Merge pull request #491 from v923z/sosfilt-fix

fix scipy.signal.sosfilt
This commit is contained in:
Zoltán Vörös 2022-01-21 17:29:40 +01:00 committed by GitHub
commit 6dfdc44c4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 5 deletions

View file

@ -53,7 +53,7 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(signal_spectrogram_obj, 1, 2, signal_spectro
#endif /* ULAB_SCIPY_SIGNAL_HAS_SPECTROGRAM */
#if ULAB_SCIPY_SIGNAL_HAS_SOSFILT
#if ULAB_SCIPY_SIGNAL_HAS_SOSFILT & ULAB_MAX_DIMS > 1
static void signal_sosfilt_array(mp_float_t *x, const mp_float_t *coeffs, mp_float_t *zf, const size_t len) {
for(size_t i=0; i < len; i++) {
mp_float_t xn = *x;
@ -118,7 +118,7 @@ mp_obj_t signal_sosfilt(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_ar
mp_raise_TypeError(translate("zi must be an ndarray"));
} else {
ndarray_obj_t *zi = MP_OBJ_TO_PTR(args[2].u_obj);
if((zi->shape[ULAB_MAX_DIMS - 1] != lensos) || (zi->shape[ULAB_MAX_DIMS - 1] != 2)) {
if((zi->shape[ULAB_MAX_DIMS - 2] != lensos) || (zi->shape[ULAB_MAX_DIMS - 1] != 2)) {
mp_raise_ValueError(translate("zi must be of shape (n_section, 2)"));
}
if(zi->dtype != NDARRAY_FLOAT) {
@ -158,7 +158,7 @@ static const mp_rom_map_elem_t ulab_scipy_signal_globals_table[] = {
#if ULAB_SCIPY_SIGNAL_HAS_SPECTROGRAM
{ MP_OBJ_NEW_QSTR(MP_QSTR_spectrogram), (mp_obj_t)&signal_spectrogram_obj },
#endif
#if ULAB_SCIPY_SIGNAL_HAS_SOSFILT
#if ULAB_SCIPY_SIGNAL_HAS_SOSFILT & ULAB_MAX_DIMS > 1
{ MP_OBJ_NEW_QSTR(MP_QSTR_sosfilt), (mp_obj_t)&signal_sosfilt_obj },
#endif
};

View file

@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"
#define ULAB_VERSION 4.3.0
#define ULAB_VERSION 4.3.1
#define xstr(s) str(s)
#define str(s) #s

View file

@ -1,6 +1,12 @@
Wed, 19 Jan 2022
version 4.3.0
version 4.3.1
fix signal.sosfilt
Wed, 19 Jan 2022
version 4.3.0
implement numpy.save, numpy.load

13
tests/2d/scipy/sosfilt.py Normal file
View file

@ -0,0 +1,13 @@
try:
from ulab import numpy as np
from ulab import scipy as spy
except:
import numpy as np
import scipy as spy
x = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
sos = [[1, 2, 3, 1, 5, 6], [1, 2, 3, 1, 5, 6], [1, 2, 3, 1, 5, 6]]
zi = np.array([[1, 2], [3, 4], [5, 6]],dtype=np.float)
y, zo = spy.signal.sosfilt(sos, x, zi=zi)
print('y: ', y)
print('zo: ', zo)

View file

@ -0,0 +1,4 @@
y: array([9.0, -47.0, 224.0, -987.0000000000001, 4129.0, -16549.0, 64149.0, -241937.0, 892121.0, -3228165.0], dtype=float64)
zo: array([[37242.0, 74835.0],
[1026187.0, 1936542.0],
[10433318.0, 18382017.0]], dtype=float64)