moved the reduce_axes helper function to ulab_tools

This commit is contained in:
Zoltán Vörös 2021-02-09 07:00:47 +01:00
parent 7c4f4dba48
commit 2c71434eab
5 changed files with 45 additions and 42 deletions

View file

@ -632,7 +632,7 @@ ndarray_obj_t *ndarray_new_dense_ndarray(uint8_t ndim, size_t *shape, uint8_t dt
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS); int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
strides[ULAB_MAX_DIMS-1] = dtype == NDARRAY_BOOL ? 1 : mp_binary_get_size('@', dtype, NULL); strides[ULAB_MAX_DIMS-1] = dtype == NDARRAY_BOOL ? 1 : mp_binary_get_size('@', dtype, NULL);
for(size_t i=ULAB_MAX_DIMS; i > 1; i--) { for(size_t i=ULAB_MAX_DIMS; i > 1; i--) {
strides[i-2] = strides[i-1] * shape[i-1]; strides[i-2] = strides[i-1] * MAX(1, shape[i-1]);
} }
return ndarray_new_ndarray(ndim, shape, strides, dtype); return ndarray_new_ndarray(ndim, shape, strides, dtype);
} }

View file

@ -63,38 +63,6 @@ static void numerical_reduce_axes(ndarray_obj_t *ndarray, int8_t axis, size_t *s
} }
} }
static shape_strides numerical_reduce_axes_(ndarray_obj_t *ndarray, mp_obj_t axis) {
// TODO: replace numerical_reduce_axes with this function, wherever applicable
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndarray->ndim;
if((ax < 0) || (ax > ndarray->ndim - 1)) {
mp_raise_ValueError(translate("index out of range"));
}
shape_strides _shape_strides;
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
_shape_strides.shape = shape;
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
_shape_strides.strides = strides;
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
_shape_strides.index = 0;
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
} else {
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
if(i > _shape_strides.index) {
_shape_strides.shape[i] = ndarray->shape[i];
_shape_strides.strides[i] = ndarray->strides[i];
} else {
_shape_strides.shape[i] = ndarray->shape[i-1];
_shape_strides.strides[i] = ndarray->strides[i-1];
}
}
}
return _shape_strides;
}
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY #if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) { static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
bool anytype = optype == NUMERICAL_ALL ? 1 : 0; bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
@ -148,7 +116,7 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]); } while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
#endif #endif
} else { } else {
shape_strides _shape_strides = numerical_reduce_axes_(ndarray, axis); shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL); ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL);
uint8_t *rarray = (uint8_t *)results->array; uint8_t *rarray = (uint8_t *)results->array;
if(optype == NUMERICAL_ALL) { if(optype == NUMERICAL_ALL) {

View file

@ -17,13 +17,6 @@
// TODO: implement cumsum // TODO: implement cumsum
typedef struct {
uint8_t index;
int8_t axis;
size_t *shape;
int32_t *strides;
} shape_strides;
#define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\ #define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\
({\ ({\
uint16_t best_index = 0;\ uint16_t best_index = 0;\

View file

@ -9,7 +9,7 @@
*/ */
#include <string.h>
#include "py/runtime.h" #include "py/runtime.h"
#include "ulab.h" #include "ulab.h"
@ -158,3 +158,35 @@ void *ndarray_set_float_function(uint8_t dtype) {
} }
} }
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */ #endif /* NDARRAY_BINARY_USES_FUN_POINTER */
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
// TODO: replace numerical_reduce_axes with this function, wherever applicable
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndarray->ndim;
if((ax < 0) || (ax > ndarray->ndim - 1)) {
mp_raise_ValueError(translate("index out of range"));
}
shape_strides _shape_strides;
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
_shape_strides.shape = shape;
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
_shape_strides.strides = strides;
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
_shape_strides.index = 0;
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
} else {
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
if(i > _shape_strides.index) {
_shape_strides.shape[i] = ndarray->shape[i];
_shape_strides.strides[i] = ndarray->strides[i];
} else {
_shape_strides.shape[i] = ndarray->shape[i-1];
_shape_strides.strides[i] = ndarray->strides[i-1];
}
}
}
return _shape_strides;
}

View file

@ -11,8 +11,17 @@
#ifndef _TOOLS_ #ifndef _TOOLS_
#define _TOOLS_ #define _TOOLS_
#include "ndarray.h"
#define SWAP(t, a, b) { t tmp = a; a = b; b = tmp; } #define SWAP(t, a, b) { t tmp = a; a = b; b = tmp; }
typedef struct _shape_strides_t {
uint8_t index;
int8_t axis;
size_t *shape;
int32_t *strides;
} shape_strides;
mp_float_t ndarray_get_float_uint8(void *); mp_float_t ndarray_get_float_uint8(void *);
mp_float_t ndarray_get_float_int8(void *); mp_float_t ndarray_get_float_int8(void *);
mp_float_t ndarray_get_float_uint16(void *); mp_float_t ndarray_get_float_uint16(void *);
@ -23,4 +32,5 @@ void *ndarray_get_float_function(uint8_t );
uint8_t ndarray_upcast_dtype(uint8_t , uint8_t ); uint8_t ndarray_upcast_dtype(uint8_t , uint8_t );
void *ndarray_set_float_function(uint8_t ); void *ndarray_set_float_function(uint8_t );
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
#endif #endif