fix array initialisation from complex array

This commit is contained in:
Zoltán Vörös 2022-01-01 09:20:00 +01:00
parent 4855baa8cc
commit c11dac322d
4 changed files with 83 additions and 70 deletions

View file

@ -741,6 +741,85 @@ ndarray_obj_t *ndarray_copy_view(ndarray_obj_t *source) {
return ndarray;
}
ndarray_obj_t *ndarray_copy_view_convert_type(ndarray_obj_t *source, uint8_t dtype) {
// creates a copy, similar to ndarray_copy_view, but it also converts the dtype, if necessary
if(dtype == source->dtype) {
return ndarray_copy_view(source);
}
ndarray_obj_t *ndarray = ndarray_new_dense_ndarray(source->ndim, source->shape, dtype);
uint8_t *sarray = (uint8_t *)source->array;
uint8_t *array = (uint8_t *)ndarray->array;
#if ULAB_SUPPORTS_COMPLEX
uint8_t complex_size = 2 * sizeof(mp_float_t);
#endif
#if ULAB_MAX_DIMS > 3
size_t i = 0;
do {
#endif
#if ULAB_MAX_DIMS > 2
size_t j = 0;
do {
#endif
#if ULAB_MAX_DIMS > 1
size_t k = 0;
do {
#endif
size_t l = 0;
do {
mp_obj_t item;
#if ULAB_SUPPORTS_COMPLEX
if(source->dtype == NDARRAY_COMPLEX) {
if(dtype != NDARRAY_COMPLEX) {
mp_raise_TypeError(translate("cannot convert complex type"));
} else {
memcpy(array, sarray, complex_size);
}
} else {
#endif
if((source->dtype == NDARRAY_FLOAT) && (dtype != NDARRAY_FLOAT)) {
// floats must be treated separately, because they can't directly be converted to integer types
mp_float_t f = ndarray_get_float_value(sarray, source->dtype);
item = mp_obj_new_int((int32_t)MICROPY_FLOAT_C_FUN(floor)(f));
} else {
item = mp_binary_get_val_array(source->dtype, sarray, 0);
}
#if ULAB_SUPPORTS_COMPLEX
if(dtype == NDARRAY_COMPLEX) {
ndarray_set_value(NDARRAY_FLOAT, array, 0, item);
} else {
ndarray_set_value(dtype, array, 0, item);
}
}
#else
ndarray_set_value(dtype, array, 0, item);
#endif
array += ndarray->itemsize;
sarray += source->strides[ULAB_MAX_DIMS - 1];
l++;
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
#if ULAB_MAX_DIMS > 1
sarray -= source->strides[ULAB_MAX_DIMS - 1] * source->shape[ULAB_MAX_DIMS-1];
sarray += source->strides[ULAB_MAX_DIMS - 2];
k++;
} while(k < source->shape[ULAB_MAX_DIMS - 2]);
#endif
#if ULAB_MAX_DIMS > 2
sarray -= source->strides[ULAB_MAX_DIMS - 2] * source->shape[ULAB_MAX_DIMS-2];
sarray += source->strides[ULAB_MAX_DIMS - 3];
j++;
} while(j < source->shape[ULAB_MAX_DIMS - 3]);
#endif
#if ULAB_MAX_DIMS > 3
sarray -= source->strides[ULAB_MAX_DIMS - 3] * source->shape[ULAB_MAX_DIMS-3];
sarray += source->strides[ULAB_MAX_DIMS - 4];
i++;
} while(i < source->shape[ULAB_MAX_DIMS - 4]);
#endif
return MP_OBJ_FROM_PTR(ndarray);
}
#if NDARRAY_HAS_BYTESWAP
mp_obj_t ndarray_byteswap(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
// changes the endiannes of an array
@ -952,75 +1031,7 @@ STATIC mp_obj_t ndarray_make_new_core(const mp_obj_type_t *type, size_t n_args,
if(mp_obj_is_type(args[0], &ulab_ndarray_type)) {
ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0]);
if(dtype == source->dtype) {
return ndarray_copy_view(source);
}
ndarray_obj_t *target = ndarray_new_dense_ndarray(source->ndim, source->shape, dtype);
uint8_t *sarray = (uint8_t *)source->array;
uint8_t *tarray = (uint8_t *)target->array;
#if ULAB_SUPPORTS_COMPLEX
uint8_t complex_size = 2 * sizeof(mp_float_t);
#endif
#if ULAB_MAX_DIMS > 3
size_t i = 0;
do {
#endif
#if ULAB_MAX_DIMS > 2
size_t j = 0;
do {
#endif
#if ULAB_MAX_DIMS > 1
size_t k = 0;
do {
#endif
size_t l = 0;
do {
mp_obj_t item;
#if ULAB_SUPPORTS_COMPLEX
if(source->dtype == NDARRAY_COMPLEX) {
if(dtype != NDARRAY_COMPLEX) {
mp_raise_TypeError(translate("cannot convert complex type"));
} else {
memcpy(tarray, sarray, complex_size);
}
} else {
#endif
if((source->dtype == NDARRAY_FLOAT) && (dtype != NDARRAY_FLOAT)) {
// floats must be treated separately, because they can't directly be converted to integer types
mp_float_t f = ndarray_get_float_value(sarray, source->dtype);
item = mp_obj_new_int((int32_t)MICROPY_FLOAT_C_FUN(floor)(f));
} else {
item = mp_binary_get_val_array(source->dtype, sarray, 0);
}
ndarray_set_value(dtype, tarray, 0, item);
#if ULAB_SUPPORTS_COMPLEX
}
#endif
tarray += target->itemsize;
sarray += source->strides[ULAB_MAX_DIMS - 1];
l++;
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
#if ULAB_MAX_DIMS > 1
sarray -= source->strides[ULAB_MAX_DIMS - 1] * source->shape[ULAB_MAX_DIMS-1];
sarray += source->strides[ULAB_MAX_DIMS - 2];
k++;
} while(k < source->shape[ULAB_MAX_DIMS - 2]);
#endif
#if ULAB_MAX_DIMS > 2
sarray -= source->strides[ULAB_MAX_DIMS - 2] * source->shape[ULAB_MAX_DIMS-2];
sarray += source->strides[ULAB_MAX_DIMS - 3];
j++;
} while(j < source->shape[ULAB_MAX_DIMS - 3]);
#endif
#if ULAB_MAX_DIMS > 3
sarray -= source->strides[ULAB_MAX_DIMS - 3] * source->shape[ULAB_MAX_DIMS-3];
sarray += source->strides[ULAB_MAX_DIMS - 4];
i++;
} while(i < source->shape[ULAB_MAX_DIMS - 4]);
#endif
return MP_OBJ_FROM_PTR(target);
return MP_OBJ_FROM_PTR(ndarray_copy_view_convert_type(source, dtype));
} else {
// assume that the input is an iterable
return MP_OBJ_FROM_PTR(ndarray_from_iterable(args[0], dtype));

View file

@ -143,6 +143,7 @@ ndarray_obj_t *ndarray_new_linear_array(size_t , uint8_t );
ndarray_obj_t *ndarray_new_view(ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t );
bool ndarray_is_dense(ndarray_obj_t *);
ndarray_obj_t *ndarray_copy_view(ndarray_obj_t *);
ndarray_obj_t *ndarray_copy_view_convert_type(ndarray_obj_t *, uint8_t );
void ndarray_copy_array(ndarray_obj_t *, ndarray_obj_t *, uint8_t );
MP_DECLARE_CONST_FUN_OBJ_KW(ndarray_array_constructor_obj);

View file

@ -6,7 +6,7 @@
*
* The MIT License (MIT)
*
* Copyright (c) 2021 Zoltán Vörös
* Copyright (c) 2021-2022 Zoltán Vörös
*/
#include <math.h>

View file

@ -15,6 +15,7 @@
MP_DECLARE_CONST_FUN_OBJ_1(carray_real_obj);
MP_DECLARE_CONST_FUN_OBJ_1(carray_imag_obj);
MP_DECLARE_CONST_FUN_OBJ_1(carray_conjugate_obj);
MP_DECLARE_CONST_FUN_OBJ_1(carray_sort_complex_obj);
mp_obj_t carray_abs(ndarray_obj_t *, ndarray_obj_t *);
mp_obj_t carray_binary_add(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);