add function to deal with keepdims=True
This commit is contained in:
parent
303e8d790a
commit
35c2b85e57
3 changed files with 32 additions and 6 deletions
|
|
@ -380,7 +380,7 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
|
||||||
bool isStd = optype == NUMERICAL_STD ? 1 : 0;
|
bool isStd = optype == NUMERICAL_STD ? 1 : 0;
|
||||||
results = ndarray_new_dense_ndarray(_shape_strides.ndim, _shape_strides.shape, NDARRAY_FLOAT);
|
results = ndarray_new_dense_ndarray(_shape_strides.ndim, _shape_strides.shape, NDARRAY_FLOAT);
|
||||||
farray = (mp_float_t *)results->array;
|
farray = (mp_float_t *)results->array;
|
||||||
// we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
|
// we can return the 0 array here, if the degrees of freedom are larger than the length of the axis
|
||||||
if((optype == NUMERICAL_STD) && (_shape_strides.shape[0] <= ddof)) {
|
if((optype == NUMERICAL_STD) && (_shape_strides.shape[0] <= ddof)) {
|
||||||
return MP_OBJ_FROM_PTR(results);
|
return MP_OBJ_FROM_PTR(results);
|
||||||
}
|
}
|
||||||
|
|
@ -397,9 +397,6 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
|
||||||
RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd);
|
RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(results->ndim == 0) { // return a scalar here
|
|
||||||
return mp_binary_get_val_array(results->dtype, results->array, 0);
|
|
||||||
}
|
|
||||||
return MP_OBJ_FROM_PTR(results);
|
return MP_OBJ_FROM_PTR(results);
|
||||||
}
|
}
|
||||||
return mp_const_none;
|
return mp_const_none;
|
||||||
|
|
@ -560,6 +557,7 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
|
||||||
static const mp_arg_t allowed_args[] = {
|
static const mp_arg_t allowed_args[] = {
|
||||||
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
|
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
|
||||||
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
|
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
|
||||||
|
{ MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
|
||||||
};
|
};
|
||||||
|
|
||||||
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||||
|
|
@ -567,6 +565,8 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
|
||||||
|
|
||||||
mp_obj_t oin = args[0].u_obj;
|
mp_obj_t oin = args[0].u_obj;
|
||||||
mp_obj_t axis = args[1].u_obj;
|
mp_obj_t axis = args[1].u_obj;
|
||||||
|
mp_obj_t keepdims = args[2].u_obj;
|
||||||
|
|
||||||
if((axis != mp_const_none) && (!mp_obj_is_int(axis))) {
|
if((axis != mp_const_none) && (!mp_obj_is_int(axis))) {
|
||||||
mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
|
mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
|
||||||
}
|
}
|
||||||
|
|
@ -578,6 +578,7 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
|
||||||
#endif
|
#endif
|
||||||
if(mp_obj_is_type(oin, &mp_type_tuple) || mp_obj_is_type(oin, &mp_type_list) ||
|
if(mp_obj_is_type(oin, &mp_type_tuple) || mp_obj_is_type(oin, &mp_type_list) ||
|
||||||
mp_obj_is_type(oin, &mp_type_range)) {
|
mp_obj_is_type(oin, &mp_type_range)) {
|
||||||
|
mp_obj_t *result = NULL;
|
||||||
switch(optype) {
|
switch(optype) {
|
||||||
case NUMERICAL_MIN:
|
case NUMERICAL_MIN:
|
||||||
case NUMERICAL_ARGMIN:
|
case NUMERICAL_ARGMIN:
|
||||||
|
|
@ -602,14 +603,14 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
|
||||||
case NUMERICAL_SUM:
|
case NUMERICAL_SUM:
|
||||||
case NUMERICAL_MEAN:
|
case NUMERICAL_MEAN:
|
||||||
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
|
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
|
||||||
return numerical_sum_mean_std_ndarray(ndarray, axis, optype, 0);
|
result = numerical_sum_mean_std_ndarray(ndarray, axis, optype, 0);
|
||||||
default:
|
default:
|
||||||
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is not implemented on ndarrays"));
|
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is not implemented on ndarrays"));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray"));
|
mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray"));
|
||||||
}
|
}
|
||||||
return mp_const_none;
|
return ulab_tools_restore_dims(result, keepdims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if ULAB_NUMPY_HAS_SORT | NDARRAY_HAS_SORT
|
#if ULAB_NUMPY_HAS_SORT | NDARRAY_HAS_SORT
|
||||||
|
|
|
||||||
|
|
@ -225,6 +225,30 @@ int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) {
|
||||||
return ax;
|
return ax;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mp_obj_t ulab_tools_restore_dims(mp_obj_t *result, mp_obj_t keepdims, mp_obj_t axis, uint8_t ndim) {
|
||||||
|
// restores the contracted dimension, if keepdims is True
|
||||||
|
ndarray_obj_t *_result = MP_OBJ_TO_PTR(result);
|
||||||
|
if(keepdims == mp_const_true) {
|
||||||
|
_result->ndim += 1;
|
||||||
|
int8_t = tools_get_axis(axis, _result->ndim + 1);
|
||||||
|
|
||||||
|
// shift values from the right to the left in the strides and shape arrays
|
||||||
|
for(uint8_t i = ULAB_MAX_DIMS - _result->ndim + ax - 1; i > 0; i--) {
|
||||||
|
_result->shape[i - 1] = _result->shape[i];
|
||||||
|
_result->strides[i - 1] = _result->strides[i];
|
||||||
|
}
|
||||||
|
_result->shape[ULAB_MAX_DIMS - _result->ndim + ax] = 1;
|
||||||
|
_result->strides[ULAB_MAX_DIMS - _result->ndim + ax] = _result->strides[ULAB_MAX_DIMS - _result->ndim + ax + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if(keepdims == mp_const_false) {
|
||||||
|
if(results->ndim == 0) { // return a scalar here
|
||||||
|
return mp_binary_get_val_array(results->dtype, results->array, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
#if ULAB_MAX_DIMS > 1
|
#if ULAB_MAX_DIMS > 1
|
||||||
ndarray_obj_t *tools_object_is_square(mp_obj_t obj) {
|
ndarray_obj_t *tools_object_is_square(mp_obj_t obj) {
|
||||||
// Returns an ndarray, if the object is a square ndarray,
|
// Returns an ndarray, if the object is a square ndarray,
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ void *ndarray_set_float_function(uint8_t );
|
||||||
|
|
||||||
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
|
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
|
||||||
int8_t tools_get_axis(mp_obj_t , uint8_t );
|
int8_t tools_get_axis(mp_obj_t , uint8_t );
|
||||||
|
mp_obj_t ulab_tools_restore_dims(mp_obj_t *, mp_obj_t , mp_obj_t , uint8_t );
|
||||||
ndarray_obj_t *tools_object_is_square(mp_obj_t );
|
ndarray_obj_t *tools_object_is_square(mp_obj_t );
|
||||||
|
|
||||||
uint8_t ulab_binary_get_size(uint8_t );
|
uint8_t ulab_binary_get_size(uint8_t );
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue