fix upcasting rules for ndarray + scalar

This commit is contained in:
Zoltán Vörös 2021-06-01 17:32:18 +02:00
parent 9cc3638604
commit 4edb6aa318
7 changed files with 61 additions and 43 deletions

View file

@ -1469,7 +1469,7 @@ mp_obj_t ndarray_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
if (value == MP_OBJ_SENTINEL) { // return value(s)
return ndarray_get_slice(self, index, NULL);
} else { // assignment to slices; the value must be an ndarray, or a scalar
ndarray_obj_t *values = ndarray_from_mp_obj(value);
ndarray_obj_t *values = ndarray_from_mp_obj(value, 0);
return ndarray_get_slice(self, index, values);
}
return mp_const_none;
@ -1686,38 +1686,50 @@ MP_DEFINE_CONST_FUN_OBJ_1(ndarray_tobytes_obj, ndarray_tobytes);
#endif
// Binary operations
ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t obj) {
ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t obj, uint8_t other_type) {
// creates an ndarray from a micropython int or float
// if the input is an ndarray, it is returned
// if other_type is 0, return the smallest type that can accommodate the object
ndarray_obj_t *ndarray;
if(mp_obj_is_int(obj)) {
int32_t ivalue = mp_obj_get_int(obj);
if((ivalue >= 0) && (ivalue < 256)) {
ndarray = ndarray_new_linear_array(1, NDARRAY_UINT8);
uint8_t *array = (uint8_t *)ndarray->array;
array[0] = (uint8_t)ivalue;
} else if((ivalue > 255) && (ivalue < 65535)) {
ndarray = ndarray_new_linear_array(1, NDARRAY_UINT16);
uint16_t *array = (uint16_t *)ndarray->array;
array[0] = (uint16_t)ivalue;
} else if((ivalue < 0) && (ivalue > -128)) {
ndarray = ndarray_new_linear_array(1, NDARRAY_INT8);
int8_t *array = (int8_t *)ndarray->array;
array[0] = (int8_t)ivalue;
} else if((ivalue < -127) && (ivalue > -32767)) {
ndarray = ndarray_new_linear_array(1, NDARRAY_INT16);
int16_t *array = (int16_t *)ndarray->array;
array[0] = (int16_t)ivalue;
} else { // the integer value clearly does not fit the ulab integer types, so move on to float
if((ivalue < -32767) || (ivalue > 32767)) {
// the integer value clearly does not fit the ulab integer types, so move on to float
ndarray = ndarray_new_linear_array(1, NDARRAY_FLOAT);
mp_float_t *array = (mp_float_t *)ndarray->array;
array[0] = (mp_float_t)ivalue;
} else {
uint8_t dtype;
if(ivalue < 0) {
if(ivalue > -128) {
dtype = NDARRAY_INT8;
} else {
dtype = NDARRAY_INT16;
}
} else { // ivalue >= 0
if((other_type == NDARRAY_INT8) || (other_type == NDARRAY_INT16)) {
if(ivalue < 128) {
dtype = NDARRAY_INT8;
} else {
dtype = NDARRAY_INT16;
}
} else { // other_type = 0 is also included here
if(ivalue < 256) {
dtype = NDARRAY_UINT8;
} else {
dtype = NDARRAY_UINT16;
}
}
}
ndarray = ndarray_new_linear_array(1, dtype);
uint8_t width = mp_binary_get_size('@', dtype, NULL);
memcpy(ndarray->array, &ivalue, width);
}
} else if(mp_obj_is_float(obj)) {
mp_float_t fvalue = mp_obj_get_float(obj);
ndarray = ndarray_new_linear_array(1, NDARRAY_FLOAT);
mp_float_t *array = (mp_float_t *)ndarray->array;
array[0] = (mp_float_t)fvalue;
array[0] = mp_obj_get_float(obj);
} else if(mp_obj_is_type(obj, &ulab_ndarray_type)){
return obj;
} else {
@ -1735,11 +1747,11 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
if((op == MP_BINARY_OP_REVERSE_ADD) || (op == MP_BINARY_OP_REVERSE_MULTIPLY) ||
(op == MP_BINARY_OP_REVERSE_POWER) || (op == MP_BINARY_OP_REVERSE_SUBTRACT) ||
(op == MP_BINARY_OP_REVERSE_TRUE_DIVIDE)) {
lhs = ndarray_from_mp_obj(robj);
rhs = ndarray_from_mp_obj(lobj);
lhs = ndarray_from_mp_obj(robj, 0);
rhs = ndarray_from_mp_obj(lobj, lhs->dtype);
} else {
lhs = ndarray_from_mp_obj(lobj);
rhs = ndarray_from_mp_obj(robj);
lhs = ndarray_from_mp_obj(lobj, 0);
rhs = ndarray_from_mp_obj(robj, lhs->dtype);
}
if(op == MP_BINARY_OP_REVERSE_ADD) {
op = MP_BINARY_OP_ADD;

View file

@ -192,7 +192,7 @@ MP_DECLARE_CONST_FUN_OBJ_1(ndarray_info_obj);
mp_int_t ndarray_get_buffer(mp_obj_t , mp_buffer_info_t *, mp_uint_t );
//void ndarray_attributes(mp_obj_t , qstr , mp_obj_t *);
ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t );
ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t );
#define BOOLEAN_ASSIGNMENT_LOOP(type_left, type_right, ndarray, iarray, istride, varray, vstride)\

View file

@ -58,9 +58,9 @@ STATIC mp_obj_t approx_interp(size_t n_args, const mp_obj_t *pos_args, mp_map_t
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
// TODO: numpy allows generic iterables
ndarray_obj_t *x = ndarray_from_mp_obj(args[0].u_obj);
ndarray_obj_t *xp = ndarray_from_mp_obj(args[1].u_obj); // xp must hold an increasing sequence of independent values
ndarray_obj_t *fp = ndarray_from_mp_obj(args[2].u_obj);
ndarray_obj_t *x = ndarray_from_mp_obj(args[0].u_obj, 0);
ndarray_obj_t *xp = ndarray_from_mp_obj(args[1].u_obj, 0); // xp must hold an increasing sequence of independent values
ndarray_obj_t *fp = ndarray_from_mp_obj(args[2].u_obj, 0);
if((xp->ndim != 1) || (fp->ndim != 1) || (xp->len < 2) || (fp->len < 2) || (xp->len != fp->len)) {
mp_raise_ValueError(translate("interp is defined for 1D arrays of equal length"));
}
@ -157,7 +157,7 @@ STATIC mp_obj_t approx_trapz(size_t n_args, const mp_obj_t *pos_args, mp_map_t *
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
ndarray_obj_t *y = ndarray_from_mp_obj(args[0].u_obj);
ndarray_obj_t *y = ndarray_from_mp_obj(args[0].u_obj, 0);
ndarray_obj_t *x;
mp_float_t mean = MICROPY_FLOAT_CONST(0.0);
if(y->len < 2) {
@ -174,7 +174,7 @@ STATIC mp_obj_t approx_trapz(size_t n_args, const mp_obj_t *pos_args, mp_map_t *
mp_float_t y1, y2, m;
if(args[1].u_obj != mp_const_none) {
x = ndarray_from_mp_obj(args[1].u_obj); // x must hold an increasing sequence of independent values
x = ndarray_from_mp_obj(args[1].u_obj, 0); // x must hold an increasing sequence of independent values
if((x->ndim != 1) || (y->len != x->len)) {
mp_raise_ValueError(translate("trapz is defined for 1D arrays of equal length"));
}

View file

@ -23,8 +23,8 @@
#include "compare.h"
static mp_obj_t compare_function(mp_obj_t x1, mp_obj_t x2, uint8_t op) {
ndarray_obj_t *lhs = ndarray_from_mp_obj(x1);
ndarray_obj_t *rhs = ndarray_from_mp_obj(x2);
ndarray_obj_t *lhs = ndarray_from_mp_obj(x1, 0);
ndarray_obj_t *rhs = ndarray_from_mp_obj(x2, 0);
uint8_t ndim = 0;
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
int32_t *lstrides = m_new(int32_t, ULAB_MAX_DIMS);
@ -309,9 +309,9 @@ MP_DEFINE_CONST_FUN_OBJ_2(compare_minimum_obj, compare_minimum);
mp_obj_t compare_where(mp_obj_t _condition, mp_obj_t _x, mp_obj_t _y) {
// this implementation will work with ndarrays, and scalars only
ndarray_obj_t *c = ndarray_from_mp_obj(_condition);
ndarray_obj_t *x = ndarray_from_mp_obj(_x);
ndarray_obj_t *y = ndarray_from_mp_obj(_y);
ndarray_obj_t *c = ndarray_from_mp_obj(_condition, 0);
ndarray_obj_t *x = ndarray_from_mp_obj(_x, 0);
ndarray_obj_t *y = ndarray_from_mp_obj(_y, 0);
int32_t *cstrides = m_new(int32_t, ULAB_MAX_DIMS);
int32_t *xstrides = m_new(int32_t, ULAB_MAX_DIMS);

View file

@ -43,11 +43,11 @@ static mp_obj_t vectorise_generic_vector(mp_obj_t o_in, mp_float_t (*f)(mp_float
uint8_t *sarray = (uint8_t *)source->array;
ndarray_obj_t *ndarray = ndarray_new_dense_ndarray(source->ndim, source->shape, NDARRAY_FLOAT);
mp_float_t *array = (mp_float_t *)ndarray->array;
#if ULAB_VECTORISE_USES_FUN_POINTER
mp_float_t (*func)(void *) = ndarray_get_float_function(source->dtype);
#if ULAB_MAX_DIMS > 3
size_t i = 0;
do {
@ -98,7 +98,7 @@ static mp_obj_t vectorise_generic_vector(mp_obj_t o_in, mp_float_t (*f)(mp_float
ITERATE_VECTOR(mp_float_t, array, source, sarray);
}
#endif /* ULAB_VECTORISE_USES_FUN_POINTER */
return MP_OBJ_FROM_PTR(ndarray);
} else if(mp_obj_is_type(o_in, &mp_type_tuple) || mp_obj_is_type(o_in, &mp_type_list) ||
mp_obj_is_type(o_in, &mp_type_range)) { // i.e., the input is a generic iterable
@ -247,8 +247,8 @@ MP_DEFINE_CONST_FUN_OBJ_1(vectorise_atan_obj, vectorise_atan);
//|
mp_obj_t vectorise_arctan2(mp_obj_t y, mp_obj_t x) {
ndarray_obj_t *ndarray_x = ndarray_from_mp_obj(x);
ndarray_obj_t *ndarray_y = ndarray_from_mp_obj(y);
ndarray_obj_t *ndarray_x = ndarray_from_mp_obj(x, 0);
ndarray_obj_t *ndarray_y = ndarray_from_mp_obj(y, 0);
uint8_t ndim = 0;
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);

View file

@ -34,7 +34,7 @@
#include "user/user.h"
#include "utils/utils.h"
#define ULAB_VERSION 2.8.3
#define ULAB_VERSION 2.8.4
#define xstr(s) str(s)
#define str(s) #s
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)

View file

@ -1,3 +1,9 @@
Tue, 1 Jun 2021
version 2.8.4
fix upcasting rules for ndarray + scalar
Mon, 24 May 2021
version 2.8.3