diff --git a/code/numpy/compare.c b/code/numpy/compare.c index f72fde8..42f015b 100644 --- a/code/numpy/compare.c +++ b/code/numpy/compare.c @@ -6,7 +6,7 @@ * * The MIT License (MIT) * - * Copyright (c) 2020-2021 Zoltán Vörös + * Copyright (c) 2020-2025 Zoltán Vörös * 2020 Jeff Epler for Adafruit Industries */ @@ -27,14 +27,69 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { static const mp_arg_t allowed_args[] = { { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } , - { MP_QSTR_weights, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, - { MP_QSTR_minlength, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, + { MP_QSTR_weights, MP_ARG_OBJ | MP_ARG_KW_ONLY, { .u_rom_obj = MP_ROM_NONE } }, + { MP_QSTR_minlength, MP_ARG_OBJ | MP_ARG_KW_ONLY, { .u_rom_obj = MP_ROM_NONE } }, }; mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); - return mp_const_none; + if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) { + mp_raise_TypeError(MP_ERROR_TEXT("input must be an ndarray")); + } + ndarray_obj_t *input = MP_OBJ_TO_PTR(args[0].u_obj); + + #if ULAB_MAX_DIMS > 1 + // no need to check anything, if the maximum number of dimensions is 1 + if(input->ndim != 1) { + mp_raise_ValueError(MP_ERROR_TEXT("object too deep for desired arrayy")); + } + #endif + if((input->dtype != NDARRAY_UINT8) && (input->dtype != NDARRAY_UINT16)) { + mp_raise_TypeError(MP_ERROR_TEXT("cannot cast array data from dtype")); + } + + // first find the maximum of the array, and figure out how long the result should be + uint16_t max = 0; + int32_t stride = input->strides[ULAB_MAX_DIMS - 1]; + if(input->dtype == NDARRAY_UINT8) { + uint8_t *iarray = (uint8_t *)input->array; + for(size_t i = 0; i < input->len; i++) { + if(*iarray > max) { + max = *iarray; + } + iarray += stride; + } + } else if(input->dtype == NDARRAY_UINT16) { + stride /= 2; + uint16_t *iarray = (uint16_t *)input->array; + for(size_t i = 0; i < input->len; i++) { + if(*iarray > max) { + max = *iarray; + } + iarray += stride; + } + } + ndarray_obj_t *result = ndarray_new_linear_array(max + 1, NDARRAY_UINT16); + + // now we can do the binning + uint16_t *rarray = (uint16_t *)result->array; + + if(input->dtype == NDARRAY_UINT8) { + uint8_t *iarray = (uint8_t *)input->array; + for(size_t i = 0; i < input->len; i++) { + rarray[*iarray] += 1; + iarray += stride; + } + } else if(input->dtype == NDARRAY_UINT16) { + uint16_t *iarray = (uint16_t *)input->array; + for(size_t i = 0; i < input->len; i++) { + rarray[*iarray] += 1; + iarray += stride; + } + } + + return MP_OBJ_FROM_PTR(result); } MP_DEFINE_CONST_FUN_OBJ_KW(compare_bincount_obj, 1, compare_bincount); diff --git a/code/numpy/compare.h b/code/numpy/compare.h index 668903e..4169a64 100644 --- a/code/numpy/compare.h +++ b/code/numpy/compare.h @@ -6,7 +6,7 @@ * * The MIT License (MIT) * - * Copyright (c) 2020-2021 Zoltán Vörös + * Copyright (c) 2020-2025 Zoltán Vörös */ #ifndef _COMPARE_