implement binary add for complex arrays

This commit is contained in:
Zoltán Vörös 2021-12-15 06:52:42 +01:00
parent 3557e16cd1
commit 864ab31766
3 changed files with 155 additions and 0 deletions

View file

@ -17,6 +17,7 @@
#include "ndarray_operators.h"
#include "ulab.h"
#include "ulab_tools.h"
#include "numpy/carray/carray.h"
/*
This file contains the actual implementations of the various
@ -161,6 +162,42 @@ mp_obj_t ndarray_binary_equality(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
mp_obj_t ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
#if ULAB_SUPPORTS_COMPLEX
if((lhs->dtype == NDARRAY_COMPLEX) || (rhs->dtype == NDARRAY_COMPLEX)) {
ndarray_obj_t *results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_COMPLEX);
mp_float_t *resarray = (mp_float_t *)results->array;
uint8_t *larray = (uint8_t *)lhs->array;
uint8_t *rarray = (uint8_t *)rhs->array;
uint8_t *lo = larray, *ro = rarray;
int32_t *left_strides = lstrides;
int32_t *right_strides = rstrides;
uint8_t rdtype = rhs->dtype;
// align the complex array to the left
if(rhs->dtype == NDARRAY_COMPLEX) {
lo = (uint8_t *)rhs->array;
ro = (uint8_t *)lhs->array;
rdtype = lhs->dtype;
left_strides = rstrides;
right_strides = lstrides;
}
larray = lo;
rarray = ro;
carray_binary_add(results, resarray, larray, rarray, left_strides, right_strides, rdtype == NDARRAY_COMPLEX ? NDARRAY_FLOAT : rdtype);
if((lhs->dtype == NDARRAY_COMPLEX) && (rhs->dtype == NDARRAY_COMPLEX)) {
larray = lo + sizeof(mp_float_t);
rarray = ro + sizeof(mp_float_t);
resarray = (mp_float_t *)results->array;
resarray++;
carray_binary_add(results, resarray, larray, rarray, left_strides, right_strides, NDARRAY_FLOAT);
}
return MP_OBJ_FROM_PTR(results);
}
#endif
ndarray_obj_t *results = NULL;
uint8_t *larray = (uint8_t *)lhs->array;
uint8_t *rarray = (uint8_t *)rhs->array;

View file

@ -20,6 +20,7 @@
#include "../../ulab.h"
#include "../../ndarray.h"
#include "carray.h"
#if ULAB_SUPPORTS_COMPLEX
@ -109,4 +110,20 @@ mp_obj_t carray_abs(ndarray_obj_t *source, ndarray_obj_t *target) {
return MP_OBJ_FROM_PTR(target);
}
void carray_binary_add(ndarray_obj_t *results, mp_float_t *resarray, uint8_t *larray, uint8_t *rarray,
int32_t *lstrides, int32_t *rstrides, uint8_t rdtype) {
if(rdtype == NDARRAY_UINT8) {
BINARY_LOOP_COMPLEX(results, resarray, uint8_t, larray, lstrides, rarray, rstrides, +);
} else if(rdtype == NDARRAY_INT8) {
BINARY_LOOP_COMPLEX(results, resarray, int8_t, larray, lstrides, rarray, rstrides, +);
} else if(rdtype == NDARRAY_UINT16) {
BINARY_LOOP_COMPLEX(results, resarray, uint16_t, larray, lstrides, rarray, rstrides, +);
} else if(rdtype == NDARRAY_INT16) {
BINARY_LOOP_COMPLEX(results, resarray, int16_t, larray, lstrides, rarray, rstrides, +);
} else if(rdtype == NDARRAY_FLOAT) {
BINARY_LOOP_COMPLEX(results, resarray, mp_float_t, larray, lstrides, rarray, rstrides, +);
}
}
#endif

View file

@ -16,5 +16,106 @@ MP_DECLARE_CONST_FUN_OBJ_1(carray_real_obj);
MP_DECLARE_CONST_FUN_OBJ_1(carray_imag_obj);
mp_obj_t carray_abs(ndarray_obj_t *, ndarray_obj_t *);
void carray_binary_add(ndarray_obj_t *, mp_float_t *, uint8_t *, uint8_t *, int32_t *, int32_t *, uint8_t);
#if ULAB_MAX_DIMS == 1
#define BINARY_LOOP_COMPLEX(results, resarray, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
size_t l = 0;\
do {\
*(resarray) = *((mp_float_t *)(larray)++ OPERATOR *((type_right *)(rarray)++;\
(resarray) += 2;\
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
#endif /* ULAB_MAX_DIMS == 1 */
#if ULAB_MAX_DIMS == 2
#define BINARY_LOOP_COMPLEX(results, resarray, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
*(resarray) = *((mp_float_t *)(larray)) OPERATOR *((type_right *)(rarray));\
(resarray) += 2;\
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
#endif /* ULAB_MAX_DIMS == 2 */
#if ULAB_MAX_DIMS == 3
#define BINARY_LOOP_COMPLEX(results, resarray, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
size_t j = 0;\
do {\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
*(resarray) = *((mp_float_t *)(larray)) OPERATOR *((type_right *)(rarray));\
(resarray) += 2;\
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
j++;\
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
#endif /* ULAB_MAX_DIMS == 3 */
#if ULAB_MAX_DIMS == 4
#define BINARY_LOOP_COMPLEX(results, resarray, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
size_t i = 0;\
do {\
size_t j = 0;\
do {\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
*(resarray) = *((mp_float_t *)(larray)) OPERATOR *((type_right *)(rarray));\
(resarray) += 2;\
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
j++;\
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 4];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
i++;\
} while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\
#endif /* ULAB_MAX_DIMS == 4 */
#endif