moved the reduce_axes helper function to ulab_tools
This commit is contained in:
parent
7c4f4dba48
commit
2c71434eab
5 changed files with 45 additions and 42 deletions
|
|
@ -632,7 +632,7 @@ ndarray_obj_t *ndarray_new_dense_ndarray(uint8_t ndim, size_t *shape, uint8_t dt
|
|||
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||
strides[ULAB_MAX_DIMS-1] = dtype == NDARRAY_BOOL ? 1 : mp_binary_get_size('@', dtype, NULL);
|
||||
for(size_t i=ULAB_MAX_DIMS; i > 1; i--) {
|
||||
strides[i-2] = strides[i-1] * shape[i-1];
|
||||
strides[i-2] = strides[i-1] * MAX(1, shape[i-1]);
|
||||
}
|
||||
return ndarray_new_ndarray(ndim, shape, strides, dtype);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,38 +63,6 @@ static void numerical_reduce_axes(ndarray_obj_t *ndarray, int8_t axis, size_t *s
|
|||
}
|
||||
}
|
||||
|
||||
static shape_strides numerical_reduce_axes_(ndarray_obj_t *ndarray, mp_obj_t axis) {
|
||||
// TODO: replace numerical_reduce_axes with this function, wherever applicable
|
||||
int8_t ax = mp_obj_get_int(axis);
|
||||
if(ax < 0) ax += ndarray->ndim;
|
||||
if((ax < 0) || (ax > ndarray->ndim - 1)) {
|
||||
mp_raise_ValueError(translate("index out of range"));
|
||||
}
|
||||
shape_strides _shape_strides;
|
||||
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
|
||||
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
||||
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
|
||||
_shape_strides.shape = shape;
|
||||
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
|
||||
_shape_strides.strides = strides;
|
||||
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
|
||||
_shape_strides.index = 0;
|
||||
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
|
||||
} else {
|
||||
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
|
||||
if(i > _shape_strides.index) {
|
||||
_shape_strides.shape[i] = ndarray->shape[i];
|
||||
_shape_strides.strides[i] = ndarray->strides[i];
|
||||
} else {
|
||||
_shape_strides.shape[i] = ndarray->shape[i-1];
|
||||
_shape_strides.strides[i] = ndarray->strides[i-1];
|
||||
}
|
||||
}
|
||||
}
|
||||
return _shape_strides;
|
||||
}
|
||||
|
||||
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
|
||||
static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
|
||||
bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
|
||||
|
|
@ -148,7 +116,7 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
|
|||
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
|
||||
#endif
|
||||
} else {
|
||||
shape_strides _shape_strides = numerical_reduce_axes_(ndarray, axis);
|
||||
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
|
||||
ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL);
|
||||
uint8_t *rarray = (uint8_t *)results->array;
|
||||
if(optype == NUMERICAL_ALL) {
|
||||
|
|
|
|||
|
|
@ -17,13 +17,6 @@
|
|||
|
||||
// TODO: implement cumsum
|
||||
|
||||
typedef struct {
|
||||
uint8_t index;
|
||||
int8_t axis;
|
||||
size_t *shape;
|
||||
int32_t *strides;
|
||||
} shape_strides;
|
||||
|
||||
#define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\
|
||||
({\
|
||||
uint16_t best_index = 0;\
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
*/
|
||||
|
||||
|
||||
|
||||
#include <string.h>
|
||||
#include "py/runtime.h"
|
||||
|
||||
#include "ulab.h"
|
||||
|
|
@ -158,3 +158,35 @@ void *ndarray_set_float_function(uint8_t dtype) {
|
|||
}
|
||||
}
|
||||
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */
|
||||
|
||||
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
|
||||
// TODO: replace numerical_reduce_axes with this function, wherever applicable
|
||||
int8_t ax = mp_obj_get_int(axis);
|
||||
if(ax < 0) ax += ndarray->ndim;
|
||||
if((ax < 0) || (ax > ndarray->ndim - 1)) {
|
||||
mp_raise_ValueError(translate("index out of range"));
|
||||
}
|
||||
shape_strides _shape_strides;
|
||||
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
|
||||
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
||||
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
|
||||
_shape_strides.shape = shape;
|
||||
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
|
||||
_shape_strides.strides = strides;
|
||||
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
|
||||
_shape_strides.index = 0;
|
||||
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
|
||||
} else {
|
||||
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
|
||||
if(i > _shape_strides.index) {
|
||||
_shape_strides.shape[i] = ndarray->shape[i];
|
||||
_shape_strides.strides[i] = ndarray->strides[i];
|
||||
} else {
|
||||
_shape_strides.shape[i] = ndarray->shape[i-1];
|
||||
_shape_strides.strides[i] = ndarray->strides[i-1];
|
||||
}
|
||||
}
|
||||
}
|
||||
return _shape_strides;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,8 +11,17 @@
|
|||
#ifndef _TOOLS_
|
||||
#define _TOOLS_
|
||||
|
||||
#include "ndarray.h"
|
||||
|
||||
#define SWAP(t, a, b) { t tmp = a; a = b; b = tmp; }
|
||||
|
||||
typedef struct _shape_strides_t {
|
||||
uint8_t index;
|
||||
int8_t axis;
|
||||
size_t *shape;
|
||||
int32_t *strides;
|
||||
} shape_strides;
|
||||
|
||||
mp_float_t ndarray_get_float_uint8(void *);
|
||||
mp_float_t ndarray_get_float_int8(void *);
|
||||
mp_float_t ndarray_get_float_uint16(void *);
|
||||
|
|
@ -23,4 +32,5 @@ void *ndarray_get_float_function(uint8_t );
|
|||
uint8_t ndarray_upcast_dtype(uint8_t , uint8_t );
|
||||
void *ndarray_set_float_function(uint8_t );
|
||||
|
||||
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue