add savetxt implementation

This commit is contained in:
Zoltán Vörös 2022-01-18 21:20:57 +01:00
parent b756c36f73
commit 209e7ff251
4 changed files with 102 additions and 3 deletions

View file

@ -17,6 +17,7 @@
#include "py/stream.h"
#include "../../ndarray.h"
#include "../../ulab_tools.h"
#include "io.h"
#define ULAB_IO_BUFFER_SIZE 128
@ -243,7 +244,7 @@ static mp_obj_t io_save(mp_obj_t file, mp_obj_t ndarray_) {
// test for endianness
uint16_t x = 1;
int8_t native_endiannes = (x >> 8) == 1 ? '>' : '<';
int8_t native_endianness = (x >> 8) == 1 ? '>' : '<';
mp_obj_t open_args[2] = {
file,
@ -259,7 +260,7 @@ static mp_obj_t io_save(mp_obj_t file, mp_obj_t ndarray_) {
memcpy(buffer, "\x93NUMPY\x01\x00\x76\x00{'descr': '", 21);
offset += 21;
buffer[offset] = native_endiannes;
buffer[offset] = native_endianness;
if((ndarray->dtype == NDARRAY_UINT8) || (ndarray->dtype == NDARRAY_INT8)) {
// for single-byte data, the endianness doesn't matter
buffer[offset] = '|';
@ -373,3 +374,93 @@ static mp_obj_t io_save(mp_obj_t file, mp_obj_t ndarray_) {
MP_DEFINE_CONST_FUN_OBJ_2(io_save_obj, io_save);
#endif /* ULAB_NUMPY_HAS_SAVE */
#if ULAB_NUMPY_HAS_SAVETXT
static int8_t io_format_number(ndarray_obj_t *ndarray, mp_float_t (*func)(void *), uint8_t *array, char *buffer, char delimiter) {
#if ULAB_SUPPORTS_COMPLEX
if(ndarray->dtype == NDARRAY_COMPLEX) {
mp_float_t real = func(array);
mp_float_t imag = func(array + ndarray->itemsize / 2);
if(imag >= MICROPY_FLOAT_CONST(0.0)) {
return sprintf(buffer, "%.8e+%.8ej%c", real, imag, delimiter);
} else {
return sprintf(buffer, "%.8e-%.8ej%c", real, -imag, delimiter);
}
}
#endif
return sprintf(buffer, "%.8e%c", func(array), delimiter);
}
static mp_obj_t io_savetxt(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = mp_const_none } },
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = mp_const_none } },
// { MP_QSTR_header, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = } },
// { MP_QSTR_footer, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = } },
// { MP_QSTR_comments, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = } },
};
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
if(!mp_obj_is_str(args[0].u_obj) || !mp_obj_is_type(args[1].u_obj, &ulab_ndarray_type)) {
mp_raise_TypeError(translate("wrong input type"));
}
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(args[1].u_obj);
#if ULAB_MAX_DIMS > 2
if(ndarray->ndim > 2) {
mp_raise_ValueError(translate("array has too many dimensions"));
}
#endif
mp_obj_t open_args[2] = {
args[0].u_obj,
MP_OBJ_NEW_QSTR(MP_QSTR_w)
};
mp_obj_t stream = mp_builtin_open(2, open_args, (mp_map_t *)&mp_const_empty_map);
const mp_stream_p_t *stream_p = mp_get_stream(stream);
char *buffer = m_new(char, ULAB_IO_BUFFER_SIZE);
int error;
uint8_t *array = (uint8_t *)ndarray->array;
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
char delimiter = '\n';
if(ndarray->ndim > 1) {
delimiter = ' ';
}
#if ULAB_MAX_DIMS > 1
size_t k = 0;
do {
#endif
size_t l = 0;
do {
int8_t chars = io_format_number(ndarray, func, array, buffer, delimiter);
if(chars > 0) {
stream_p->write(stream, buffer, chars, &error);
}
array += ndarray->strides[ULAB_MAX_DIMS - 1];
l++;
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
#if ULAB_MAX_DIMS > 1
if(ndarray->ndim > 1) {
buffer[0] = '\n';
stream_p->write(stream, buffer, 1, &error);
}
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS-1];
array += ndarray->strides[ULAB_MAX_DIMS - 2];
k++;
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
#endif
stream_p->ioctl(stream, MP_STREAM_CLOSE, 0, &error);
return mp_const_none;
}
MP_DEFINE_CONST_FUN_OBJ_KW(io_savetxt_obj, 2, io_savetxt);
#endif /* ULAB_NUMPY_HAS_SAVETXT */

View file

@ -11,7 +11,8 @@
#ifndef _ULAB_IO_
#define _ULAB_IO_
MP_DECLARE_CONST_FUN_OBJ_2(io_save_obj);
MP_DECLARE_CONST_FUN_OBJ_1(io_load_obj);
MP_DECLARE_CONST_FUN_OBJ_2(io_save_obj);
MP_DECLARE_CONST_FUN_OBJ_KW(io_savetxt_obj);
#endif

View file

@ -278,6 +278,9 @@ static const mp_rom_map_elem_t ulab_numpy_globals_table[] = {
#if ULAB_NUMPY_HAS_SAVE
{ MP_OBJ_NEW_QSTR(MP_QSTR_save), (mp_obj_t)&io_save_obj },
#endif
#if ULAB_NUMPY_HAS_SAVETXT
{ MP_OBJ_NEW_QSTR(MP_QSTR_savetxt), (mp_obj_t)&io_savetxt_obj },
#endif
#if ULAB_NUMPY_HAS_SIZE
{ MP_OBJ_NEW_QSTR(MP_QSTR_size), (mp_obj_t)&transform_size_obj },
#endif

View file

@ -494,6 +494,10 @@
#define ULAB_NUMPY_HAS_SAVE (1)
#endif
#ifndef ULAB_NUMPY_HAS_SAVETXT
#define ULAB_NUMPY_HAS_SAVETXT (1)
#endif
#ifndef ULAB_NUMPY_HAS_SIZE
#define ULAB_NUMPY_HAS_SIZE (1)
#endif