moved the reduce_axes helper function to ulab_tools
This commit is contained in:
parent
7c4f4dba48
commit
2c71434eab
5 changed files with 45 additions and 42 deletions
|
|
@ -632,7 +632,7 @@ ndarray_obj_t *ndarray_new_dense_ndarray(uint8_t ndim, size_t *shape, uint8_t dt
|
||||||
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
|
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||||
strides[ULAB_MAX_DIMS-1] = dtype == NDARRAY_BOOL ? 1 : mp_binary_get_size('@', dtype, NULL);
|
strides[ULAB_MAX_DIMS-1] = dtype == NDARRAY_BOOL ? 1 : mp_binary_get_size('@', dtype, NULL);
|
||||||
for(size_t i=ULAB_MAX_DIMS; i > 1; i--) {
|
for(size_t i=ULAB_MAX_DIMS; i > 1; i--) {
|
||||||
strides[i-2] = strides[i-1] * shape[i-1];
|
strides[i-2] = strides[i-1] * MAX(1, shape[i-1]);
|
||||||
}
|
}
|
||||||
return ndarray_new_ndarray(ndim, shape, strides, dtype);
|
return ndarray_new_ndarray(ndim, shape, strides, dtype);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -63,38 +63,6 @@ static void numerical_reduce_axes(ndarray_obj_t *ndarray, int8_t axis, size_t *s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static shape_strides numerical_reduce_axes_(ndarray_obj_t *ndarray, mp_obj_t axis) {
|
|
||||||
// TODO: replace numerical_reduce_axes with this function, wherever applicable
|
|
||||||
int8_t ax = mp_obj_get_int(axis);
|
|
||||||
if(ax < 0) ax += ndarray->ndim;
|
|
||||||
if((ax < 0) || (ax > ndarray->ndim - 1)) {
|
|
||||||
mp_raise_ValueError(translate("index out of range"));
|
|
||||||
}
|
|
||||||
shape_strides _shape_strides;
|
|
||||||
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
|
|
||||||
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
|
||||||
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
|
|
||||||
_shape_strides.shape = shape;
|
|
||||||
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
|
|
||||||
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
|
|
||||||
_shape_strides.strides = strides;
|
|
||||||
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
|
|
||||||
_shape_strides.index = 0;
|
|
||||||
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
|
|
||||||
} else {
|
|
||||||
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
|
|
||||||
if(i > _shape_strides.index) {
|
|
||||||
_shape_strides.shape[i] = ndarray->shape[i];
|
|
||||||
_shape_strides.strides[i] = ndarray->strides[i];
|
|
||||||
} else {
|
|
||||||
_shape_strides.shape[i] = ndarray->shape[i-1];
|
|
||||||
_shape_strides.strides[i] = ndarray->strides[i-1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return _shape_strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
|
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
|
||||||
static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
|
static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
|
||||||
bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
|
bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
|
||||||
|
|
@ -148,7 +116,7 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
|
||||||
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
|
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
shape_strides _shape_strides = numerical_reduce_axes_(ndarray, axis);
|
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
|
||||||
ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL);
|
ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL);
|
||||||
uint8_t *rarray = (uint8_t *)results->array;
|
uint8_t *rarray = (uint8_t *)results->array;
|
||||||
if(optype == NUMERICAL_ALL) {
|
if(optype == NUMERICAL_ALL) {
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,6 @@
|
||||||
|
|
||||||
// TODO: implement cumsum
|
// TODO: implement cumsum
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
uint8_t index;
|
|
||||||
int8_t axis;
|
|
||||||
size_t *shape;
|
|
||||||
int32_t *strides;
|
|
||||||
} shape_strides;
|
|
||||||
|
|
||||||
#define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\
|
#define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\
|
||||||
({\
|
({\
|
||||||
uint16_t best_index = 0;\
|
uint16_t best_index = 0;\
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
#include "py/runtime.h"
|
#include "py/runtime.h"
|
||||||
|
|
||||||
#include "ulab.h"
|
#include "ulab.h"
|
||||||
|
|
@ -158,3 +158,35 @@ void *ndarray_set_float_function(uint8_t dtype) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */
|
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */
|
||||||
|
|
||||||
|
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
|
||||||
|
// TODO: replace numerical_reduce_axes with this function, wherever applicable
|
||||||
|
int8_t ax = mp_obj_get_int(axis);
|
||||||
|
if(ax < 0) ax += ndarray->ndim;
|
||||||
|
if((ax < 0) || (ax > ndarray->ndim - 1)) {
|
||||||
|
mp_raise_ValueError(translate("index out of range"));
|
||||||
|
}
|
||||||
|
shape_strides _shape_strides;
|
||||||
|
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
|
||||||
|
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
|
||||||
|
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
|
||||||
|
_shape_strides.shape = shape;
|
||||||
|
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
|
||||||
|
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
|
||||||
|
_shape_strides.strides = strides;
|
||||||
|
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
|
||||||
|
_shape_strides.index = 0;
|
||||||
|
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
|
||||||
|
} else {
|
||||||
|
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
|
||||||
|
if(i > _shape_strides.index) {
|
||||||
|
_shape_strides.shape[i] = ndarray->shape[i];
|
||||||
|
_shape_strides.strides[i] = ndarray->strides[i];
|
||||||
|
} else {
|
||||||
|
_shape_strides.shape[i] = ndarray->shape[i-1];
|
||||||
|
_shape_strides.strides[i] = ndarray->strides[i-1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return _shape_strides;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,17 @@
|
||||||
#ifndef _TOOLS_
|
#ifndef _TOOLS_
|
||||||
#define _TOOLS_
|
#define _TOOLS_
|
||||||
|
|
||||||
|
#include "ndarray.h"
|
||||||
|
|
||||||
#define SWAP(t, a, b) { t tmp = a; a = b; b = tmp; }
|
#define SWAP(t, a, b) { t tmp = a; a = b; b = tmp; }
|
||||||
|
|
||||||
|
typedef struct _shape_strides_t {
|
||||||
|
uint8_t index;
|
||||||
|
int8_t axis;
|
||||||
|
size_t *shape;
|
||||||
|
int32_t *strides;
|
||||||
|
} shape_strides;
|
||||||
|
|
||||||
mp_float_t ndarray_get_float_uint8(void *);
|
mp_float_t ndarray_get_float_uint8(void *);
|
||||||
mp_float_t ndarray_get_float_int8(void *);
|
mp_float_t ndarray_get_float_int8(void *);
|
||||||
mp_float_t ndarray_get_float_uint16(void *);
|
mp_float_t ndarray_get_float_uint16(void *);
|
||||||
|
|
@ -23,4 +32,5 @@ void *ndarray_get_float_function(uint8_t );
|
||||||
uint8_t ndarray_upcast_dtype(uint8_t , uint8_t );
|
uint8_t ndarray_upcast_dtype(uint8_t , uint8_t );
|
||||||
void *ndarray_set_float_function(uint8_t );
|
void *ndarray_set_float_function(uint8_t );
|
||||||
|
|
||||||
|
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue