support circuitpython-style split type objects

This commit is contained in:
Jeff Epler 2021-07-06 10:10:35 -05:00
parent 161a728848
commit deda9d74d8
4 changed files with 24 additions and 11 deletions

View file

@ -49,6 +49,13 @@ typedef struct _mp_obj_slice_t {
#define MP_ERROR_TEXT(x) x
#endif
#if !defined(MP_TYPE_FLAG_FULL)
#define MP_TYPE_CALL call
#define mp_type_call(t) t->call
#define MP_TYPE_FLAG_FULL (0)
#define EXTENDED_FIELDS(...) __VA_ARGS__
#endif
#if !CIRCUITPY
#define translate(x) MP_ERROR_TEXT(x)
#define ndarray_set_value(a, b, c, d) mp_binary_set_val_array(a, b, c, d)

View file

@ -549,7 +549,7 @@ static mp_obj_t vectorise_vectorized_function_call(mp_obj_t self_in, size_t n_ar
ndarray_obj_t *ndarray = ndarray_new_dense_ndarray(source->ndim, source->shape, self->otypes);
for(size_t i=0; i < source->len; i++) {
avalue[0] = mp_binary_get_val_array(source->dtype, source->array, i);
fvalue = self->type->call(self->fun, 1, 0, avalue);
fvalue = self->type->MP_TYPE_CALL(self->fun, 1, 0, avalue);
ndarray_set_value(self->otypes, ndarray->array, i, fvalue);
}
return MP_OBJ_FROM_PTR(ndarray);
@ -561,14 +561,14 @@ static mp_obj_t vectorise_vectorized_function_call(mp_obj_t self_in, size_t n_ar
mp_obj_t iterable = mp_getiter(args[0], &iter_buf);
size_t i=0;
while ((avalue[0] = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
fvalue = self->type->call(self->fun, 1, 0, avalue);
fvalue = self->type->MP_TYPE_CALL(self->fun, 1, 0, avalue);
ndarray_set_value(self->otypes, ndarray->array, i, fvalue);
i++;
}
return MP_OBJ_FROM_PTR(ndarray);
} else if(mp_obj_is_int(args[0]) || mp_obj_is_float(args[0])) {
ndarray_obj_t *ndarray = ndarray_new_linear_array(1, self->otypes);
fvalue = self->type->call(self->fun, 1, 0, args);
fvalue = self->type->MP_TYPE_CALL(self->fun, 1, 0, args);
ndarray_set_value(self->otypes, ndarray->array, 0, fvalue);
return MP_OBJ_FROM_PTR(ndarray);
} else {
@ -579,8 +579,11 @@ static mp_obj_t vectorise_vectorized_function_call(mp_obj_t self_in, size_t n_ar
const mp_obj_type_t vectorise_function_type = {
{ &mp_type_type },
.flags = MP_TYPE_FLAG_FULL,
.name = MP_QSTR_,
EXTENDED_FIELDS(
.call = vectorise_vectorized_function_call,
)
};
//| def vectorize(
@ -605,7 +608,7 @@ static mp_obj_t vectorise_vectorize(size_t n_args, const mp_obj_t *pos_args, mp_
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
const mp_obj_type_t *type = mp_obj_get_type(args[0].u_obj);
if(type->call == NULL) {
if(mp_type_call(type) == NULL) {
mp_raise_TypeError(translate("first argument must be a callable"));
}
mp_obj_t _otypes = args[1].u_obj;

View file

@ -30,7 +30,7 @@ static mp_float_t optimize_python_call(const mp_obj_type_t *type, mp_obj_t fun,
// where f is defined in python. Takes a float, returns a float.
// The array of mp_obj_t type must be supplied, as must the number of parameters (a, b, c...) in nparams
fargs[0] = mp_obj_new_float(x);
return mp_obj_get_float(type->call(fun, nparams+1, 0, fargs));
return mp_obj_get_float(type->MP_TYPE_CALL(fun, nparams+1, 0, fargs));
}
#if ULAB_SCIPY_OPTIMIZE_HAS_BISECT
@ -70,7 +70,7 @@ STATIC mp_obj_t optimize_bisect(size_t n_args, const mp_obj_t *pos_args, mp_map_
mp_obj_t fun = args[0].u_obj;
const mp_obj_type_t *type = mp_obj_get_type(fun);
if(type->call == NULL) {
if(mp_type_call(type) == NULL) {
mp_raise_TypeError(translate("first argument must be a function"));
}
mp_float_t xtol = mp_obj_get_float(args[3].u_obj);
@ -140,7 +140,7 @@ STATIC mp_obj_t optimize_fmin(size_t n_args, const mp_obj_t *pos_args, mp_map_t
mp_obj_t fun = args[0].u_obj;
const mp_obj_type_t *type = mp_obj_get_type(fun);
if(type->call == NULL) {
if(mp_type_call(type) == NULL) {
mp_raise_TypeError(translate("first argument must be a function"));
}
@ -276,7 +276,7 @@ mp_obj_t optimize_curve_fit(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
mp_obj_t fun = args[0].u_obj;
const mp_obj_type_t *type = mp_obj_get_type(fun);
if(type->call == NULL) {
if(mp_type_call(type) == NULL) {
mp_raise_TypeError(translate("first argument must be a function"));
}
@ -365,7 +365,7 @@ static mp_obj_t optimize_newton(size_t n_args, const mp_obj_t *pos_args, mp_map_
mp_obj_t fun = args[0].u_obj;
const mp_obj_type_t *type = mp_obj_get_type(fun);
if(type->call == NULL) {
if(mp_type_call(type) == NULL) {
mp_raise_TypeError(translate("first argument must be a function"));
}
mp_float_t x = mp_obj_get_float(args[1].u_obj);

View file

@ -89,12 +89,15 @@ STATIC MP_DEFINE_CONST_DICT(ulab_ndarray_locals_dict, ulab_ndarray_locals_dict_t
const mp_obj_type_t ulab_ndarray_type = {
{ &mp_type_type },
.flags = MP_TYPE_FLAG_FULL
#if defined(MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE) && defined(MP_TYPE_FLAG_EQ_HAS_NEQ_TEST)
.flags = MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE | MP_TYPE_FLAG_EQ_HAS_NEQ_TEST,
| MP_TYPE_FLAG_EQ_CHECKS_OTHER_TYPE | MP_TYPE_FLAG_EQ_HAS_NEQ_TEST,
#endif
.name = MP_QSTR_ndarray,
.print = ndarray_print,
.make_new = ndarray_make_new,
.locals_dict = (mp_obj_dict_t*)&ulab_ndarray_locals_dict,
EXTENDED_FIELDS(
#if NDARRAY_IS_SLICEABLE
.subscr = ndarray_subscr,
#endif
@ -111,7 +114,7 @@ const mp_obj_type_t ulab_ndarray_type = {
.attr = ndarray_properties_attr,
#endif
.buffer_p = { .get_buffer = ndarray_get_buffer, },
.locals_dict = (mp_obj_dict_t*)&ulab_ndarray_locals_dict,
)
};
#if ULAB_HAS_DTYPE_OBJECT