implement binary add for complex arrays
This commit is contained in:
parent
3557e16cd1
commit
864ab31766
3 changed files with 155 additions and 0 deletions
|
|
@ -17,6 +17,7 @@
|
|||
#include "ndarray_operators.h"
|
||||
#include "ulab.h"
|
||||
#include "ulab_tools.h"
|
||||
#include "numpy/carray/carray.h"
|
||||
|
||||
/*
|
||||
This file contains the actual implementations of the various
|
||||
|
|
@ -161,6 +162,42 @@ mp_obj_t ndarray_binary_equality(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
|||
mp_obj_t ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
if((lhs->dtype == NDARRAY_COMPLEX) || (rhs->dtype == NDARRAY_COMPLEX)) {
|
||||
ndarray_obj_t *results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_COMPLEX);
|
||||
mp_float_t *resarray = (mp_float_t *)results->array;
|
||||
|
||||
uint8_t *larray = (uint8_t *)lhs->array;
|
||||
uint8_t *rarray = (uint8_t *)rhs->array;
|
||||
uint8_t *lo = larray, *ro = rarray;
|
||||
int32_t *left_strides = lstrides;
|
||||
int32_t *right_strides = rstrides;
|
||||
uint8_t rdtype = rhs->dtype;
|
||||
|
||||
// align the complex array to the left
|
||||
if(rhs->dtype == NDARRAY_COMPLEX) {
|
||||
lo = (uint8_t *)rhs->array;
|
||||
ro = (uint8_t *)lhs->array;
|
||||
rdtype = lhs->dtype;
|
||||
left_strides = rstrides;
|
||||
right_strides = lstrides;
|
||||
}
|
||||
|
||||
larray = lo;
|
||||
rarray = ro;
|
||||
carray_binary_add(results, resarray, larray, rarray, left_strides, right_strides, rdtype == NDARRAY_COMPLEX ? NDARRAY_FLOAT : rdtype);
|
||||
|
||||
if((lhs->dtype == NDARRAY_COMPLEX) && (rhs->dtype == NDARRAY_COMPLEX)) {
|
||||
larray = lo + sizeof(mp_float_t);
|
||||
rarray = ro + sizeof(mp_float_t);
|
||||
resarray = (mp_float_t *)results->array;
|
||||
resarray++;
|
||||
carray_binary_add(results, resarray, larray, rarray, left_strides, right_strides, NDARRAY_FLOAT);
|
||||
}
|
||||
return MP_OBJ_FROM_PTR(results);
|
||||
}
|
||||
#endif
|
||||
|
||||
ndarray_obj_t *results = NULL;
|
||||
uint8_t *larray = (uint8_t *)lhs->array;
|
||||
uint8_t *rarray = (uint8_t *)rhs->array;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "../../ulab.h"
|
||||
#include "../../ndarray.h"
|
||||
#include "carray.h"
|
||||
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
|
||||
|
|
@ -109,4 +110,20 @@ mp_obj_t carray_abs(ndarray_obj_t *source, ndarray_obj_t *target) {
|
|||
return MP_OBJ_FROM_PTR(target);
|
||||
}
|
||||
|
||||
void carray_binary_add(ndarray_obj_t *results, mp_float_t *resarray, uint8_t *larray, uint8_t *rarray,
|
||||
int32_t *lstrides, int32_t *rstrides, uint8_t rdtype) {
|
||||
|
||||
if(rdtype == NDARRAY_UINT8) {
|
||||
BINARY_LOOP_COMPLEX(results, resarray, uint8_t, larray, lstrides, rarray, rstrides, +);
|
||||
} else if(rdtype == NDARRAY_INT8) {
|
||||
BINARY_LOOP_COMPLEX(results, resarray, int8_t, larray, lstrides, rarray, rstrides, +);
|
||||
} else if(rdtype == NDARRAY_UINT16) {
|
||||
BINARY_LOOP_COMPLEX(results, resarray, uint16_t, larray, lstrides, rarray, rstrides, +);
|
||||
} else if(rdtype == NDARRAY_INT16) {
|
||||
BINARY_LOOP_COMPLEX(results, resarray, int16_t, larray, lstrides, rarray, rstrides, +);
|
||||
} else if(rdtype == NDARRAY_FLOAT) {
|
||||
BINARY_LOOP_COMPLEX(results, resarray, mp_float_t, larray, lstrides, rarray, rstrides, +);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -16,5 +16,106 @@ MP_DECLARE_CONST_FUN_OBJ_1(carray_real_obj);
|
|||
MP_DECLARE_CONST_FUN_OBJ_1(carray_imag_obj);
|
||||
|
||||
mp_obj_t carray_abs(ndarray_obj_t *, ndarray_obj_t *);
|
||||
void carray_binary_add(ndarray_obj_t *, mp_float_t *, uint8_t *, uint8_t *, int32_t *, int32_t *, uint8_t);
|
||||
|
||||
#if ULAB_MAX_DIMS == 1
|
||||
#define BINARY_LOOP_COMPLEX(results, resarray, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
||||
size_t l = 0;\
|
||||
do {\
|
||||
*(resarray) = *((mp_float_t *)(larray)++ OPERATOR *((type_right *)(rarray)++;\
|
||||
(resarray) += 2;\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
|
||||
#endif /* ULAB_MAX_DIMS == 1 */
|
||||
|
||||
#if ULAB_MAX_DIMS == 2
|
||||
#define BINARY_LOOP_COMPLEX(results, resarray, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
||||
size_t k = 0;\
|
||||
do {\
|
||||
size_t l = 0;\
|
||||
do {\
|
||||
*(resarray) = *((mp_float_t *)(larray)) OPERATOR *((type_right *)(rarray));\
|
||||
(resarray) += 2;\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
||||
k++;\
|
||||
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
|
||||
|
||||
#endif /* ULAB_MAX_DIMS == 2 */
|
||||
|
||||
#if ULAB_MAX_DIMS == 3
|
||||
#define BINARY_LOOP_COMPLEX(results, resarray, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
||||
size_t j = 0;\
|
||||
do {\
|
||||
size_t k = 0;\
|
||||
do {\
|
||||
size_t l = 0;\
|
||||
do {\
|
||||
*(resarray) = *((mp_float_t *)(larray)) OPERATOR *((type_right *)(rarray));\
|
||||
(resarray) += 2;\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
||||
k++;\
|
||||
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
|
||||
j++;\
|
||||
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
|
||||
|
||||
#endif /* ULAB_MAX_DIMS == 3 */
|
||||
|
||||
#if ULAB_MAX_DIMS == 4
|
||||
#define BINARY_LOOP_COMPLEX(results, resarray, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
||||
size_t i = 0;\
|
||||
do {\
|
||||
size_t j = 0;\
|
||||
do {\
|
||||
size_t k = 0;\
|
||||
do {\
|
||||
size_t l = 0;\
|
||||
do {\
|
||||
*(resarray) = *((mp_float_t *)(larray)) OPERATOR *((type_right *)(rarray));\
|
||||
(resarray) += 2;\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
||||
k++;\
|
||||
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
|
||||
j++;\
|
||||
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 4];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
|
||||
i++;\
|
||||
} while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\
|
||||
|
||||
#endif /* ULAB_MAX_DIMS == 4 */
|
||||
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue