added in-place operators
This commit is contained in:
parent
ff8d260809
commit
2e08b2566d
5 changed files with 346 additions and 23 deletions
|
|
@ -876,6 +876,35 @@ bool ndarray_can_broadcast(ndarray_obj_t *lhs, ndarray_obj_t *rhs, uint8_t *ndim
|
|||
return true;
|
||||
}
|
||||
|
||||
#if NDARRAY_HAS_INPLACE_OPS
|
||||
bool ndarray_can_broadcast_inplace(ndarray_obj_t *lhs, ndarray_obj_t *rhs, uint8_t *ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
// returns true or false, depending on, whether the two arrays can be broadcast together inplace
|
||||
// this means that the right hand side always must be "smaller" than the left hand side, i.e.
|
||||
// the broadcasting rules are as follows:
|
||||
//
|
||||
// 1. the two shapes are either equal
|
||||
// 2. the shapes on the right hand side is 1
|
||||
memset(lstrides, 0, sizeof(size_t)*ULAB_MAX_DIMS);
|
||||
memset(rstrides, 0, sizeof(size_t)*ULAB_MAX_DIMS);
|
||||
lstrides[ULAB_MAX_DIMS - 1] = lhs->strides[ULAB_MAX_DIMS - 1];
|
||||
rstrides[ULAB_MAX_DIMS - 1] = rhs->strides[ULAB_MAX_DIMS - 1];
|
||||
for(uint8_t i=ULAB_MAX_DIMS; i > 0; i--) {
|
||||
if((lhs->shape[i-1] == rhs->shape[i-1]) || (rhs->shape[i-1] == 0) || (rhs->shape[i-1] == 1)) {
|
||||
shape[i-1] = lhs->shape[i-1];
|
||||
if(shape[i-1] > 0) (*ndim)++;
|
||||
if(rhs->shape[i-1] < 2) {
|
||||
rstrides[i-1] = 0;
|
||||
} else {
|
||||
rstrides[i-1] = rhs->strides[i-1];
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
void ndarray_assign_view(ndarray_obj_t *view, ndarray_obj_t *values) {
|
||||
uint8_t ndim = 0;
|
||||
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
||||
|
|
@ -1247,8 +1276,9 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t obj) {
|
|||
return ndarray;
|
||||
}
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OPS
|
||||
#if NDARRAY_HAS_BINARY_OPS || NDARRAY_HAS_INPLACE_OPS
|
||||
mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
|
||||
// TODO: implement in-place operators
|
||||
// if the ndarray stands on the right hand side of the expression, simply swap the operands
|
||||
ndarray_obj_t *lhs, *rhs;
|
||||
mp_binary_op_t op = _op;
|
||||
|
|
@ -1272,15 +1302,19 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
|
|||
} else if(op == MP_BINARY_OP_REVERSE_TRUE_DIVIDE) {
|
||||
op = MP_BINARY_OP_TRUE_DIVIDE;
|
||||
}
|
||||
// One of the operands is a scalar
|
||||
// TODO: conform to numpy with the upcasting
|
||||
// TODO: implement in-place operators
|
||||
|
||||
uint8_t ndim = 0;
|
||||
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
||||
int32_t *lstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||
int32_t *rstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||
if(!ndarray_can_broadcast(lhs, rhs, &ndim, shape, lstrides, rstrides)) {
|
||||
uint8_t broadcastable;
|
||||
if((op == MP_BINARY_OP_INPLACE_ADD) || (op == MP_BINARY_OP_INPLACE_MULTIPLY) || (op == MP_BINARY_OP_INPLACE_POWER) ||
|
||||
(op == MP_BINARY_OP_INPLACE_SUBTRACT) || (op == MP_BINARY_OP_INPLACE_TRUE_DIVIDE)) {
|
||||
broadcastable = ndarray_can_broadcast_inplace(lhs, rhs, &ndim, shape, lstrides, rstrides);
|
||||
} else {
|
||||
broadcastable = ndarray_can_broadcast(lhs, rhs, &ndim, shape, lstrides, rstrides);
|
||||
}
|
||||
if(!broadcastable) {
|
||||
mp_raise_ValueError(translate("operands could not be broadcast together"));
|
||||
m_del(size_t, shape, ULAB_MAX_DIMS);
|
||||
m_del(int32_t, lstrides, ULAB_MAX_DIMS);
|
||||
|
|
@ -1288,6 +1322,34 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
|
|||
}
|
||||
|
||||
switch(op) {
|
||||
// first the in-place operators
|
||||
#if NDARRAY_HAS_INPLACE_ADD
|
||||
case MP_BINARY_OP_INPLACE_ADD:
|
||||
return ndarray_inplace_ams(lhs, rhs, ndim, shape, lstrides, rstrides, op);
|
||||
break;
|
||||
#endif
|
||||
#if NDARRAY_HAS_INPLACE_MULTIPLY
|
||||
case MP_BINARY_OP_INPLACE_MULTIPLY:
|
||||
return ndarray_inplace_ams(lhs, rhs, ndim, shape, lstrides, rstrides, op);
|
||||
break;
|
||||
#endif
|
||||
#if NDARRAY_HAS_INPLACE_POWER
|
||||
case MP_BINARY_OP_INPLACE_POWER:
|
||||
return ndarray_inplace_power(lhs, rhs, ndim, shape, lstrides, rstrides);
|
||||
break;
|
||||
#endif
|
||||
#if NDARRAY_HAS_INPLACE_SUBTRACT
|
||||
case MP_BINARY_OP_INPLACE_SUBTRACT:
|
||||
return ndarray_inplace_ams(lhs, rhs, ndim, shape, lstrides, rstrides, op);
|
||||
break;
|
||||
#endif
|
||||
#if NDARRAY_HAS_INPLACE_TRUE_DIVIDE
|
||||
case MP_BINARY_OP_INPLACE_TRUE_DIVIDE:
|
||||
return ndarray_inplace_divide(lhs, rhs, ndim, shape, lstrides, rstrides);
|
||||
break;
|
||||
#endif
|
||||
// end if in-place operators
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OP_LESS
|
||||
case MP_BINARY_OP_LESS:
|
||||
// here we simply swap the operands
|
||||
|
|
@ -1300,8 +1362,6 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
|
|||
return ndarray_binary_more(rhs, lhs, ndim, shape, rstrides, lstrides, MP_BINARY_OP_MORE_EQUAL);
|
||||
break;
|
||||
#endif
|
||||
// by separating the associative operators, we can save a lot of flash space,
|
||||
// because the operands can simply be swapped for half of the cases
|
||||
#if NDARRAY_HAS_BINARY_OP_EQUAL
|
||||
case MP_BINARY_OP_EQUAL:
|
||||
return ndarray_binary_equality(lhs, rhs, ndim, shape, lstrides, rstrides, MP_BINARY_OP_EQUAL);
|
||||
|
|
@ -1353,7 +1413,7 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
|
|||
}
|
||||
return MP_OBJ_NULL;
|
||||
}
|
||||
#endif /* NDARRAY_HAS_BINARY_OPS */
|
||||
#endif /* NDARRAY_HAS_BINARY_OPS || NDARRAY_HAS_INPLACE_OPS */
|
||||
|
||||
#if NDARRAY_HAS_UNARY_OPS
|
||||
mp_obj_t ndarray_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
|
||||
|
|
|
|||
|
|
@ -123,6 +123,7 @@ mp_obj_t ndarray_make_new(const mp_obj_type_t *, size_t , size_t , const mp_obj_
|
|||
mp_obj_t ndarray_subscr(mp_obj_t , mp_obj_t , mp_obj_t );
|
||||
mp_obj_t ndarray_getiter(mp_obj_t , mp_obj_iter_buf_t *);
|
||||
bool ndarray_can_broadcast(ndarray_obj_t *, ndarray_obj_t *, uint8_t *, size_t *, int32_t *, int32_t *);
|
||||
bool ndarray_can_broadcast_inplace(ndarray_obj_t *, ndarray_obj_t *, uint8_t *, size_t *, int32_t *, int32_t *);
|
||||
mp_obj_t ndarray_binary_op(mp_binary_op_t , mp_obj_t , mp_obj_t );
|
||||
mp_obj_t ndarray_unary_op(mp_unary_op_t , mp_obj_t );
|
||||
|
||||
|
|
@ -266,6 +267,23 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t );
|
|||
|
||||
#endif
|
||||
|
||||
#define INPLACE_LOOP(results, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
||||
size_t k = 0;\
|
||||
do {\
|
||||
size_t l = 0;\
|
||||
do {\
|
||||
*((type_left *)(larray)) OPERATOR *((type_right *)(rarray));\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
||||
k++;\
|
||||
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
|
||||
|
||||
#define EQUALITY_LOOP(results, array, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
||||
size_t k = 0;\
|
||||
do {\
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <math.h>
|
||||
|
||||
#include "py/runtime.h"
|
||||
#include "py/objtuple.h"
|
||||
#include "ndarray.h"
|
||||
#include "ndarray_operators.h"
|
||||
|
|
@ -23,7 +24,7 @@
|
|||
*/
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OP_EQUAL | NDARRAY_HAS_BINARY_OP_NOT_EQUAL
|
||||
ndarray_obj_t *ndarray_binary_equality(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
mp_obj_t ndarray_binary_equality(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides, mp_binary_op_t op) {
|
||||
|
||||
ndarray_obj_t *results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
|
||||
|
|
@ -145,7 +146,7 @@ ndarray_obj_t *ndarray_binary_equality(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
|||
#endif /* NDARRAY_HAS_BINARY_OP_EQUAL | NDARRAY_HAS_BINARY_OP_NOT_EQUAL */
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OP_ADD
|
||||
ndarray_obj_t *ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
mp_obj_t ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
|
||||
ndarray_obj_t *results = NULL;
|
||||
|
|
@ -222,7 +223,7 @@ ndarray_obj_t *ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
|||
#endif /* NDARRAY_HAS_BINARY_OP_ADD */
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OP_MULTIPLY
|
||||
ndarray_obj_t *ndarray_binary_multiply(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
mp_obj_t ndarray_binary_multiply(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
|
||||
ndarray_obj_t *results = NULL;
|
||||
|
|
@ -299,7 +300,7 @@ ndarray_obj_t *ndarray_binary_multiply(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
|||
#endif /* NDARRAY_HAS_BINARY_OP_MULTIPLY */
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OP_MORE | NDARRAY_HAS_BINARY_OP_MORE_EQUAL | NDARRAY_HAS_BINARY_OP_LESS | NDARRAY_HAS_BINARY_OP_LESS_EQUAL
|
||||
ndarray_obj_t *ndarray_binary_more(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
mp_obj_t ndarray_binary_more(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides, mp_binary_op_t op) {
|
||||
|
||||
ndarray_obj_t *results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
|
||||
|
|
@ -444,7 +445,7 @@ ndarray_obj_t *ndarray_binary_more(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
|||
#endif /* NDARRAY_HAS_BINARY_OP_MORE | NDARRAY_HAS_BINARY_OP_MORE_EQUAL | NDARRAY_HAS_BINARY_OP_LESS | NDARRAY_HAS_BINARY_OP_LESS_EQUAL */
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OP_SUBTRACT
|
||||
ndarray_obj_t *ndarray_binary_subtract(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
mp_obj_t ndarray_binary_subtract(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
|
||||
ndarray_obj_t *results = NULL;
|
||||
|
|
@ -543,7 +544,7 @@ ndarray_obj_t *ndarray_binary_subtract(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
|||
#endif /* NDARRAY_HAS_BINARY_OP_SUBTRACT */
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OP_TRUE_DIVIDE
|
||||
ndarray_obj_t *ndarray_binary_true_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
mp_obj_t ndarray_binary_true_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
|
||||
ndarray_obj_t *results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
|
||||
|
|
@ -617,7 +618,7 @@ ndarray_obj_t *ndarray_binary_true_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs
|
|||
#endif /* NDARRAY_HAS_BINARY_OP_TRUE_DIVIDE */
|
||||
|
||||
#if NDARRAY_HAS_BINARY_OP_POWER
|
||||
ndarray_obj_t *ndarray_binary_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
mp_obj_t ndarray_binary_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
|
||||
// TODO: numpy upcasts the results to int64, if the inputs are of integer type,
|
||||
|
|
@ -691,3 +692,83 @@ ndarray_obj_t *ndarray_binary_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
|||
return MP_OBJ_FROM_PTR(results);
|
||||
}
|
||||
#endif /* NDARRAY_HAS_BINARY_OP_POWER */
|
||||
|
||||
#if NDARRAY_HAS_INPLACE_ADD || NDARRAY_HAS_INPLACE_MULTIPLY || NDARRAY_HAS_INPLACE_SUBTRACT
|
||||
mp_obj_t ndarray_inplace_ams(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides, uint8_t optype) {
|
||||
|
||||
if((lhs->dtype != NDARRAY_FLOAT) && (rhs->dtype == NDARRAY_FLOAT)) {
|
||||
mp_raise_TypeError(translate("cannot cast output with casting rule"));
|
||||
}
|
||||
uint8_t *larray = (uint8_t *)lhs->array;
|
||||
uint8_t *rarray = (uint8_t *)rhs->array;
|
||||
|
||||
#if NDARRAY_HAS_INPLACE_ADD
|
||||
if(optype == MP_BINARY_OP_INPLACE_ADD) {
|
||||
UNWRAP_INPLACE_OPERATOR(lhs, larray, lstrides, rarray, rstrides, +=);
|
||||
}
|
||||
#endif
|
||||
#if NDARRAY_HAS_INPLACE_ADD
|
||||
if(optype == MP_BINARY_OP_INPLACE_MULTIPLY) {
|
||||
UNWRAP_INPLACE_OPERATOR(lhs, larray, lstrides, rarray, rstrides, *=);
|
||||
}
|
||||
#endif
|
||||
#if NDARRAY_HAS_INPLACE_SUBTRACT
|
||||
if(optype == MP_BINARY_OP_INPLACE_SUBTRACT) {
|
||||
UNWRAP_INPLACE_OPERATOR(lhs, larray, lstrides, rarray, rstrides, -=);
|
||||
}
|
||||
#endif
|
||||
|
||||
return MP_OBJ_FROM_PTR(lhs);
|
||||
}
|
||||
#endif /* NDARRAY_HAS_INPLACE_ADD || NDARRAY_HAS_INPLACE_MULTIPLY || NDARRAY_HAS_INPLACE_SUBTRACT */
|
||||
|
||||
#if NDARRAY_HAS_INPLACE_TRUE_DIVIDE
|
||||
mp_obj_t ndarray_inplace_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
|
||||
if((lhs->dtype != NDARRAY_FLOAT)) {
|
||||
mp_raise_TypeError(translate("results cannot be cast to specified type"));
|
||||
}
|
||||
uint8_t *larray = (uint8_t *)lhs->array;
|
||||
uint8_t *rarray = (uint8_t *)rhs->array;
|
||||
|
||||
if(rhs->dtype == NDARRAY_UINT8) {
|
||||
INPLACE_LOOP(lhs, mp_float_t, uint8_t, larray, lstrides, rarray, rstrides, /=);
|
||||
} else if(rhs->dtype == NDARRAY_INT8) {
|
||||
INPLACE_LOOP(lhs, mp_float_t, int8_t, larray, lstrides, rarray, rstrides, /=);
|
||||
} else if(lhs->dtype == NDARRAY_UINT16) {
|
||||
INPLACE_LOOP(lhs, mp_float_t, uint16_t, larray, lstrides, rarray, rstrides, /=);
|
||||
} else if(rhs->dtype == NDARRAY_INT16) {
|
||||
INPLACE_LOOP(lhs, mp_float_t, int16_t, larray, lstrides, rarray, rstrides, /=);
|
||||
} else if(lhs->dtype == NDARRAY_FLOAT) {
|
||||
INPLACE_LOOP(lhs, mp_float_t, mp_float_t, larray, lstrides, rarray, rstrides, /=);
|
||||
}
|
||||
return MP_OBJ_FROM_PTR(lhs);
|
||||
}
|
||||
#endif /* NDARRAY_HAS_INPLACE_DIVIDE */
|
||||
|
||||
#if NDARRAY_HAS_INPLACE_POWER
|
||||
mp_obj_t ndarray_inplace_power(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
|
||||
uint8_t ndim, size_t *shape, int32_t *lstrides, int32_t *rstrides) {
|
||||
|
||||
if((lhs->dtype != NDARRAY_FLOAT)) {
|
||||
mp_raise_TypeError(translate("results cannot be cast to specified type"));
|
||||
}
|
||||
uint8_t *larray = (uint8_t *)lhs->array;
|
||||
uint8_t *rarray = (uint8_t *)rhs->array;
|
||||
|
||||
if(rhs->dtype == NDARRAY_UINT8) {
|
||||
INPLACE_POWER(lhs, mp_float_t, uint8_t, larray, lstrides, rarray, rstrides);
|
||||
} else if(rhs->dtype == NDARRAY_INT8) {
|
||||
INPLACE_POWER(lhs, mp_float_t, int8_t, larray, lstrides, rarray, rstrides);
|
||||
} else if(lhs->dtype == NDARRAY_UINT16) {
|
||||
INPLACE_POWER(lhs, mp_float_t, uint16_t, larray, lstrides, rarray, rstrides);
|
||||
} else if(rhs->dtype == NDARRAY_INT16) {
|
||||
INPLACE_POWER(lhs, mp_float_t, int16_t, larray, lstrides, rarray, rstrides);
|
||||
} else if(lhs->dtype == NDARRAY_FLOAT) {
|
||||
INPLACE_POWER(lhs, mp_float_t, mp_float_t, larray, lstrides, rarray, rstrides);
|
||||
}
|
||||
return MP_OBJ_FROM_PTR(lhs);
|
||||
}
|
||||
#endif /* NDARRAY_HAS_INPLACE_POWER */
|
||||
|
|
|
|||
|
|
@ -1,9 +1,166 @@
|
|||
#include "ndarray.h"
|
||||
|
||||
ndarray_obj_t *ndarray_binary_equality(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *, mp_binary_op_t );
|
||||
ndarray_obj_t *ndarray_binary_add(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
ndarray_obj_t *ndarray_binary_multiply(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
ndarray_obj_t *ndarray_binary_more(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *, mp_binary_op_t );
|
||||
ndarray_obj_t *ndarray_binary_subtract(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
ndarray_obj_t *ndarray_binary_true_divide(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
ndarray_obj_t *ndarray_binary_power(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
mp_obj_t ndarray_binary_equality(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *, mp_binary_op_t );
|
||||
mp_obj_t ndarray_binary_add(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
mp_obj_t ndarray_binary_multiply(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
mp_obj_t ndarray_binary_more(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *, mp_binary_op_t );
|
||||
mp_obj_t ndarray_binary_power(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
mp_obj_t ndarray_binary_subtract(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
mp_obj_t ndarray_binary_true_divide(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
|
||||
mp_obj_t ndarray_inplace_ams(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *, uint8_t );
|
||||
mp_obj_t ndarray_inplace_power(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
mp_obj_t ndarray_inplace_divide(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
|
||||
|
||||
#define UNWRAP_INPLACE_OPERATOR(lhs, larray, lstrides, rarray, rstrides, OPERATOR)\
|
||||
({\
|
||||
if((lhs)->dtype == NDARRAY_UINT8) {\
|
||||
if((rhs)->dtype == NDARRAY_UINT8) {\
|
||||
INPLACE_LOOP((lhs), uint8_t, uint8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_INT8) {\
|
||||
INPLACE_LOOP((lhs), uint8_t, int8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_UINT16) {\
|
||||
INPLACE_LOOP((lhs), uint8_t, uint16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else {\
|
||||
INPLACE_LOOP((lhs), uint8_t, int16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
}\
|
||||
} else if(lhs->dtype == NDARRAY_INT8) {\
|
||||
if(rhs->dtype == NDARRAY_UINT8) {\
|
||||
INPLACE_LOOP((lhs), int8_t, uint8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_INT8) {\
|
||||
INPLACE_LOOP((lhs), int8_t, int8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_UINT16) {\
|
||||
INPLACE_LOOP((lhs), int8_t, uint16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else {\
|
||||
INPLACE_LOOP((lhs), int8_t, int16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
}\
|
||||
} else if(lhs->dtype == NDARRAY_UINT16) {\
|
||||
if(rhs->dtype == NDARRAY_UINT8) {\
|
||||
INPLACE_LOOP((lhs), uint16_t, uint8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_INT8) {\
|
||||
INPLACE_LOOP((lhs), uint16_t, int8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_UINT16) {\
|
||||
INPLACE_LOOP((lhs), uint16_t, uint16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else {\
|
||||
INPLACE_LOOP((lhs), uint16_t, int16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
}\
|
||||
} else if(lhs->dtype == NDARRAY_INT16) {\
|
||||
if(rhs->dtype == NDARRAY_UINT8) {\
|
||||
INPLACE_LOOP((lhs), int16_t, uint8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_INT8) {\
|
||||
INPLACE_LOOP((lhs), int16_t, int8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_UINT16) {\
|
||||
INPLACE_LOOP((lhs), int16_t, uint16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else {\
|
||||
INPLACE_LOOP((lhs), int16_t, int16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
}\
|
||||
} else if(lhs->dtype == NDARRAY_FLOAT) {\
|
||||
if(rhs->dtype == NDARRAY_UINT8) {\
|
||||
INPLACE_LOOP((lhs), mp_float_t, uint8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_INT8) {\
|
||||
INPLACE_LOOP((lhs), mp_float_t, int8_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_UINT16) {\
|
||||
INPLACE_LOOP((lhs), mp_float_t, uint16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else if(rhs->dtype == NDARRAY_INT16) {\
|
||||
INPLACE_LOOP((lhs), mp_float_t, int16_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
} else {\
|
||||
INPLACE_LOOP((lhs), mp_float_t, mp_float_t, (larray), (lstrides), (rarray), (rstrides), OPERATOR);\
|
||||
}\
|
||||
}\
|
||||
})
|
||||
|
||||
#if ULAB_MAX_DIMS == 1
|
||||
#define INPLACE_POWER(results, type_left, type_right, larray, lstrides, rarray, rstrides)\
|
||||
({ size_t l = 0;\
|
||||
do {\
|
||||
*((type_left *)(larray)) = MICROPY_FLOAT_C_FUN(pow)(*((type_left *)(larray)), *((type_right *)(rarray)));\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
})
|
||||
#endif /* ULAB_MAX_DIMS == 1 */
|
||||
|
||||
#if ULAB_MAX_DIMS == 2
|
||||
#define INPLACE_POWER(results, type_left, type_right, larray, lstrides, rarray, rstrides)\
|
||||
({ size_t k = 0;\
|
||||
do {\
|
||||
size_t l = 0;\
|
||||
do {\
|
||||
*((type_left *)(larray)) = MICROPY_FLOAT_C_FUN(pow)(*((type_left *)(larray)), *((type_right *)(rarray)));\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
||||
k++;\
|
||||
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
|
||||
})
|
||||
#endif /* ULAB_MAX_DIMS == 2 */
|
||||
|
||||
#if ULAB_MAX_DIMS == 3
|
||||
#define INPLACE_POWER(results, type_left, type_right, larray, lstrides, rarray, rstrides)\
|
||||
({ size_t j = 0;\
|
||||
do {\
|
||||
size_t k = 0;\
|
||||
do {\
|
||||
size_t l = 0;\
|
||||
do {\
|
||||
*((type_left *)(larray)) = MICROPY_FLOAT_C_FUN(pow)(*((type_left *)(larray)), *((type_right *)(rarray)));\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
||||
k++;\
|
||||
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
|
||||
j++;\
|
||||
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
|
||||
})
|
||||
#endif /* ULAB_MAX_DIMS == 3 */
|
||||
|
||||
#if ULAB_MAX_DIMS == 4
|
||||
#define INPLACE_POWER(results, type_left, type_right, larray, lstrides, rarray, rstrides)\
|
||||
({ size_t i = 0;\
|
||||
do {\
|
||||
size_t j = 0;\
|
||||
do {\
|
||||
size_t k = 0;\
|
||||
do {\
|
||||
size_t l = 0;\
|
||||
do {\
|
||||
*((type_left *)(larray)) = MICROPY_FLOAT_C_FUN(pow)(*((type_left *)(larray)), *((type_right *)(rarray)));\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
||||
l++;\
|
||||
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
||||
k++;\
|
||||
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
|
||||
j++;\
|
||||
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
|
||||
(larray) -= (lstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
|
||||
(larray) += (lstrides)[ULAB_MAX_DIMS - 4];\
|
||||
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
|
||||
(rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
|
||||
i++;\
|
||||
} while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\
|
||||
})
|
||||
#endif /* ULAB_MAX_DIMS == 4 */
|
||||
|
|
|
|||
|
|
@ -69,7 +69,14 @@
|
|||
#define NDARRAY_HAS_BINARY_OP_NOT_EQUAL (1)
|
||||
#define NDARRAY_HAS_BINARY_OP_POWER (1)
|
||||
#define NDARRAY_HAS_BINARY_OP_SUBTRACT (1)
|
||||
#define NDARRAY_HAS_BINARY_OP_TRUE_DIVIDE (1)
|
||||
#define NDARRAY_HAS_BINARY_OP_TRUE_DIVIDE (1)
|
||||
|
||||
#define NDARRAY_HAS_INPLACE_OPS (1)
|
||||
#define NDARRAY_HAS_INPLACE_ADD (1)
|
||||
#define NDARRAY_HAS_INPLACE_MULTIPLY (1)
|
||||
#define NDARRAY_HAS_INPLACE_POWER (1)
|
||||
#define NDARRAY_HAS_INPLACE_SUBTRACT (1)
|
||||
#define NDARRAY_HAS_INPLACE_TRUE_DIVIDE (1)
|
||||
|
||||
// the ndarray unary operators
|
||||
#define NDARRAY_HAS_UNARY_OPS (1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue