Compare commits

...

5 commits

Author SHA1 Message Date
Zoltán Vörös
774b821d11
Merge branch 'master' into keepdims 2024-12-30 22:22:15 +01:00
Zoltán Vörös
0c8c5f03fe remove out-commented code 2024-12-30 22:19:30 +01:00
Zoltán Vörös
a3fc235418 fux keepdims code 2024-12-30 22:17:25 +01:00
Zoltán Vörös
f013badcff preliminary keepdims fix 2024-12-26 16:11:38 +01:00
Zoltán Vörös
35c2b85e57 add function to deal with keepdims=True 2024-12-08 19:21:58 +01:00
7 changed files with 190 additions and 38 deletions

View file

@ -274,7 +274,7 @@ static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, si
} }
} }
static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype, size_t ddof) { static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, mp_obj_t keepdims, uint8_t optype, size_t ddof) {
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype) COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
uint8_t *array = (uint8_t *)ndarray->array; uint8_t *array = (uint8_t *)ndarray->array;
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis); shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
@ -372,7 +372,7 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
mp_float_t norm = (mp_float_t)_shape_strides.shape[0]; mp_float_t norm = (mp_float_t)_shape_strides.shape[0];
// re-wind the array here // re-wind the array here
farray = (mp_float_t *)results->array; farray = (mp_float_t *)results->array;
for(size_t i=0; i < results->len; i++) { for(size_t i = 0; i < results->len; i++) {
*farray++ *= norm; *farray++ *= norm;
} }
} }
@ -380,7 +380,7 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
bool isStd = optype == NUMERICAL_STD ? 1 : 0; bool isStd = optype == NUMERICAL_STD ? 1 : 0;
results = ndarray_new_dense_ndarray(_shape_strides.ndim, _shape_strides.shape, NDARRAY_FLOAT); results = ndarray_new_dense_ndarray(_shape_strides.ndim, _shape_strides.shape, NDARRAY_FLOAT);
farray = (mp_float_t *)results->array; farray = (mp_float_t *)results->array;
// we can return the 0 array here, if the degrees of freedom is larger than the length of the axis // we can return the 0 array here, if the degrees of freedom are larger than the length of the axis
if((optype == NUMERICAL_STD) && (_shape_strides.shape[0] <= ddof)) { if((optype == NUMERICAL_STD) && (_shape_strides.shape[0] <= ddof)) {
return MP_OBJ_FROM_PTR(results); return MP_OBJ_FROM_PTR(results);
} }
@ -397,11 +397,9 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd); RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd);
} }
} }
if(results->ndim == 0) { // return a scalar here return ulab_tools_restore_dims(ndarray, results, keepdims, _shape_strides);
return mp_binary_get_val_array(results->dtype, results->array, 0);
}
return MP_OBJ_FROM_PTR(results);
} }
// we should never get to this point
return mp_const_none; return mp_const_none;
} }
#endif #endif
@ -441,7 +439,7 @@ static mp_obj_t numerical_argmin_argmax_iterable(mp_obj_t oin, uint8_t optype) {
} }
} }
static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype) { static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t keepdims, mp_obj_t axis, uint8_t optype) {
// TODO: treat the flattened array // TODO: treat the flattened array
if(ndarray->len == 0) { if(ndarray->len == 0) {
mp_raise_ValueError(MP_ERROR_TEXT("attempt to get (arg)min/(arg)max of empty sequence")); mp_raise_ValueError(MP_ERROR_TEXT("attempt to get (arg)min/(arg)max of empty sequence"));
@ -521,7 +519,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
int32_t *strides = m_new0(int32_t, ULAB_MAX_DIMS); int32_t *strides = m_new0(int32_t, ULAB_MAX_DIMS);
numerical_reduce_axes(ndarray, ax, shape, strides); numerical_reduce_axes(ndarray, ax, shape, strides);
uint8_t index = ULAB_MAX_DIMS - ndarray->ndim + ax; shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
uint8_t index = _shape_strides.axis;
ndarray_obj_t *results = NULL; ndarray_obj_t *results = NULL;
@ -550,8 +550,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
if(results->len == 1) { if(results->len == 1) {
return mp_binary_get_val_array(results->dtype, results->array, 0); return mp_binary_get_val_array(results->dtype, results->array, 0);
} }
return MP_OBJ_FROM_PTR(results); return ulab_tools_restore_dims(ndarray, results, keepdims, _shape_strides);
} }
// we should never get to this point
return mp_const_none; return mp_const_none;
} }
#endif #endif
@ -560,6 +561,7 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
static const mp_arg_t allowed_args[] = { static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } , { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, { MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
}; };
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
@ -567,6 +569,8 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
mp_obj_t oin = args[0].u_obj; mp_obj_t oin = args[0].u_obj;
mp_obj_t axis = args[1].u_obj; mp_obj_t axis = args[1].u_obj;
mp_obj_t keepdims = args[2].u_obj;
if((axis != mp_const_none) && (!mp_obj_is_int(axis))) { if((axis != mp_const_none) && (!mp_obj_is_int(axis))) {
mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer")); mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
} }
@ -598,11 +602,11 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
case NUMERICAL_ARGMIN: case NUMERICAL_ARGMIN:
case NUMERICAL_ARGMAX: case NUMERICAL_ARGMAX:
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype) COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
return numerical_argmin_argmax_ndarray(ndarray, axis, optype); return numerical_argmin_argmax_ndarray(ndarray, keepdims, axis, optype);
case NUMERICAL_SUM: case NUMERICAL_SUM:
case NUMERICAL_MEAN: case NUMERICAL_MEAN:
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype) COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
return numerical_sum_mean_std_ndarray(ndarray, axis, optype, 0); return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, optype, 0);
default: default:
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is not implemented on ndarrays")); mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is not implemented on ndarrays"));
} }
@ -1385,6 +1389,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } } , { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } } ,
{ MP_QSTR_axis, MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } }, { MP_QSTR_axis, MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_ddof, MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} }, { MP_QSTR_ddof, MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} },
{ MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
}; };
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
@ -1393,6 +1398,8 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
mp_obj_t oin = args[0].u_obj; mp_obj_t oin = args[0].u_obj;
mp_obj_t axis = args[1].u_obj; mp_obj_t axis = args[1].u_obj;
size_t ddof = args[2].u_int; size_t ddof = args[2].u_int;
mp_obj_t keepdims = args[2].u_obj;
if((axis != mp_const_none) && (mp_obj_get_int(axis) != 0) && (mp_obj_get_int(axis) != 1)) { if((axis != mp_const_none) && (mp_obj_get_int(axis) != 0) && (mp_obj_get_int(axis) != 1)) {
// this seems to pass with False, and True... // this seems to pass with False, and True...
mp_raise_ValueError(MP_ERROR_TEXT("axis must be None, or an integer")); mp_raise_ValueError(MP_ERROR_TEXT("axis must be None, or an integer"));
@ -1401,7 +1408,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
return numerical_sum_mean_std_iterable(oin, NUMERICAL_STD, ddof); return numerical_sum_mean_std_iterable(oin, NUMERICAL_STD, ddof);
} else if(mp_obj_is_type(oin, &ulab_ndarray_type)) { } else if(mp_obj_is_type(oin, &ulab_ndarray_type)) {
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin); ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
return numerical_sum_mean_std_ndarray(ndarray, axis, NUMERICAL_STD, ddof); return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, NUMERICAL_STD, ddof);
} else { } else {
mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray")); mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray"));
} }

View file

@ -33,7 +33,7 @@
#include "user/user.h" #include "user/user.h"
#include "utils/utils.h" #include "utils/utils.h"
#define ULAB_VERSION 6.7.0 #define ULAB_VERSION 6.7.1
#define xstr(s) str(s) #define xstr(s) str(s)
#define str(s) #s #define str(s) #s

View file

@ -162,6 +162,15 @@ void *ndarray_set_float_function(uint8_t dtype) {
} }
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */ #endif /* NDARRAY_BINARY_USES_FUN_POINTER */
int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) {
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndim;
if((ax < 0) || (ax > ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds"));
}
return ax;
}
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) { shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
// TODO: replace numerical_reduce_axes with this function, wherever applicable // TODO: replace numerical_reduce_axes with this function, wherever applicable
// This function should be used, whenever a tensor is contracted; // This function should be used, whenever a tensor is contracted;
@ -172,38 +181,36 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
} }
shape_strides _shape_strides; shape_strides _shape_strides;
_shape_strides.increment = 0;
// this is the contracted dimension (won't be overwritten for axis == None)
_shape_strides.ndim = 0;
if(axis == mp_const_none) {
_shape_strides.shape = ndarray->shape;
_shape_strides.strides = ndarray->strides;
return _shape_strides;
}
size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1); size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1);
_shape_strides.shape = shape; _shape_strides.shape = shape;
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1); int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1);
_shape_strides.strides = strides; _shape_strides.strides = strides;
_shape_strides.increment = 0;
// this is the contracted dimension (won't be overwritten for axis == None)
_shape_strides.ndim = 0;
memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS); memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS); memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);
if(axis == mp_const_none) { _shape_strides.axis = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)
return _shape_strides;
}
uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)
if(axis != mp_const_none) { // i.e., axis is an integer if(axis != mp_const_none) { // i.e., axis is an integer
int8_t ax = mp_obj_get_int(axis); int8_t ax = tools_get_axis(axis, ndarray->ndim);
if(ax < 0) ax += ndarray->ndim; _shape_strides.axis = ULAB_MAX_DIMS - ndarray->ndim + ax;
if((ax < 0) || (ax > ndarray->ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
}
index = ULAB_MAX_DIMS - ndarray->ndim + ax;
_shape_strides.ndim = ndarray->ndim - 1; _shape_strides.ndim = ndarray->ndim - 1;
} }
// move the value stored at index to the leftmost position, and align everything else to the right // move the value stored at index to the leftmost position, and align everything else to the right
_shape_strides.shape[0] = ndarray->shape[index]; _shape_strides.shape[0] = ndarray->shape[_shape_strides.axis];
_shape_strides.strides[0] = ndarray->strides[index]; _shape_strides.strides[0] = ndarray->strides[_shape_strides.axis];
for(uint8_t i = 0; i < index; i++) { for(uint8_t i = 0; i < _shape_strides.axis; i++) {
// entries to the right of index must be shifted by one position to the left // entries to the right of index must be shifted by one position to the left
_shape_strides.shape[i + 1] = ndarray->shape[i]; _shape_strides.shape[i + 1] = ndarray->shape[i];
_shape_strides.strides[i + 1] = ndarray->strides[i]; _shape_strides.strides[i + 1] = ndarray->strides[i];
@ -213,16 +220,37 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
_shape_strides.increment = 1; _shape_strides.increment = 1;
} }
if(_shape_strides.ndim == 0) {
_shape_strides.ndim = 1;
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
_shape_strides.strides[ULAB_MAX_DIMS - 1] = ndarray->itemsize;
}
return _shape_strides; return _shape_strides;
} }
int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) { mp_obj_t ulab_tools_restore_dims(ndarray_obj_t *ndarray, ndarray_obj_t *results, mp_obj_t keepdims, shape_strides _shape_strides) {
int8_t ax = mp_obj_get_int(axis); // restores the contracted dimension, if keepdims is True
if(ax < 0) ax += ndim; if((ndarray->ndim == 1) && (keepdims != mp_const_true)) {
if((ax < 0) || (ax > ndim - 1)) { // since the original array has already been contracted and
mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds")); // we don't want to keep the dimensions here, we have to return a scalar
return mp_binary_get_val_array(results->dtype, results->array, 0);
} }
return ax;
if(keepdims == mp_const_true) {
results->ndim += 1;
for(int8_t i = 0; i < ULAB_MAX_DIMS; i++) {
results->shape[i] = ndarray->shape[i];
}
results->shape[_shape_strides.axis] = 1;
results->strides[ULAB_MAX_DIMS - 1] = ndarray->itemsize;
for(uint8_t i = ULAB_MAX_DIMS; i > 1; i--) {
results->strides[i - 2] = results->strides[i - 1] * results->shape[i - 1];
}
}
return MP_OBJ_FROM_PTR(results);
} }
#if ULAB_MAX_DIMS > 1 #if ULAB_MAX_DIMS > 1

View file

@ -17,6 +17,7 @@
typedef struct _shape_strides_t { typedef struct _shape_strides_t {
uint8_t increment; uint8_t increment;
uint8_t axis;
uint8_t ndim; uint8_t ndim;
size_t *shape; size_t *shape;
int32_t *strides; int32_t *strides;
@ -34,6 +35,7 @@ void *ndarray_set_float_function(uint8_t );
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t ); shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
int8_t tools_get_axis(mp_obj_t , uint8_t ); int8_t tools_get_axis(mp_obj_t , uint8_t );
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t * , ndarray_obj_t * , mp_obj_t , shape_strides );
ndarray_obj_t *tools_object_is_square(mp_obj_t ); ndarray_obj_t *tools_object_is_square(mp_obj_t );
uint8_t ulab_binary_get_size(uint8_t ); uint8_t ulab_binary_get_size(uint8_t );

View file

@ -1,3 +1,15 @@
Mon, 30 Dec 2024
version 6.7.1
add keepdims keyword argument to numerical functions
Sun, 15 Dec 2024
version 6.7.0
add scipy.integrate module
Sun, 24 Nov 2024 Sun, 24 Nov 2024
version 6.6.1 version 6.6.1

23
tests/2d/numpy/sum.py Normal file
View file

@ -0,0 +1,23 @@
try:
from ulab import numpy as np
except ImportError:
import numpy as np
for dtype in (np.uint8, np.int8, np.uint16, np.int8, np.float):
a = np.array(range(12), dtype=dtype)
b = a.reshape((3, 4))
print(a)
print(b)
print()
print(np.sum(a))
print(np.sum(a, axis=0))
print(np.sum(a, axis=0, keepdims=True))
print()
print(np.sum(b))
print(np.sum(b, axis=0))
print(np.sum(b, axis=1))
print(np.sum(b, axis=0, keepdims=True))
print(np.sum(b, axis=1, keepdims=True))

