fix array initialisation from complex array
This commit is contained in:
parent
4855baa8cc
commit
c11dac322d
4 changed files with 83 additions and 70 deletions
149
code/ndarray.c
149
code/ndarray.c
|
|
@ -741,6 +741,85 @@ ndarray_obj_t *ndarray_copy_view(ndarray_obj_t *source) {
|
|||
return ndarray;
|
||||
}
|
||||
|
||||
ndarray_obj_t *ndarray_copy_view_convert_type(ndarray_obj_t *source, uint8_t dtype) {
|
||||
// creates a copy, similar to ndarray_copy_view, but it also converts the dtype, if necessary
|
||||
if(dtype == source->dtype) {
|
||||
return ndarray_copy_view(source);
|
||||
}
|
||||
ndarray_obj_t *ndarray = ndarray_new_dense_ndarray(source->ndim, source->shape, dtype);
|
||||
uint8_t *sarray = (uint8_t *)source->array;
|
||||
uint8_t *array = (uint8_t *)ndarray->array;
|
||||
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
uint8_t complex_size = 2 * sizeof(mp_float_t);
|
||||
#endif
|
||||
|
||||
#if ULAB_MAX_DIMS > 3
|
||||
size_t i = 0;
|
||||
do {
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 2
|
||||
size_t j = 0;
|
||||
do {
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
size_t k = 0;
|
||||
do {
|
||||
#endif
|
||||
size_t l = 0;
|
||||
do {
|
||||
mp_obj_t item;
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
if(source->dtype == NDARRAY_COMPLEX) {
|
||||
if(dtype != NDARRAY_COMPLEX) {
|
||||
mp_raise_TypeError(translate("cannot convert complex type"));
|
||||
} else {
|
||||
memcpy(array, sarray, complex_size);
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
if((source->dtype == NDARRAY_FLOAT) && (dtype != NDARRAY_FLOAT)) {
|
||||
// floats must be treated separately, because they can't directly be converted to integer types
|
||||
mp_float_t f = ndarray_get_float_value(sarray, source->dtype);
|
||||
item = mp_obj_new_int((int32_t)MICROPY_FLOAT_C_FUN(floor)(f));
|
||||
} else {
|
||||
item = mp_binary_get_val_array(source->dtype, sarray, 0);
|
||||
}
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
if(dtype == NDARRAY_COMPLEX) {
|
||||
ndarray_set_value(NDARRAY_FLOAT, array, 0, item);
|
||||
} else {
|
||||
ndarray_set_value(dtype, array, 0, item);
|
||||
}
|
||||
}
|
||||
#else
|
||||
ndarray_set_value(dtype, array, 0, item);
|
||||
#endif
|
||||
array += ndarray->itemsize;
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 1];
|
||||
l++;
|
||||
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 1] * source->shape[ULAB_MAX_DIMS-1];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 2];
|
||||
k++;
|
||||
} while(k < source->shape[ULAB_MAX_DIMS - 2]);
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 2
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 2] * source->shape[ULAB_MAX_DIMS-2];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 3];
|
||||
j++;
|
||||
} while(j < source->shape[ULAB_MAX_DIMS - 3]);
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 3
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 3] * source->shape[ULAB_MAX_DIMS-3];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 4];
|
||||
i++;
|
||||
} while(i < source->shape[ULAB_MAX_DIMS - 4]);
|
||||
#endif
|
||||
return MP_OBJ_FROM_PTR(ndarray);
|
||||
}
|
||||
|
||||
#if NDARRAY_HAS_BYTESWAP
|
||||
mp_obj_t ndarray_byteswap(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||
// changes the endiannes of an array
|
||||
|
|
@ -952,75 +1031,7 @@ STATIC mp_obj_t ndarray_make_new_core(const mp_obj_type_t *type, size_t n_args,
|
|||
|
||||
if(mp_obj_is_type(args[0], &ulab_ndarray_type)) {
|
||||
ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0]);
|
||||
if(dtype == source->dtype) {
|
||||
return ndarray_copy_view(source);
|
||||
}
|
||||
ndarray_obj_t *target = ndarray_new_dense_ndarray(source->ndim, source->shape, dtype);
|
||||
uint8_t *sarray = (uint8_t *)source->array;
|
||||
uint8_t *tarray = (uint8_t *)target->array;
|
||||
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
uint8_t complex_size = 2 * sizeof(mp_float_t);
|
||||
#endif
|
||||
|
||||
#if ULAB_MAX_DIMS > 3
|
||||
size_t i = 0;
|
||||
do {
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 2
|
||||
size_t j = 0;
|
||||
do {
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
size_t k = 0;
|
||||
do {
|
||||
#endif
|
||||
size_t l = 0;
|
||||
do {
|
||||
mp_obj_t item;
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
if(source->dtype == NDARRAY_COMPLEX) {
|
||||
if(dtype != NDARRAY_COMPLEX) {
|
||||
mp_raise_TypeError(translate("cannot convert complex type"));
|
||||
} else {
|
||||
memcpy(tarray, sarray, complex_size);
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
if((source->dtype == NDARRAY_FLOAT) && (dtype != NDARRAY_FLOAT)) {
|
||||
// floats must be treated separately, because they can't directly be converted to integer types
|
||||
mp_float_t f = ndarray_get_float_value(sarray, source->dtype);
|
||||
item = mp_obj_new_int((int32_t)MICROPY_FLOAT_C_FUN(floor)(f));
|
||||
} else {
|
||||
item = mp_binary_get_val_array(source->dtype, sarray, 0);
|
||||
}
|
||||
ndarray_set_value(dtype, tarray, 0, item);
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
}
|
||||
#endif
|
||||
tarray += target->itemsize;
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 1];
|
||||
l++;
|
||||
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 1] * source->shape[ULAB_MAX_DIMS-1];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 2];
|
||||
k++;
|
||||
} while(k < source->shape[ULAB_MAX_DIMS - 2]);
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 2
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 2] * source->shape[ULAB_MAX_DIMS-2];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 3];
|
||||
j++;
|
||||
} while(j < source->shape[ULAB_MAX_DIMS - 3]);
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 3
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 3] * source->shape[ULAB_MAX_DIMS-3];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 4];
|
||||
i++;
|
||||
} while(i < source->shape[ULAB_MAX_DIMS - 4]);
|
||||
#endif
|
||||
return MP_OBJ_FROM_PTR(target);
|
||||
return MP_OBJ_FROM_PTR(ndarray_copy_view_convert_type(source, dtype));
|
||||
} else {
|
||||
// assume that the input is an iterable
|
||||
return MP_OBJ_FROM_PTR(ndarray_from_iterable(args[0], dtype));
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ ndarray_obj_t *ndarray_new_linear_array(size_t , uint8_t );
|
|||
ndarray_obj_t *ndarray_new_view(ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t );
|
||||
bool ndarray_is_dense(ndarray_obj_t *);
|
||||
ndarray_obj_t *ndarray_copy_view(ndarray_obj_t *);
|
||||
ndarray_obj_t *ndarray_copy_view_convert_type(ndarray_obj_t *, uint8_t );
|
||||
void ndarray_copy_array(ndarray_obj_t *, ndarray_obj_t *, uint8_t );
|
||||
|
||||
MP_DECLARE_CONST_FUN_OBJ_KW(ndarray_array_constructor_obj);
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2021 Zoltán Vörös
|
||||
* Copyright (c) 2021-2022 Zoltán Vörös
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
MP_DECLARE_CONST_FUN_OBJ_1(carray_real_obj);
|
||||
MP_DECLARE_CONST_FUN_OBJ_1(carray_imag_obj);
|
||||
MP_DECLARE_CONST_FUN_OBJ_1(carray_conjugate_obj);
|
||||
MP_DECLARE_CONST_FUN_OBJ_1(carray_sort_complex_obj);
|
||||
|
||||
mp_obj_t carray_abs(ndarray_obj_t *, ndarray_obj_t *);
|
||||
mp_obj_t carray_binary_add(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
|
|
|
|||
Loading…
Reference in a new issue