Polyval handles non-array as second argument (#601)

* Factorize polynomial evaluation

* Polyval handles non-array as second argument

---------

Co-authored-by: Zoltán Vörös <zvoros@gmail.com>
This commit is contained in:
HugoNumworks 2023-06-27 21:13:53 +02:00 committed by GitHub
parent 319df10cfe
commit 112d4f82d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 17 deletions

View file

@ -145,9 +145,18 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(poly_polyfit_obj, 2, 3, poly_polyfit);
#if ULAB_NUMPY_HAS_POLYVAL
static mp_float_t poly_eval(mp_float_t x, mp_float_t *p, uint8_t plen) {
mp_float_t y = p[0];
for(uint8_t j=0; j < plen-1; j++) {
y *= x;
y += p[j+1];
}
return y;
}
mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
if(!ndarray_object_is_array_like(o_p) || !ndarray_object_is_array_like(o_x)) {
mp_raise_TypeError(translate("inputs are not iterable"));
if(!ndarray_object_is_array_like(o_p)) {
mp_raise_TypeError(translate("input is not iterable"));
}
#if ULAB_SUPPORTS_COMPLEX
ndarray_obj_t *input;
@ -171,6 +180,10 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
i++;
}
if(!ndarray_object_is_array_like(o_x)) {
return mp_obj_new_float(poly_eval(mp_obj_get_float(o_x), p, plen));
}
// polynomials are going to be of type float, except, when both
// the coefficients and the independent variable are integers
ndarray_obj_t *ndarray;
@ -198,13 +211,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
#endif
size_t l = 0;
do {
mp_float_t y = p[0];
mp_float_t _x = func(sarray);
for(uint8_t m=0; m < plen-1; m++) {
y *= _x;
y += p[m+1];
}
*array++ = y;
*array++ = poly_eval(func(sarray), p, plen);
sarray += source->strides[ULAB_MAX_DIMS - 1];
l++;
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
@ -233,13 +240,7 @@ mp_obj_t poly_polyval(mp_obj_t o_p, mp_obj_t o_x) {
mp_obj_iter_buf_t x_buf;
mp_obj_t x_item, x_iterable = mp_getiter(o_x, &x_buf);
while ((x_item = mp_iternext(x_iterable)) != MP_OBJ_STOP_ITERATION) {
mp_float_t _x = mp_obj_get_float(x_item);
mp_float_t y = p[0];
for(uint8_t j=0; j < plen-1; j++) {
y *= _x;
y += p[j+1];
}
*array++ = y;
*array++ = poly_eval(mp_obj_get_float(x_item), p, plen);
}
}
m_del(mp_float_t, p, plen);

View file

@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"
#define ULAB_VERSION 6.3.2
#define ULAB_VERSION 6.3.3
#define xstr(s) str(s)
#define str(s) #s