660 lines
26 KiB
C
660 lines
26 KiB
C
|
|
/*
|
|
* This file is part of the micropython-ulab project,
|
|
*
|
|
* https://github.com/v923z/micropython-ulab
|
|
*
|
|
* The MIT License (MIT)
|
|
*
|
|
* Copyright (c) 2020-2021 Zoltán Vörös
|
|
* 2020 Jeff Epler for Adafruit Industries
|
|
*/
|
|
|
|
#include <math.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
#include "py/obj.h"
|
|
#include "py/runtime.h"
|
|
#include "py/misc.h"
|
|
|
|
#include "../ulab.h"
|
|
#include "../ndarray_operators.h"
|
|
#include "../ulab_tools.h"
|
|
#include "carray/carray_tools.h"
|
|
#include "compare.h"
|
|
|
|
#ifdef ULAB_NUMPY_HAS_BINCOUNT
|
|
mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
|
static const mp_arg_t allowed_args[] = {
|
|
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
|
|
{ MP_QSTR_weights, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
|
|
{ MP_QSTR_minlength, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
|
|
};
|
|
|
|
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
|
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
|
|
|
return mp_const_none;
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_KW(compare_bincount_obj, 1, compare_bincount);
|
|
#endif /* ULAB_NUMPY_HAS_BINCOUNT */
|
|
|
|
static mp_obj_t compare_function(mp_obj_t x1, mp_obj_t x2, uint8_t op) {
|
|
ndarray_obj_t *lhs = ndarray_from_mp_obj(x1, 0);
|
|
ndarray_obj_t *rhs = ndarray_from_mp_obj(x2, 0);
|
|
#if ULAB_SUPPORTS_COMPLEX
|
|
if((lhs->dtype == NDARRAY_COMPLEX) || (rhs->dtype == NDARRAY_COMPLEX)) {
|
|
NOT_IMPLEMENTED_FOR_COMPLEX()
|
|
}
|
|
#endif
|
|
uint8_t ndim = 0;
|
|
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
|
int32_t *lstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
|
int32_t *rstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
|
if(!ndarray_can_broadcast(lhs, rhs, &ndim, shape, lstrides, rstrides)) {
|
|
mp_raise_ValueError(MP_ERROR_TEXT("operands could not be broadcast together"));
|
|
m_del(size_t, shape, ULAB_MAX_DIMS);
|
|
m_del(int32_t, lstrides, ULAB_MAX_DIMS);
|
|
m_del(int32_t, rstrides, ULAB_MAX_DIMS);
|
|
}
|
|
|
|
uint8_t *larray = (uint8_t *)lhs->array;
|
|
uint8_t *rarray = (uint8_t *)rhs->array;
|
|
|
|
if(op == COMPARE_EQUAL) {
|
|
return ndarray_binary_equality(lhs, rhs, ndim, shape, lstrides, rstrides, MP_BINARY_OP_EQUAL);
|
|
} else if(op == COMPARE_NOT_EQUAL) {
|
|
return ndarray_binary_equality(lhs, rhs, ndim, shape, lstrides, rstrides, MP_BINARY_OP_NOT_EQUAL);
|
|
}
|
|
// These are the upcasting rules
|
|
// float always becomes float
|
|
// operation on identical types preserves type
|
|
// uint8 + int8 => int16
|
|
// uint8 + int16 => int16
|
|
// uint8 + uint16 => uint16
|
|
// int8 + int16 => int16
|
|
// int8 + uint16 => uint16
|
|
// uint16 + int16 => float
|
|
// The parameters of RUN_COMPARE_LOOP are
|
|
// typecode of result, type_out, type_left, type_right, lhs operand, rhs operand, operator
|
|
if(lhs->dtype == NDARRAY_UINT8) {
|
|
if(rhs->dtype == NDARRAY_UINT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_UINT8, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT16, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_UINT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_UINT16, uint16_t, uint8_t, uint16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT16, int16_t, uint8_t, int16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_FLOAT) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, uint8_t, mp_float_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
}
|
|
} else if(lhs->dtype == NDARRAY_INT8) {
|
|
if(rhs->dtype == NDARRAY_UINT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT16, int16_t, int8_t, uint8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT8, int8_t, int8_t, int8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_UINT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT16, int16_t, int8_t, uint16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT16, int16_t, int8_t, int16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_FLOAT) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, int8_t, mp_float_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
}
|
|
} else if(lhs->dtype == NDARRAY_UINT16) {
|
|
if(rhs->dtype == NDARRAY_UINT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_UINT16, uint16_t, uint16_t, uint8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_UINT16, uint16_t, uint16_t, int8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_UINT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_UINT16, uint16_t, uint16_t, uint16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, uint16_t, int16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_FLOAT) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, uint16_t, mp_float_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
}
|
|
} else if(lhs->dtype == NDARRAY_INT16) {
|
|
if(rhs->dtype == NDARRAY_UINT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT16, int16_t, int16_t, uint8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT16, int16_t, int16_t, int8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_UINT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, int16_t, uint16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_INT16, int16_t, int16_t, int16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_FLOAT) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, int16_t, mp_float_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
}
|
|
} else if(lhs->dtype == NDARRAY_FLOAT) {
|
|
if(rhs->dtype == NDARRAY_UINT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, mp_float_t, uint8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT8) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, mp_float_t, int8_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_UINT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, mp_float_t, uint16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_INT16) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, mp_float_t, int16_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
} else if(rhs->dtype == NDARRAY_FLOAT) {
|
|
RUN_COMPARE_LOOP(NDARRAY_FLOAT, mp_float_t, mp_float_t, mp_float_t, larray, lstrides, rarray, rstrides, ndim, shape, op);
|
|
}
|
|
}
|
|
return mp_const_none; // we should never reach this point
|
|
}
|
|
|
|
#if ULAB_NUMPY_HAS_EQUAL | ULAB_NUMPY_HAS_NOTEQUAL
|
|
static mp_obj_t compare_equal_helper(mp_obj_t x1, mp_obj_t x2, uint8_t comptype) {
|
|
// scalar comparisons should return a single object of mp_obj_t type
|
|
mp_obj_t result = compare_function(x1, x2, comptype);
|
|
if((mp_obj_is_int(x1) || mp_obj_is_float(x1)) && (mp_obj_is_int(x2) || mp_obj_is_float(x2))) {
|
|
mp_obj_iter_buf_t iter_buf;
|
|
mp_obj_t iterable = mp_getiter(result, &iter_buf);
|
|
mp_obj_t item = mp_iternext(iterable);
|
|
return item;
|
|
}
|
|
return result;
|
|
}
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_CLIP
|
|
//| def clip(
|
|
//| a: _ScalarOrArrayLike,
|
|
//| a_min: _ScalarOrArrayLike,
|
|
//| a_max: _ScalarOrArrayLike,
|
|
//| ) -> _ScalarOrNdArray:
|
|
//| """
|
|
//| Clips (limits) the values in an array.
|
|
//|
|
|
//| :param a: Scalar or array containing elements to clip.
|
|
//| :param a_min: Minimum value, it will be broadcast against ``a``.
|
|
//| :param a_max: Maximum value, it will be broadcast against ``a``.
|
|
//| :return:
|
|
//| A scalar or array with the elements of ``a``, but where
|
|
//| values < ``a_min`` are replaced with ``a_min``, and those
|
|
//| > ``a_max`` with ``a_max``.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_clip(mp_obj_t x1, mp_obj_t x2, mp_obj_t x3) {
|
|
// Note: this function could be made faster by implementing a single-loop comparison in
|
|
// RUN_COMPARE_LOOP. However, that would add around 2 kB of compile size, while we
|
|
// would not gain a factor of two in speed, since the two comparisons should still be
|
|
// evaluated. In contrast, calling the function twice adds only 140 bytes to the firmware
|
|
if(mp_obj_is_int(x1) || mp_obj_is_float(x1)) {
|
|
mp_float_t v1 = mp_obj_get_float(x1);
|
|
mp_float_t v2 = mp_obj_get_float(x2);
|
|
mp_float_t v3 = mp_obj_get_float(x3);
|
|
if(v1 < v2) {
|
|
return x2;
|
|
} else if(v1 > v3) {
|
|
return x3;
|
|
} else {
|
|
return x1;
|
|
}
|
|
} else { // assume ndarrays
|
|
return compare_function(x2, compare_function(x1, x3, COMPARE_MINIMUM), COMPARE_MAXIMUM);
|
|
}
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_3(compare_clip_obj, compare_clip);
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_EQUAL
|
|
//| def equal(x: _ScalarOrArrayLike, y: _ScalarOrArrayLike) -> _ScalarOrNdArray:
|
|
//| """
|
|
//| Returns ``x == y`` element-wise.
|
|
//|
|
|
//| :param x, y:
|
|
//| Input scalar or array. If ``x.shape != y.shape`` they must
|
|
//| be broadcastable to a common shape (which becomes the
|
|
//| shape of the output.)
|
|
//| :return:
|
|
//| A boolean scalar or array with the element-wise result of ``x == y``.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_equal(mp_obj_t x1, mp_obj_t x2) {
|
|
return compare_equal_helper(x1, x2, COMPARE_EQUAL);
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_2(compare_equal_obj, compare_equal);
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_NOTEQUAL
|
|
//| def not_equal(
|
|
//| x: _ScalarOrArrayLike,
|
|
//| y: _ScalarOrArrayLike,
|
|
//| ) -> Union[_bool, ulab.numpy.ndarray]:
|
|
//| """
|
|
//| Returns ``x != y`` element-wise.
|
|
//|
|
|
//| :param x, y:
|
|
//| Input scalar or array. If ``x.shape != y.shape`` they must
|
|
//| be broadcastable to a common shape (which becomes the
|
|
//| shape of the output.)
|
|
//| :return:
|
|
//| A boolean scalar or array with the element-wise result of ``x != y``.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_not_equal(mp_obj_t x1, mp_obj_t x2) {
|
|
return compare_equal_helper(x1, x2, COMPARE_NOT_EQUAL);
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_2(compare_not_equal_obj, compare_not_equal);
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_ISFINITE | ULAB_NUMPY_HAS_ISINF
|
|
static mp_obj_t compare_isinf_isfinite(mp_obj_t _x, uint8_t mask) {
|
|
// mask should signify, whether the function is called from isinf (mask = 1),
|
|
// or from isfinite (mask = 0)
|
|
if(mp_obj_is_int(_x)) {
|
|
if(mask) {
|
|
return mp_const_false;
|
|
} else {
|
|
return mp_const_true;
|
|
}
|
|
} else if(mp_obj_is_float(_x)) {
|
|
mp_float_t x = mp_obj_get_float(_x);
|
|
if(isnan(x)) {
|
|
return mp_const_false;
|
|
}
|
|
if(mask) { // called from isinf
|
|
return isinf(x) ? mp_const_true : mp_const_false;
|
|
} else { // called from isfinite
|
|
return isinf(x) ? mp_const_false : mp_const_true;
|
|
}
|
|
} else if(mp_obj_is_type(_x, &ulab_ndarray_type)) {
|
|
ndarray_obj_t *x = MP_OBJ_TO_PTR(_x);
|
|
COMPLEX_DTYPE_NOT_IMPLEMENTED(x->dtype)
|
|
ndarray_obj_t *results = ndarray_new_dense_ndarray(x->ndim, x->shape, NDARRAY_BOOL);
|
|
// At this point, results is all False
|
|
uint8_t *rarray = (uint8_t *)results->array;
|
|
if(x->dtype != NDARRAY_FLOAT) {
|
|
// int types can never be infinite...
|
|
if(!mask) {
|
|
// ...so flip all values in the array, if the function was called from isfinite
|
|
memset(rarray, 1, results->len);
|
|
}
|
|
return MP_OBJ_FROM_PTR(results);
|
|
}
|
|
uint8_t *xarray = (uint8_t *)x->array;
|
|
|
|
ITERATOR_HEAD();
|
|
mp_float_t value = *(mp_float_t *)xarray;
|
|
if(isnan(value)) {
|
|
*rarray++ = 0;
|
|
} else {
|
|
*rarray++ = isinf(value) ? mask : 1 - mask;
|
|
}
|
|
ITERATOR_TAIL(x, xarray);
|
|
return MP_OBJ_FROM_PTR(results);
|
|
} else {
|
|
mp_raise_TypeError(MP_ERROR_TEXT("wrong input type"));
|
|
}
|
|
return mp_const_none;
|
|
}
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_ISFINITE
|
|
//| def isfinite(x: _ScalarOrNdArray) -> Union[_bool, ulab.numpy.ndarray]:
|
|
//| """
|
|
//| Tests element-wise for finiteness (i.e., it should not be infinity or a NaN).
|
|
//|
|
|
//| :param x: Input scalar or ndarray.
|
|
//| :return:
|
|
//| A boolean scalar or array with True where ``x`` is finite, and
|
|
//| False otherwise.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_isfinite(mp_obj_t _x) {
|
|
return compare_isinf_isfinite(_x, 0);
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_1(compare_isfinite_obj, compare_isfinite);
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_ISINF
|
|
//| def isinf(x: _ScalarOrNdArray) -> Union[_bool, ulab.numpy.ndarray]:
|
|
//| """
|
|
//| Tests element-wise for positive or negative infinity.
|
|
//|
|
|
//| :param x: Input scalar or ndarray.
|
|
//| :return:
|
|
//| A boolean scalar or array with True where ``x`` is positive or
|
|
//| negative infinity, and False otherwise.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_isinf(mp_obj_t _x) {
|
|
return compare_isinf_isfinite(_x, 1);
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_1(compare_isinf_obj, compare_isinf);
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_MAXIMUM
|
|
//| def maximum(x1: _ScalarOrArrayLike, x2: _ScalarOrArrayLike) -> _ScalarOrNdArray:
|
|
//| """
|
|
//| Returns the element-wise maximum.
|
|
//|
|
|
//| :param x1, x2:
|
|
//| Input scalar or array. If ``x.shape != y.shape`` they must
|
|
//| be broadcastable to a common shape (which becomes the
|
|
//| shape of the output.)
|
|
//| :return:
|
|
//| A scalar or array with the element-wise maximum of ``x1`` and ``x2``.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_maximum(mp_obj_t x1, mp_obj_t x2) {
|
|
// extra round, so that we can return maximum(3, 4) properly
|
|
mp_obj_t result = compare_function(x1, x2, COMPARE_MAXIMUM);
|
|
if((mp_obj_is_int(x1) || mp_obj_is_float(x1)) && (mp_obj_is_int(x2) || mp_obj_is_float(x2))) {
|
|
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(result);
|
|
return mp_binary_get_val_array(ndarray->dtype, ndarray->array, 0);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_2(compare_maximum_obj, compare_maximum);
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_MINIMUM
|
|
|
|
//| def minimum(x1: _ScalarOrArrayLike, x2: _ScalarOrArrayLike) -> _ScalarOrNdArray:
|
|
//| """
|
|
//| Returns the element-wise minimum.
|
|
//|
|
|
//| :param x1, x2:
|
|
//| Input scalar or array. If ``x.shape != y.shape`` they must
|
|
//| be broadcastable to a common shape (which becomes the
|
|
//| shape of the output.)
|
|
//| :return:
|
|
//| A scalar or array with the element-wise minimum of ``x1`` and ``x2``.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_minimum(mp_obj_t x1, mp_obj_t x2) {
|
|
// extra round, so that we can return minimum(3, 4) properly
|
|
mp_obj_t result = compare_function(x1, x2, COMPARE_MINIMUM);
|
|
if((mp_obj_is_int(x1) || mp_obj_is_float(x1)) && (mp_obj_is_int(x2) || mp_obj_is_float(x2))) {
|
|
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(result);
|
|
return mp_binary_get_val_array(ndarray->dtype, ndarray->array, 0);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_2(compare_minimum_obj, compare_minimum);
|
|
#endif
|
|
|
|
#if ULAB_NUMPY_HAS_NONZERO
|
|
|
|
//| def nonzero(x: _ScalarOrArrayLike) -> ulab.numpy.ndarray:
|
|
//| """
|
|
//| Returns the indices of elements that are non-zero.
|
|
//|
|
|
//| :param x:
|
|
//| Input scalar or array. If ``x`` is a scalar, it is treated
|
|
//| as a single-element 1-d array.
|
|
//| :return:
|
|
//| An array of indices that are non-zero.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_nonzero(mp_obj_t x) {
|
|
ndarray_obj_t *ndarray_x = ndarray_from_mp_obj(x, 0);
|
|
// since ndarray_new_linear_array calls m_new0, the content of zero is a single zero
|
|
ndarray_obj_t *zero = ndarray_new_linear_array(1, NDARRAY_UINT8);
|
|
|
|
uint8_t ndim = 0;
|
|
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
|
int32_t *x_strides = m_new(int32_t, ULAB_MAX_DIMS);
|
|
int32_t *zero_strides = m_new(int32_t, ULAB_MAX_DIMS);
|
|
// we don't actually have to inspect the outcome of ndarray_can_broadcast,
|
|
// because the right hand side is a linear array with a single element
|
|
ndarray_can_broadcast(ndarray_x, zero, &ndim, shape, x_strides, zero_strides);
|
|
|
|
// equal_obj is a Boolean ndarray
|
|
mp_obj_t equal_obj = ndarray_binary_equality(ndarray_x, zero, ndim, shape, x_strides, zero_strides, MP_BINARY_OP_NOT_EQUAL);
|
|
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(equal_obj);
|
|
|
|
// these are no longer needed, get rid of them
|
|
m_del(size_t, shape, ULAB_MAX_DIMS);
|
|
m_del(int32_t, x_strides, ULAB_MAX_DIMS);
|
|
m_del(int32_t, zero_strides, ULAB_MAX_DIMS);
|
|
|
|
uint8_t *array = (uint8_t *)ndarray->array;
|
|
uint8_t *origin = (uint8_t *)ndarray->array;
|
|
|
|
// First, count the number of Trues:
|
|
uint16_t count = 0;
|
|
size_t indices[ULAB_MAX_DIMS];
|
|
|
|
#if ULAB_MAX_DIMS > 3
|
|
indices[3] = 0;
|
|
do {
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 2
|
|
indices[2] = 0;
|
|
do {
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 1
|
|
indices[1] = 0;
|
|
do {
|
|
#endif
|
|
indices[0] = 0;
|
|
do {
|
|
if(*array != 0) {
|
|
count++;
|
|
}
|
|
array += ndarray->strides[ULAB_MAX_DIMS - 1];
|
|
indices[0]++;
|
|
} while(indices[0] < ndarray->shape[ULAB_MAX_DIMS - 1]);
|
|
#if ULAB_MAX_DIMS > 1
|
|
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS-1];
|
|
array += ndarray->strides[ULAB_MAX_DIMS - 2];
|
|
indices[1]++;
|
|
} while(indices[1] < ndarray->shape[ULAB_MAX_DIMS - 2]);
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 2
|
|
array -= ndarray->strides[ULAB_MAX_DIMS - 2] * ndarray->shape[ULAB_MAX_DIMS-2];
|
|
array += ndarray->strides[ULAB_MAX_DIMS - 3];
|
|
indices[2]++;
|
|
} while(indices[2] < ndarray->shape[ULAB_MAX_DIMS - 3]);
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 3
|
|
array -= ndarray->strides[ULAB_MAX_DIMS - 3] * ndarray->shape[ULAB_MAX_DIMS-3];
|
|
array += ndarray->strides[ULAB_MAX_DIMS - 4];
|
|
indices[3]++;
|
|
} while(indices[3] < ndarray->shape[ULAB_MAX_DIMS - 4]);
|
|
#endif
|
|
|
|
mp_obj_t *items = m_new(mp_obj_t, ndarray->ndim);
|
|
uint16_t *arrays[ULAB_MAX_DIMS];
|
|
|
|
for(uint8_t i = 0; i < ndarray->ndim; i++) {
|
|
ndarray_obj_t *item_array = ndarray_new_linear_array(count, NDARRAY_UINT16);
|
|
uint16_t *iarray = (uint16_t *)item_array->array;
|
|
arrays[ULAB_MAX_DIMS - 1 - i] = iarray;
|
|
items[ndarray->ndim - 1 - i] = MP_OBJ_FROM_PTR(item_array);
|
|
}
|
|
array = origin;
|
|
count = 0;
|
|
|
|
#if ULAB_MAX_DIMS > 3
|
|
indices[3] = 0;
|
|
do {
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 2
|
|
indices[2] = 0;
|
|
do {
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 1
|
|
indices[1] = 0;
|
|
do {
|
|
#endif
|
|
indices[0] = 0;
|
|
do {
|
|
if(*array != 0) {
|
|
for(uint8_t d = 0; d < ndarray->ndim; d++) {
|
|
arrays[ULAB_MAX_DIMS - 1 - d][count] = indices[d];
|
|
}
|
|
count++;
|
|
}
|
|
array += ndarray->strides[ULAB_MAX_DIMS - 1];
|
|
indices[0]++;
|
|
} while(indices[0] < ndarray->shape[ULAB_MAX_DIMS - 1]);
|
|
#if ULAB_MAX_DIMS > 1
|
|
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS-1];
|
|
array += ndarray->strides[ULAB_MAX_DIMS - 2];
|
|
indices[1]++;
|
|
} while(indices[1] < ndarray->shape[ULAB_MAX_DIMS - 2]);
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 2
|
|
array -= ndarray->strides[ULAB_MAX_DIMS - 2] * ndarray->shape[ULAB_MAX_DIMS-2];
|
|
array += ndarray->strides[ULAB_MAX_DIMS - 3];
|
|
indices[2]++;
|
|
} while(indices[2] < ndarray->shape[ULAB_MAX_DIMS - 3]);
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 3
|
|
array -= ndarray->strides[ULAB_MAX_DIMS - 3] * ndarray->shape[ULAB_MAX_DIMS-3];
|
|
array += ndarray->strides[ULAB_MAX_DIMS - 4];
|
|
indices[3]++;
|
|
} while(indices[3] < ndarray->shape[ULAB_MAX_DIMS - 4]);
|
|
#endif
|
|
|
|
return mp_obj_new_tuple(ndarray->ndim, items);
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_1(compare_nonzero_obj, compare_nonzero);
|
|
#endif /* ULAB_NUMPY_HAS_NONZERO */
|
|
|
|
#if ULAB_NUMPY_HAS_WHERE
|
|
|
|
//| def where(
|
|
//| condition: _ScalarOrArrayLike,
|
|
//| x: _ScalarOrArrayLike,
|
|
//| y: _ScalarOrArrayLike,
|
|
//| ) -> ulab.numpy.ndarray:
|
|
//| """
|
|
//| Returns elements from ``x`` or ``y`` depending on ``condition``.
|
|
//|
|
|
//| :param condition:
|
|
//| Input scalar or array. If an element (or scalar) is truthy,
|
|
//| the corresponding element from ``x`` is chosen, otherwise
|
|
//| ``y`` is used. ``condition``, ``x`` and ``y`` must also be
|
|
//| broadcastable to the same shape (which becomes the output
|
|
//| shape.)
|
|
//| :param x, y:
|
|
//| Input scalar or array.
|
|
//| :return:
|
|
//| An array with elements from ``x`` when ``condition`` is
|
|
//| truthy, and ``y`` elsewhere.
|
|
//| """
|
|
//| ...
|
|
mp_obj_t compare_where(mp_obj_t _condition, mp_obj_t _x, mp_obj_t _y) {
|
|
// this implementation will work with ndarrays, and scalars only
|
|
ndarray_obj_t *c = ndarray_from_mp_obj(_condition, 0);
|
|
ndarray_obj_t *x = ndarray_from_mp_obj(_x, 0);
|
|
ndarray_obj_t *y = ndarray_from_mp_obj(_y, 0);
|
|
|
|
COMPLEX_DTYPE_NOT_IMPLEMENTED(c->dtype)
|
|
COMPLEX_DTYPE_NOT_IMPLEMENTED(x->dtype)
|
|
COMPLEX_DTYPE_NOT_IMPLEMENTED(y->dtype)
|
|
|
|
int32_t *cstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
|
int32_t *xstrides = m_new(int32_t, ULAB_MAX_DIMS);
|
|
int32_t *ystrides = m_new(int32_t, ULAB_MAX_DIMS);
|
|
|
|
size_t *oshape = m_new(size_t, ULAB_MAX_DIMS);
|
|
|
|
uint8_t ndim;
|
|
|
|
// establish the broadcasting conditions first
|
|
// if any two of the arrays can be broadcast together, then
|
|
// the three arrays can also be broadcast together
|
|
if(!ndarray_can_broadcast(c, x, &ndim, oshape, cstrides, ystrides) ||
|
|
!ndarray_can_broadcast(c, y, &ndim, oshape, cstrides, ystrides) ||
|
|
!ndarray_can_broadcast(x, y, &ndim, oshape, xstrides, ystrides)) {
|
|
mp_raise_ValueError(MP_ERROR_TEXT("operands could not be broadcast together"));
|
|
}
|
|
|
|
ndim = MAX(MAX(c->ndim, x->ndim), y->ndim);
|
|
|
|
for(uint8_t i = 1; i <= ndim; i++) {
|
|
cstrides[ULAB_MAX_DIMS - i] = c->shape[ULAB_MAX_DIMS - i] < 2 ? 0 : c->strides[ULAB_MAX_DIMS - i];
|
|
xstrides[ULAB_MAX_DIMS - i] = x->shape[ULAB_MAX_DIMS - i] < 2 ? 0 : x->strides[ULAB_MAX_DIMS - i];
|
|
ystrides[ULAB_MAX_DIMS - i] = y->shape[ULAB_MAX_DIMS - i] < 2 ? 0 : y->strides[ULAB_MAX_DIMS - i];
|
|
oshape[ULAB_MAX_DIMS - i] = MAX(MAX(c->shape[ULAB_MAX_DIMS - i], x->shape[ULAB_MAX_DIMS - i]), y->shape[ULAB_MAX_DIMS - i]);
|
|
}
|
|
|
|
uint8_t out_dtype = ndarray_upcast_dtype(x->dtype, y->dtype);
|
|
ndarray_obj_t *out = ndarray_new_dense_ndarray(ndim, oshape, out_dtype);
|
|
|
|
mp_float_t (*cfunc)(void *) = ndarray_get_float_function(c->dtype);
|
|
mp_float_t (*xfunc)(void *) = ndarray_get_float_function(x->dtype);
|
|
mp_float_t (*yfunc)(void *) = ndarray_get_float_function(y->dtype);
|
|
mp_float_t (*ofunc)(void *, mp_float_t ) = ndarray_set_float_function(out->dtype);
|
|
|
|
uint8_t *oarray = (uint8_t *)out->array;
|
|
uint8_t *carray = (uint8_t *)c->array;
|
|
uint8_t *xarray = (uint8_t *)x->array;
|
|
uint8_t *yarray = (uint8_t *)y->array;
|
|
|
|
#if ULAB_MAX_DIMS > 3
|
|
size_t i = 0;
|
|
do {
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 2
|
|
size_t j = 0;
|
|
do {
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 1
|
|
size_t k = 0;
|
|
do {
|
|
#endif
|
|
size_t l = 0;
|
|
do {
|
|
mp_float_t value;
|
|
mp_float_t cvalue = cfunc(carray);
|
|
if(cvalue != MICROPY_FLOAT_CONST(0.0)) {
|
|
value = xfunc(xarray);
|
|
} else {
|
|
value = yfunc(yarray);
|
|
}
|
|
ofunc(oarray, value);
|
|
oarray += out->itemsize;
|
|
carray += cstrides[ULAB_MAX_DIMS - 1];
|
|
xarray += xstrides[ULAB_MAX_DIMS - 1];
|
|
yarray += ystrides[ULAB_MAX_DIMS - 1];
|
|
l++;
|
|
} while(l < out->shape[ULAB_MAX_DIMS - 1]);
|
|
#if ULAB_MAX_DIMS > 1
|
|
carray -= cstrides[ULAB_MAX_DIMS - 1] * c->shape[ULAB_MAX_DIMS-1];
|
|
carray += cstrides[ULAB_MAX_DIMS - 2];
|
|
xarray -= xstrides[ULAB_MAX_DIMS - 1] * x->shape[ULAB_MAX_DIMS-1];
|
|
xarray += xstrides[ULAB_MAX_DIMS - 2];
|
|
yarray -= ystrides[ULAB_MAX_DIMS - 1] * y->shape[ULAB_MAX_DIMS-1];
|
|
yarray += ystrides[ULAB_MAX_DIMS - 2];
|
|
k++;
|
|
} while(k < out->shape[ULAB_MAX_DIMS - 2]);
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 2
|
|
carray -= cstrides[ULAB_MAX_DIMS - 2] * c->shape[ULAB_MAX_DIMS-2];
|
|
carray += cstrides[ULAB_MAX_DIMS - 3];
|
|
xarray -= xstrides[ULAB_MAX_DIMS - 2] * x->shape[ULAB_MAX_DIMS-2];
|
|
xarray += xstrides[ULAB_MAX_DIMS - 3];
|
|
yarray -= ystrides[ULAB_MAX_DIMS - 2] * y->shape[ULAB_MAX_DIMS-2];
|
|
yarray += ystrides[ULAB_MAX_DIMS - 3];
|
|
j++;
|
|
} while(j < out->shape[ULAB_MAX_DIMS - 3]);
|
|
#endif
|
|
#if ULAB_MAX_DIMS > 3
|
|
carray -= cstrides[ULAB_MAX_DIMS - 3] * c->shape[ULAB_MAX_DIMS-3];
|
|
carray += cstrides[ULAB_MAX_DIMS - 4];
|
|
xarray -= xstrides[ULAB_MAX_DIMS - 3] * x->shape[ULAB_MAX_DIMS-3];
|
|
xarray += xstrides[ULAB_MAX_DIMS - 4];
|
|
yarray -= ystrides[ULAB_MAX_DIMS - 3] * y->shape[ULAB_MAX_DIMS-3];
|
|
yarray += ystrides[ULAB_MAX_DIMS - 4];
|
|
i++;
|
|
} while(i < out->shape[ULAB_MAX_DIMS - 4]);
|
|
#endif
|
|
return MP_OBJ_FROM_PTR(out);
|
|
}
|
|
|
|
MP_DEFINE_CONST_FUN_OBJ_3(compare_where_obj, compare_where);
|
|
#endif
|