fix upcasting rules for ndarray + scalar
This commit is contained in:
parent
9cc3638604
commit
4edb6aa318
7 changed files with 61 additions and 43 deletions
|
|
@ -1469,7 +1469,7 @@ mp_obj_t ndarray_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
|
|||
if (value == MP_OBJ_SENTINEL) { // return value(s)
|
||||
return ndarray_get_slice(self, index, NULL);
|
||||
} else { // assignment to slices; the value must be an ndarray, or a scalar
|
||||
ndarray_obj_t *values = ndarray_from_mp_obj(value);
|
||||
ndarray_obj_t *values = ndarray_from_mp_obj(value, 0);
|
||||
return ndarray_get_slice(self, index, values);
|
||||
}
|
||||
return mp_const_none;
|
||||
|
|
@ -1686,38 +1686,50 @@ MP_DEFINE_CONST_FUN_OBJ_1(ndarray_tobytes_obj, ndarray_tobytes);
|
|||
#endif
|
||||
|
||||
// Binary operations
|
||||
ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t obj) {
|
||||
ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t obj, uint8_t other_type) {
|
||||
// creates an ndarray from a micropython int or float
|
||||
// if the input is an ndarray, it is returned
|
||||
// if other_type is 0, return the smallest type that can accommodate the object
|
||||
ndarray_obj_t *ndarray;
|
||||
|
||||
if(mp_obj_is_int(obj)) {
|
||||
int32_t ivalue = mp_obj_get_int(obj);
|
||||
if((ivalue >= 0) && (ivalue < 256)) {
|
||||
ndarray = ndarray_new_linear_array(1, NDARRAY_UINT8);
|
||||
uint8_t *array = (uint8_t *)ndarray->array;
|
||||
array[0] = (uint8_t)ivalue;
|
||||
} else if((ivalue > 255) && (ivalue < 65535)) {
|
||||
ndarray = ndarray_new_linear_array(1, NDARRAY_UINT16);
|
||||
uint16_t *array = (uint16_t *)ndarray->array;
|
||||
array[0] = (uint16_t)ivalue;
|
||||
} else if((ivalue < 0) && (ivalue > -128)) {
|
||||
ndarray = ndarray_new_linear_array(1, NDARRAY_INT8);
|
||||
int8_t *array = (int8_t *)ndarray->array;
|
||||
array[0] = (int8_t)ivalue;
|
||||
} else if((ivalue < -127) && (ivalue > -32767)) {
|
||||
ndarray = ndarray_new_linear_array(1, NDARRAY_INT16);
|
||||
int16_t *array = (int16_t *)ndarray->array;
|
||||
array[0] = (int16_t)ivalue;
|
||||
} else { // the integer value clearly does not fit the ulab integer types, so move on to float
|
||||
if((ivalue < -32767) || (ivalue > 32767)) {
|
||||
// the integer value clearly does not fit the ulab integer types, so move on to float
|
||||
ndarray = ndarray_new_linear_array(1, NDARRAY_FLOAT);
|
||||
mp_float_t *array = (mp_float_t *)ndarray->array;
|
||||
array[0] = (mp_float_t)ivalue;
|
||||
} else {
|
||||
uint8_t dtype;
|
||||
if(ivalue < 0) {
|
||||
if(ivalue > -128) {
|
||||
dtype = NDARRAY_INT8;
|
||||
} else {
|
||||
dtype = NDARRAY_INT16;
|
||||
}
|
||||
} else { // ivalue >= 0
|
||||
if((other_type == NDARRAY_INT8) || (other_type == NDARRAY_INT16)) {
|
||||
if(ivalue < 128) {
|
||||
dtype = NDARRAY_INT8;
|
||||
} else {
|
||||
dtype = NDARRAY_INT16;
|
||||
}
|
||||
} else { // other_type = 0 is also included here
|
||||
if(ivalue < 256) {
|
||||
dtype = NDARRAY_UINT8;
|
||||
} else {
|
||||
dtype = NDARRAY_UINT16;
|
||||
}
|
||||
}
|
||||
}
|
||||
ndarray = ndarray_new_linear_array(1, dtype);
|
||||
uint8_t width = mp_binary_get_size('@', dtype, NULL);
|
||||
memcpy(ndarray->array, &ivalue, width);
|
||||
}
|
||||
} else if(mp_obj_is_float(obj)) {
|
||||
mp_float_t fvalue = mp_obj_get_float(obj);
|
||||
ndarray = ndarray_new_linear_array(1, NDARRAY_FLOAT);
|
||||
mp_float_t *array = (mp_float_t *)ndarray->array;
|
||||
array[0] = (mp_float_t)fvalue;
|
||||
array[0] = mp_obj_get_float(obj);
|
||||
} else if(mp_obj_is_type(obj, &ulab_ndarray_type)){
|
||||
return obj;
|
||||
} else {
|
||||
|
|
@ -1735,11 +1747,11 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
|
|||
if((op == MP_BINARY_OP_REVERSE_ADD) || (op == MP_BINARY_OP_REVERSE_MULTIPLY) ||
|
||||
(op == MP_BINARY_OP_REVERSE_POWER) || (op == MP_BINARY_OP_REVERSE_SUBTRACT) ||
|
||||
(op == MP_BINARY_OP_REVERSE_TRUE_DIVIDE)) {
|
||||
lhs = ndarray_from_mp_obj(robj);
|
||||
rhs = ndarray_from_mp_obj(lobj);
|
||||
lhs = ndarray_from_mp_obj(robj, 0);
|
||||
rhs = ndarray_from_mp_obj(lobj, lhs->dtype);
|
||||
} else {
|
||||
lhs = ndarray_from_mp_obj(lobj);
|
||||
rhs = ndarray_from_mp_obj(robj);
|
||||
lhs = ndarray_from_mp_obj(lobj, 0);
|
||||
rhs = ndarray_from_mp_obj(robj, lhs->dtype);
|
||||
}
|
||||
if(op == MP_BINARY_OP_REVERSE_ADD) {
|
||||
op = MP_BINARY_OP_ADD;
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ MP_DECLARE_CONST_FUN_OBJ_1(ndarray_info_obj);
|
|||
mp_int_t ndarray_get_buffer(mp_obj_t , mp_buffer_info_t *, mp_uint_t );
|
||||
//void ndarray_attributes(mp_obj_t , qstr , mp_obj_t *);
|
||||
|
||||
ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t );
|
||||
ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t );
|
||||
|
||||
|
||||
#define BOOLEAN_ASSIGNMENT_LOOP(type_left, type_right, ndarray, iarray, istride, varray, vstride)\
|
||||
|
|
|
|||
|
|
@ -58,9 +58,9 @@ STATIC mp_obj_t approx_interp(size_t n_args, const mp_obj_t *pos_args, mp_map_t
|
|||
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||
|
||||
// TODO: numpy allows generic iterables
|
||||
ndarray_obj_t *x = ndarray_from_mp_obj(args[0].u_obj);
|
||||
ndarray_obj_t *xp = ndarray_from_mp_obj(args[1].u_obj); // xp must hold an increasing sequence of independent values
|
||||
ndarray_obj_t *fp = ndarray_from_mp_obj(args[2].u_obj);
|
||||
ndarray_obj_t *x = ndarray_from_mp_obj(args[0].u_obj, 0);
|
||||
ndarray_obj_t *xp = ndarray_from_mp_obj(args[1].u_obj, 0); // xp must hold an increasing sequence of independent values
|
||||
ndarray_obj_t *fp = ndarray_from_mp_obj(args[2].u_obj, 0);
|
||||
if((xp->ndim != 1) || (fp->ndim != 1) || (xp->len < 2) || (fp->len < 2) || (xp->len != fp->len)) {
|
||||
mp_raise_ValueError(translate("interp is defined for 1D arrays of equal length"));
|
||||
}
|
||||
|
|
@ -157,7 +157,7 @@ STATIC mp_obj_t approx_trapz(size_t n_args, const mp_obj_t *pos_args, mp_map_t *
|
|||
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||
|
||||
ndarray_obj_t *y = ndarray_from_mp_obj(args[0].u_obj);
|
||||
ndarray_obj_t *y = ndarray_from_mp_obj(args[0].u_obj, 0);
|
||||
ndarray_obj_t *x;
|
||||
mp_float_t mean = MICROPY_FLOAT_CONST(0.0);
|
||||
if(y->len < 2) {
|
||||
|
|
@ -174,7 +174,7 @@ STATIC mp_obj_t approx_trapz(size_t n_args, const mp_obj_t *pos_args, mp_map_t *
|
|||
mp_float_t y1, y2, m;
|
||||
|
||||
if(args[1].u_obj != mp_const_none) {
|
||||
x = ndarray_from_mp_obj(args[1].u_obj); // x must hold an increasing sequence of independent values
|
||||
x = ndarray_from_mp_obj(args[1].u_obj, 0); // x must hold an increasing sequence of independent values
|
||||
if((x->ndim != 1) || (y->len != x->len)) {
|
||||
mp_raise_ValueError(translate("trapz is defined for 1D arrays of equal length"));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@
|
|||
#include "compare.h"
|
||||
|
||||
static mp_obj_t compare_function(mp_obj_t x1, mp_obj_t x2, uint8_t op) {
|
||||
ndarray_obj_t *lhs = ndarray_from_mp_obj(x1);
|
||||
ndarray_obj_t *rhs = ndarray_from_mp_obj(x2);
|
||||
ndarray_obj_t *lhs = ndarray_from_mp_obj(x1, 0);
|
||||
ndarray_obj_t *rhs = ndarray_from_mp_obj(x2, 0);
|
||||
uint8_t ndim = 0;
|
||||
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
||||
int32_t *lstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||
|
|
@ -309,9 +309,9 @@ MP_DEFINE_CONST_FUN_OBJ_2(compare_minimum_obj, compare_minimum);
|
|||
|
||||
mp_obj_t compare_where(mp_obj_t _condition, mp_obj_t _x, mp_obj_t _y) {
|
||||
// this implementation will work with ndarrays, and scalars only
|
||||
ndarray_obj_t *c = ndarray_from_mp_obj(_condition);
|
||||
ndarray_obj_t *x = ndarray_from_mp_obj(_x);
|
||||
ndarray_obj_t *y = ndarray_from_mp_obj(_y);
|
||||
ndarray_obj_t *c = ndarray_from_mp_obj(_condition, 0);
|
||||
ndarray_obj_t *x = ndarray_from_mp_obj(_x, 0);
|
||||
ndarray_obj_t *y = ndarray_from_mp_obj(_y, 0);
|
||||
|
||||
int32_t *cstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||
int32_t *xstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||
|
|
|
|||
|
|
@ -43,11 +43,11 @@ static mp_obj_t vectorise_generic_vector(mp_obj_t o_in, mp_float_t (*f)(mp_float
|
|||
uint8_t *sarray = (uint8_t *)source->array;
|
||||
ndarray_obj_t *ndarray = ndarray_new_dense_ndarray(source->ndim, source->shape, NDARRAY_FLOAT);
|
||||
mp_float_t *array = (mp_float_t *)ndarray->array;
|
||||
|
||||
|
||||
#if ULAB_VECTORISE_USES_FUN_POINTER
|
||||
|
||||
|
||||
mp_float_t (*func)(void *) = ndarray_get_float_function(source->dtype);
|
||||
|
||||
|
||||
#if ULAB_MAX_DIMS > 3
|
||||
size_t i = 0;
|
||||
do {
|
||||
|
|
@ -98,7 +98,7 @@ static mp_obj_t vectorise_generic_vector(mp_obj_t o_in, mp_float_t (*f)(mp_float
|
|||
ITERATE_VECTOR(mp_float_t, array, source, sarray);
|
||||
}
|
||||
#endif /* ULAB_VECTORISE_USES_FUN_POINTER */
|
||||
|
||||
|
||||
return MP_OBJ_FROM_PTR(ndarray);
|
||||
} else if(mp_obj_is_type(o_in, &mp_type_tuple) || mp_obj_is_type(o_in, &mp_type_list) ||
|
||||
mp_obj_is_type(o_in, &mp_type_range)) { // i.e., the input is a generic iterable
|
||||
|
|
@ -247,8 +247,8 @@ MP_DEFINE_CONST_FUN_OBJ_1(vectorise_atan_obj, vectorise_atan);
|
|||
//|
|
||||
|
||||
mp_obj_t vectorise_arctan2(mp_obj_t y, mp_obj_t x) {
|
||||
ndarray_obj_t *ndarray_x = ndarray_from_mp_obj(x);
|
||||
ndarray_obj_t *ndarray_y = ndarray_from_mp_obj(y);
|
||||
ndarray_obj_t *ndarray_x = ndarray_from_mp_obj(x, 0);
|
||||
ndarray_obj_t *ndarray_y = ndarray_from_mp_obj(y, 0);
|
||||
|
||||
uint8_t ndim = 0;
|
||||
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@
|
|||
#include "user/user.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
#define ULAB_VERSION 2.8.3
|
||||
#define ULAB_VERSION 2.8.4
|
||||
#define xstr(s) str(s)
|
||||
#define str(s) #s
|
||||
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
Tue, 1 Jun 2021
|
||||
|
||||
version 2.8.4
|
||||
|
||||
fix upcasting rules for ndarray + scalar
|
||||
|
||||
Mon, 24 May 2021
|
||||
|
||||
version 2.8.3
|
||||
|
|
|
|||
Loading…
Reference in a new issue