add keyword handling
This commit is contained in:
parent
03c8655b06
commit
ec7caa8c27
1 changed files with 75 additions and 20 deletions
|
|
@ -42,7 +42,7 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
|
|||
#if ULAB_MAX_DIMS > 1
|
||||
// no need to check anything, if the maximum number of dimensions is 1
|
||||
if(input->ndim != 1) {
|
||||
mp_raise_ValueError(MP_ERROR_TEXT("object too deep for desired arrayy"));
|
||||
mp_raise_ValueError(MP_ERROR_TEXT("object too deep for desired array"));
|
||||
}
|
||||
#endif
|
||||
if((input->dtype != NDARRAY_UINT8) && (input->dtype != NDARRAY_UINT16)) {
|
||||
|
|
@ -50,13 +50,13 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
|
|||
}
|
||||
|
||||
// first find the maximum of the array, and figure out how long the result should be
|
||||
uint16_t max = 0;
|
||||
size_t length = 0;
|
||||
int32_t stride = input->strides[ULAB_MAX_DIMS - 1];
|
||||
if(input->dtype == NDARRAY_UINT8) {
|
||||
uint8_t *iarray = (uint8_t *)input->array;
|
||||
for(size_t i = 0; i < input->len; i++) {
|
||||
if(*iarray > max) {
|
||||
max = *iarray;
|
||||
if(*iarray > length) {
|
||||
length = *iarray;
|
||||
}
|
||||
iarray += stride;
|
||||
}
|
||||
|
|
@ -64,17 +64,40 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
|
|||
stride /= 2;
|
||||
uint16_t *iarray = (uint16_t *)input->array;
|
||||
for(size_t i = 0; i < input->len; i++) {
|
||||
if(*iarray > max) {
|
||||
max = *iarray;
|
||||
if(*iarray > length) {
|
||||
length = *iarray;
|
||||
}
|
||||
iarray += stride;
|
||||
}
|
||||
}
|
||||
ndarray_obj_t *result = ndarray_new_linear_array(max + 1, NDARRAY_UINT16);
|
||||
length += 1;
|
||||
|
||||
if(args[2].u_obj != mp_const_none) {
|
||||
int32_t minlength = mp_obj_get_int(args[2].u_obj);
|
||||
if(minlength < 0) {
|
||||
mp_raise_ValueError(MP_ERROR_TEXT("minlength must not be negative"));
|
||||
}
|
||||
if((size_t)minlength > length) {
|
||||
length = minlength;
|
||||
}
|
||||
}
|
||||
|
||||
ndarray_obj_t *result = NULL;
|
||||
ndarray_obj_t *weights = NULL;
|
||||
|
||||
if(args[1].u_obj == mp_const_none) {
|
||||
result = ndarray_new_linear_array(length, NDARRAY_UINT16);
|
||||
} else {
|
||||
if(!mp_obj_is_type(args[1].u_obj, &ulab_ndarray_type)) {
|
||||
mp_raise_TypeError(MP_ERROR_TEXT("input must be an ndarray"));
|
||||
}
|
||||
weights = MP_OBJ_TO_PTR(args[1].u_obj);
|
||||
result = ndarray_new_linear_array(length, NDARRAY_FLOAT);
|
||||
}
|
||||
|
||||
// now we can do the binning
|
||||
if(result->dtype == NDARRAY_UINT16) {
|
||||
uint16_t *rarray = (uint16_t *)result->array;
|
||||
|
||||
if(input->dtype == NDARRAY_UINT8) {
|
||||
uint8_t *iarray = (uint8_t *)input->array;
|
||||
for(size_t i = 0; i < input->len; i++) {
|
||||
|
|
@ -88,7 +111,39 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
|
|||
iarray += stride;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mp_float_t *rarray = (mp_float_t *)result->array;
|
||||
if(input->dtype == NDARRAY_UINT8) {
|
||||
uint8_t *iarray = (uint8_t *)input->array;
|
||||
for(size_t i = 0; i < input->len; i++) {
|
||||
rarray[*iarray] += MICROPY_FLOAT_CONST(1.0);
|
||||
iarray += stride;
|
||||
}
|
||||
} else if(input->dtype == NDARRAY_UINT16) {
|
||||
uint16_t *iarray = (uint16_t *)input->array;
|
||||
for(size_t i = 0; i < input->len; i++) {
|
||||
rarray[*iarray] += MICROPY_FLOAT_CONST(1.0);
|
||||
iarray += stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(weights != NULL) {
|
||||
mp_float_t (*get_weights)(void *) = ndarray_get_float_function(weights->dtype);
|
||||
mp_float_t *rarray = (mp_float_t *)result->array;
|
||||
uint8_t *warray = (uint8_t *)weights->array;
|
||||
|
||||
size_t fill_length = result->len;
|
||||
if(weights->len < result->len) {
|
||||
fill_length = weights->len;
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < fill_length; i++) {
|
||||
*rarray = *rarray * get_weights(warray);
|
||||
rarray++;
|
||||
warray += weights->strides[ULAB_MAX_DIMS - 1];
|
||||
}
|
||||
}
|
||||
return MP_OBJ_FROM_PTR(result);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue