fix the keyword handling in sqrt

This commit is contained in:
Zoltán Vörös 2023-05-27 21:47:55 +02:00
parent 8f0fca769a
commit 4524d6871f
2 changed files with 26 additions and 23 deletions

View file

@ -37,13 +37,21 @@
static mp_obj_t vector_generic_vector(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args, mp_float_t (*f)(mp_float_t)) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
{ MP_QSTR_out, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
// this keyword argument is not used; it's only here, so that functions that
// support the complex dtype can call vector_generic_vector directly
{ MP_QSTR_dtype, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
};
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
mp_obj_t o_in = args[0].u_obj;
// Return a single value, if o_in is not iterable
if(mp_obj_is_float(o_in) || mp_obj_is_int(o_in)) {
return mp_obj_new_float(f(mp_obj_get_float(o_in)));
}
mp_obj_t out = args[1].u_obj;
ndarray_obj_t *target = NULL;
@ -252,7 +260,7 @@ mp_obj_t vector_around(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
{ MP_QSTR_decimals, MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0 } },
#if ULAB_MATH_FUNCTIONS_OUT_KEYWORD
{ MP_QSTR_out, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }
{ MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }
#endif
};
@ -317,11 +325,7 @@ mp_obj_t vector_around(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
return MP_OBJ_FROM_PTR(ndarray);
}
#if ULAB_MATH_FUNCTIONS_OUT_KEYWORD
MP_DEFINE_CONST_FUN_OBJ_KW(vector_around_obj, 2, vector_around);
#else
MP_DEFINE_CONST_FUN_OBJ_KW(vector_around_obj, 1, vector_around);
#endif /* ULAB_MATH_FUNCTIONS_OUT_KEYWORD */
#endif /* ULAB_NUMPY_HAS_AROUND */
#if ULAB_NUMPY_HAS_ATAN
@ -556,7 +560,7 @@ static mp_obj_t vector_exp(mp_obj_t o_in) {
// since the complex case is dissimilar to the rest, we've got to do the parsing of the keywords here
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
{ MP_QSTR_out, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
};
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
@ -822,13 +826,15 @@ MP_DEFINE_CONST_FUN_OBJ_1(vector_sinh_obj, vector_sinh);
//| ...
//|
#if ULAB_SUPPORTS_COMPLEX
#if ULAB_SUPPORTS_COMPLEX | ULAB_MATH_FUNCTIONS_OUT_KEYWORD
mp_obj_t vector_sqrt(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_dtype, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_INT(NDARRAY_FLOAT) } },
#if ULAB_MATH_FUNCTIONS_OUT_KEYWORD
{ MP_QSTR_out, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
#endif
#if ULAB_SUPPORTS_COMPLEX
{ MP_QSTR_dtype, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_INT(NDARRAY_FLOAT) } },
#endif
};
@ -836,11 +842,13 @@ mp_obj_t vector_sqrt(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
mp_obj_t o_in = args[0].u_obj;
uint8_t dtype = mp_obj_get_int(args[1].u_obj);
#if ULAB_SUPPORTS_COMPLEX
uint8_t dtype = mp_obj_get_int(args[2].u_obj);
if((dtype != NDARRAY_FLOAT) && (dtype != NDARRAY_COMPLEX)) {
mp_raise_TypeError(translate("dtype must be float, or complex"));
}
if(mp_obj_is_type(o_in, &mp_type_complex)) {
mp_float_t real, imag;
mp_obj_get_complex(o_in, &real, &imag);
@ -856,7 +864,7 @@ mp_obj_t vector_sqrt(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
if(dtype == NDARRAY_COMPLEX) {
#if ULAB_MATH_FUNCTIONS_OUT_KEYWORD
mp_obj_t out = args[2].u_obj;
mp_obj_t out = args[1].u_obj;
if(out != mp_const_none) {
mp_raise_ValueError(translate("out keyword is not supported for complex dtype"));
}
@ -964,21 +972,15 @@ mp_obj_t vector_sqrt(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
}
}
}
#if ULAB_MATH_FUNCTIONS_OUT_KEYWORD
#endif /* ULAB_SUPPORTS_COMPLEX */
return vector_generic_vector(n_args, pos_args, kw_args, MICROPY_FLOAT_C_FUN(sqrt));
#else
return vector_generic_vector(o_in, MICROPY_FLOAT_C_FUN(sqrt));
#endif /* ULAB_MATH_FUNCTIONS_OUT_KEYWORD */
}
#if ULAB_MATH_FUNCTIONS_OUT_KEYWORD
MP_DEFINE_CONST_FUN_OBJ_KW(vector_sqrt_obj, 2, vector_sqrt);
#else
MP_DEFINE_CONST_FUN_OBJ_KW(vector_sqrt_obj, 1, vector_sqrt);
#endif /* ULAB_MATH_FUNCTIONS_OUT_KEYWORD */
#else
MATH_FUN_1(sqrt, sqrt);
MP_DEFINE_CONST_FUN_OBJ_1(vector_sqrt_obj, vector_sqrt);
#endif /* ULAB_SUPPORTS_COMPLEX */
#endif /* ULAB_MATH_FUNCTIONS_OUT_KEYWORD | ULAB_SUPPORTS_COMPLEX */
#endif /* ULAB_NUMPY_HAS_SQRT */
#if ULAB_NUMPY_HAS_TAN

View file

@ -75,11 +75,12 @@ MP_DECLARE_CONST_FUN_OBJ_1(vector_tanh_obj);
MP_DECLARE_CONST_FUN_OBJ_2(vector_arctan2_obj);
MP_DECLARE_CONST_FUN_OBJ_KW(vector_around_obj);
#if ULAB_SUPPORTS_COMPLEX
#if ULAB_SUPPORTS_COMPLEX | ULAB_MATH_FUNCTIONS_OUT_KEYWORD
MP_DECLARE_CONST_FUN_OBJ_KW(vector_sqrt_obj);
#else
MP_DECLARE_CONST_FUN_OBJ_1(vector_sqrt_obj);
#endif
MP_DECLARE_CONST_FUN_OBJ_KW(vector_vectorize_obj);
typedef struct _vectorized_function_obj_t {