80
tests/2d/numpy/sum.py.exp Normal file
View file

@ -0,0 +1,80 @@
array([0, 1, 2, ..., 9, 10, 11], dtype=uint8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=uint8)
66
66
array([66], dtype=uint8)
66
array([12, 15, 18, 21], dtype=uint8)
array([6, 22, 38], dtype=uint8)
array([[12, 15, 18, 21]], dtype=uint8)
array([[6],
[22],
[38]], dtype=uint8)
array([0, 1, 2, ..., 9, 10, 11], dtype=int8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=int8)
66
66
array([66], dtype=int8)
66
array([12, 15, 18, 21], dtype=int8)
array([6, 22, 38], dtype=int8)
array([[12, 15, 18, 21]], dtype=int8)
array([[6],
[22],
[38]], dtype=int8)
array([0, 1, 2, ..., 9, 10, 11], dtype=uint16)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=uint16)
66
66
array([66], dtype=uint16)
66
array([12, 15, 18, 21], dtype=uint16)
array([6, 22, 38], dtype=uint16)
array([[12, 15, 18, 21]], dtype=uint16)
array([[6],
[22],
[38]], dtype=uint16)
array([0, 1, 2, ..., 9, 10, 11], dtype=int8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=int8)
66
66
array([66], dtype=int8)
66
array([12, 15, 18, 21], dtype=int8)
array([6, 22, 38], dtype=int8)
array([[12, 15, 18, 21]], dtype=int8)
array([[6],
[22],
[38]], dtype=int8)
array([0.0, 1.0, 2.0, ..., 9.0, 10.0, 11.0], dtype=float64)
array([[0.0, 1.0, 2.0, 3.0],
[4.0, 5.0, 6.0, 7.0],
[8.0, 9.0, 10.0, 11.0]], dtype=float64)
66.0
66.0
array([66.0], dtype=float64)
66.0
array([12.0, 15.0, 18.0, 21.0], dtype=float64)
array([6.0, 22.0, 38.0], dtype=float64)
array([[12.0, 15.0, 18.0, 21.0]], dtype=float64)
array([[6.0],
[22.0],
[38.0]], dtype=float64)