median implemented for linear arrays

This commit is contained in:
Zoltán Vörös 2020-11-03 19:07:11 +01:00
parent adfa60fbc9
commit c84ea225bd
3 changed files with 24 additions and 7 deletions

View file

@ -802,9 +802,26 @@ mp_obj_t numerical_median(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
if(!MP_OBJ_IS_TYPE(args[0].u_obj, &ulab_ndarray_type)) {
mp_raise_TypeError(translate("median argument must be an ndarray"));
}
ndarray_obj_t *ndarray = numerical_sort_helper(args[0].u_obj, args[1].u_obj, 0); ndarray_obj_t *ndarray = numerical_sort_helper(args[0].u_obj, args[1].u_obj, 0);
return MP_OBJ_FROM_PTR(ndarray);
if(args[1].u_obj == mp_const_none) {
// at this point, the array holding the sorted values should be flat
uint8_t *array = (uint8_t *)ndarray->array;
size_t len = ndarray->len;
array += (len >> 1) * ndarray->itemsize;
mp_float_t median = ndarray_get_float_value(array, ndarray->dtype);
if(!(len & 0x01)) { // len is an even number
array += ndarray->itemsize;
median += ndarray_get_float_value(array, ndarray->dtype);
median *= 0.5;
}
return mp_obj_new_float(median);
}
return mp_const_none;
} }
MP_DEFINE_CONST_FUN_OBJ_KW(numerical_median_obj, 1, numerical_median); MP_DEFINE_CONST_FUN_OBJ_KW(numerical_median_obj, 1, numerical_median);

View file

@ -115,7 +115,7 @@ extern mp_obj_module_t ulab_numerical_module;
}\ }\
}) })
#define HEAPSORT1(ndarray, type, array, shape, index, increment, N)\ #define HEAPSORT1(type, array, increment, N)\
({\ ({\
type *_array = (type *)array;\ type *_array = (type *)array;\
type tmp;\ type tmp;\
@ -208,7 +208,7 @@ extern mp_obj_module_t ulab_numerical_module;
} while(0) } while(0)
#define HEAPSORT(ndarray, type, array, shape, strides, index, increment, N) do {\ #define HEAPSORT(ndarray, type, array, shape, strides, index, increment, N) do {\
HEAPSORT1((ndarray), type, (array), (shape), (index), (increment), (N));\ HEAPSORT1(type, (array), (increment), (N));\
} while(0) } while(0)
#define HEAP_ARGSORT(ndarray, type, array, shape, strides, index, increment, N, iarray, istrides, iincrement) do {\ #define HEAP_ARGSORT(ndarray, type, array, shape, strides, index, increment, N, iarray, istrides, iincrement) do {\
@ -271,7 +271,7 @@ extern mp_obj_module_t ulab_numerical_module;
#define HEAPSORT(ndarray, type, array, shape, strides, index, increment, N) do {\ #define HEAPSORT(ndarray, type, array, shape, strides, index, increment, N) do {\
size_t l = 0;\ size_t l = 0;\
do {\ do {\
HEAPSORT1((ndarray), type, (array), (shape), (index), (increment), (N));\ HEAPSORT1(type, (array), (increment), (N));\
(array) += (strides)[ULAB_MAX_DIMS - 1];\ (array) += (strides)[ULAB_MAX_DIMS - 1];\
l++;\ l++;\
} while(l < (shape)[ULAB_MAX_DIMS - 1]);\ } while(l < (shape)[ULAB_MAX_DIMS - 1]);\
@ -375,7 +375,7 @@ extern mp_obj_module_t ulab_numerical_module;
do {\ do {\
size_t l = 0;\ size_t l = 0;\
do {\ do {\
HEAPSORT1((ndarray), type, (array), (shape), (index), (increment), (N));\ HEAPSORT1(type, (array), (increment), (N));\
(array) += (strides)[ULAB_MAX_DIMS - 1];\ (array) += (strides)[ULAB_MAX_DIMS - 1];\
l++;\ l++;\
} while(l < (shape)[ULAB_MAX_DIMS - 1]);\ } while(l < (shape)[ULAB_MAX_DIMS - 1]);\
@ -523,7 +523,7 @@ extern mp_obj_module_t ulab_numerical_module;
do {\ do {\
size_t l = 0;\ size_t l = 0;\
do {\ do {\
HEAPSORT1((ndarray), type, (array), (shape), (index), (increment), (N));\ HEAPSORT1(type, (array), (increment), (N));\
(array) += (strides)[ULAB_MAX_DIMS - 1];\ (array) += (strides)[ULAB_MAX_DIMS - 1];\
l++;\ l++;\
} while(l < (shape)[ULAB_MAX_DIMS - 1]);\ } while(l < (shape)[ULAB_MAX_DIMS - 1]);\

View file

@ -32,7 +32,7 @@
#include "user/user.h" #include "user/user.h"
#include "vector/vectorise.h" #include "vector/vectorise.h"
#define ULAB_VERSION 1.1.3 #define ULAB_VERSION 1.2.0
#define xstr(s) str(s) #define xstr(s) str(s)
#define str(s) #s #define str(s) #s
#if ULAB_NUMPY_COMPATIBILITY #if ULAB_NUMPY_COMPATIBILITY