add normal distribution
This commit is contained in:
parent
8f5e329700
commit
12f4647d41
3 changed files with 90 additions and 0 deletions
|
|
@ -8,6 +8,8 @@
|
|||
* Copyright (c) 2024 Zoltán Vörös
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "py/builtin.h"
|
||||
#include "py/obj.h"
|
||||
#include "py/runtime.h"
|
||||
|
|
@ -19,6 +21,9 @@ ULAB_DEFINE_FLOAT_CONST(random_one, MICROPY_FLOAT_CONST(1.0), 0x3f800000UL, 0x3f
|
|||
|
||||
// methods of the Generator object
|
||||
static const mp_rom_map_elem_t random_generator_locals_dict_table[] = {
|
||||
#if ULAB_NUMPY_RANDOM_HAS_NORMAL
|
||||
{ MP_ROM_QSTR(MP_QSTR_normal), MP_ROM_PTR(&random_normal_obj) },
|
||||
#endif
|
||||
#if ULAB_NUMPY_RANDOM_HAS_RANDOM
|
||||
{ MP_ROM_QSTR(MP_QSTR_random), MP_ROM_PTR(&random_random_obj) },
|
||||
#endif
|
||||
|
|
@ -118,6 +123,86 @@ static inline uint64_t pcg32_next64(uint64_t *state) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#if ULAB_NUMPY_RANDOM_HAS_NORMAL
|
||||
static mp_obj_t random_normal(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||
static const mp_arg_t allowed_args[] = {
|
||||
{ MP_QSTR_, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
|
||||
{ MP_QSTR_loc, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = ULAB_REFERENCE_FLOAT_CONST(random_zero) } },
|
||||
{ MP_QSTR_scale, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = ULAB_REFERENCE_FLOAT_CONST(random_one) } },
|
||||
{ MP_QSTR_size, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
|
||||
};
|
||||
|
||||
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||
|
||||
random_generator_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj);
|
||||
mp_float_t loc = mp_obj_get_float(args[1].u_obj);
|
||||
mp_float_t scale = mp_obj_get_float(args[2].u_obj);
|
||||
mp_obj_t size = args[3].u_obj;
|
||||
|
||||
ndarray_obj_t *ndarray = NULL;
|
||||
mp_float_t u, v, value;
|
||||
|
||||
if(size != mp_const_none) {
|
||||
if(mp_obj_is_int(size)) {
|
||||
ndarray = ndarray_new_linear_array((size_t)mp_obj_get_int(size), NDARRAY_FLOAT);
|
||||
} else if(mp_obj_is_type(size, &mp_type_tuple)) {
|
||||
mp_obj_tuple_t *_shape = MP_OBJ_TO_PTR(size);
|
||||
if(_shape->len > ULAB_MAX_DIMS) {
|
||||
mp_raise_ValueError(MP_ERROR_TEXT("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS)));
|
||||
}
|
||||
ndarray = ndarray_new_ndarray_from_tuple(_shape, NDARRAY_FLOAT);
|
||||
} else { // input type not supported
|
||||
mp_raise_TypeError(MP_ERROR_TEXT("shape must be None, and integer or a tuple of integers"));
|
||||
}
|
||||
} else {
|
||||
// return single value
|
||||
#if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT
|
||||
uint32_t x = pcg32_next(&self->state);
|
||||
u = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
|
||||
x = pcg32_next(&self->state);
|
||||
v = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
|
||||
#else
|
||||
uint64_t x = pcg32_next64(&self->state);
|
||||
u = (double)(int64_t)(x >> 11) * 0x1.0p-53;
|
||||
x = pcg32_next64(&self->state);
|
||||
v = (double)(int64_t)(x >> 11) * 0x1.0p-53;
|
||||
#endif
|
||||
mp_float_t sqrt_log = MICROPY_FLOAT_C_FUN(sqrt)(-MICROPY_FLOAT_CONST(2.0) * MICROPY_FLOAT_C_FUN(log)(u));
|
||||
value = sqrt_log * MICROPY_FLOAT_C_FUN(cos)(MICROPY_FLOAT_CONST(2.0) * MP_PI * v);
|
||||
return mp_obj_new_float(loc + scale * value);
|
||||
}
|
||||
|
||||
mp_float_t *array = (mp_float_t *)ndarray->array;
|
||||
|
||||
// numpy's random supports only dense output arrays, so we can simply
|
||||
// loop through the elements in a linear fashion
|
||||
for(size_t i = 0; i < ndarray->len; i = i + 2) {
|
||||
#if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT
|
||||
uint32_t x = pcg32_next(&self->state);
|
||||
u = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
|
||||
x = pcg32_next(&self->state);
|
||||
v = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
|
||||
#else
|
||||
uint64_t x = pcg32_next64(&self->state);
|
||||
u = (double)(int64_t)(x >> 11) * 0x1.0p-53;
|
||||
x = pcg32_next64(&self->state);
|
||||
v = (double)(int64_t)(x >> 11) * 0x1.0p-53;
|
||||
#endif
|
||||
mp_float_t sqrt_log = MICROPY_FLOAT_C_FUN(sqrt)(-MICROPY_FLOAT_CONST(2.0) * MICROPY_FLOAT_C_FUN(log)(u));
|
||||
value = sqrt_log * MICROPY_FLOAT_C_FUN(cos)(MICROPY_FLOAT_CONST(2.0) * MP_PI * v);
|
||||
*array++ = loc + scale * value;
|
||||
if((i & 1) == 0) {
|
||||
value = sqrt_log * MICROPY_FLOAT_C_FUN(sin)(MICROPY_FLOAT_CONST(2.0) * MP_PI * v);
|
||||
*array++ = loc + scale * value;
|
||||
}
|
||||
}
|
||||
return MP_OBJ_FROM_PTR(ndarray);
|
||||
}
|
||||
|
||||
MP_DEFINE_CONST_FUN_OBJ_KW(random_normal_obj, 1, random_normal);
|
||||
#endif /* ULAB_NUMPY_RANDOM_HAS_NORMAL */
|
||||
|
||||
#if ULAB_NUMPY_RANDOM_HAS_RANDOM
|
||||
static mp_obj_t random_random(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||
static const mp_arg_t allowed_args[] = {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ mp_obj_t random_generator_make_new(const mp_obj_type_t *, size_t , size_t , cons
|
|||
void random_generator_print(const mp_print_t *, mp_obj_t , mp_print_kind_t );
|
||||
|
||||
|
||||
MP_DECLARE_CONST_FUN_OBJ_KW(random_normal_obj);
|
||||
MP_DECLARE_CONST_FUN_OBJ_KW(random_random_obj);
|
||||
MP_DECLARE_CONST_FUN_OBJ_KW(random_uniform_obj);
|
||||
|
||||
|
|
|
|||
|
|
@ -702,6 +702,10 @@
|
|||
#define ULAB_NUMPY_HAS_RANDOM_MODULE (1)
|
||||
#endif
|
||||
|
||||
#ifndef ULAB_NUMPY_RANDOM_HAS_NORMAL
|
||||
#define ULAB_NUMPY_RANDOM_HAS_NORMAL (1)
|
||||
#endif
|
||||
|
||||
#ifndef ULAB_NUMPY_RANDOM_HAS_RANDOM
|
||||
#define ULAB_NUMPY_RANDOM_HAS_RANDOM (1)
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue