extend convolve for the complex case

This commit is contained in:
Zoltán Vörös 2022-01-05 20:46:03 +01:00
parent 6a7d20dd58
commit 8efdec785e

View file

@ -41,8 +41,6 @@ mp_obj_t filter_convolve(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
ndarray_obj_t *a = MP_OBJ_TO_PTR(args[0].u_obj);
ndarray_obj_t *c = MP_OBJ_TO_PTR(args[1].u_obj);
COMPLEX_DTYPE_NOT_IMPLEMENTED(a->dtype)
COMPLEX_DTYPE_NOT_IMPLEMENTED(c->dtype)
// deal with linear arrays only
#if ULAB_MAX_DIMS > 1
if((a->ndim != 1) || (c->ndim != 1)) {
@ -56,30 +54,77 @@ mp_obj_t filter_convolve(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
}
int len = len_a + len_c - 1; // convolve mode "full"
ndarray_obj_t *out = ndarray_new_linear_array(len, NDARRAY_FLOAT);
mp_float_t *outptr = (mp_float_t *)out->array;
int32_t off = len_c - 1;
uint8_t dtype = NDARRAY_FLOAT;
#if ULAB_SUPPORTS_COMPLEX
if((a->dtype == NDARRAY_COMPLEX) || (c->dtype == NDARRAY_COMPLEX)) {
dtype = NDARRAY_COMPLEX;
}
#endif
ndarray_obj_t *ndarray = ndarray_new_linear_array(len, dtype);
mp_float_t *array = (mp_float_t *)ndarray->array;
uint8_t *aarray = (uint8_t *)a->array;
uint8_t *carray = (uint8_t *)c->array;
int32_t off = len_c - 1;
int32_t as = a->strides[ULAB_MAX_DIMS - 1] / a->itemsize;
int32_t cs = c->strides[ULAB_MAX_DIMS - 1] / c->itemsize;
for(int32_t k=-off; k < len-off; k++) {
mp_float_t accum = (mp_float_t)0.0;
#if ULAB_SUPPORTS_COMPLEX
if(dtype == NDARRAY_COMPLEX) {
mp_float_t a_real, a_imag;
mp_float_t c_real, c_imag = MICROPY_FLOAT_CONST(0.0);
for(int32_t k = -off; k < len-off; k++) {
mp_float_t accum_real = MICROPY_FLOAT_CONST(0.0);
mp_float_t accum_imag = MICROPY_FLOAT_CONST(0.0);
int32_t top_n = MIN(len_c, len_a - k);
int32_t bot_n = MAX(-k, 0);
for(int32_t n = bot_n; n < top_n; n++) {
int32_t idx_c = (len_c - n - 1) * cs;
int32_t idx_a = (n + k) * as;
if(a->dtype != NDARRAY_COMPLEX) {
a_real = ndarray_get_float_index(aarray, a->dtype, idx_a);
a_imag = MICROPY_FLOAT_CONST(0.0);
} else {
a_real = ndarray_get_float_index(aarray, NDARRAY_FLOAT, 2 * idx_a);
a_imag = ndarray_get_float_index(aarray, NDARRAY_FLOAT, 2 * idx_a + 1);
}
if(c->dtype != NDARRAY_COMPLEX) {
c_real = ndarray_get_float_index(carray, c->dtype, idx_c);
c_imag = MICROPY_FLOAT_CONST(0.0);
} else {
c_real = ndarray_get_float_index(carray, NDARRAY_FLOAT, 2 * idx_c);
c_imag = ndarray_get_float_index(carray, NDARRAY_FLOAT, 2 * idx_c + 1);
}
accum_real += a_real * c_real - a_imag * c_imag;
accum_imag += a_real * c_imag + a_imag * c_real;
}
*array++ = accum_real;
*array++ = accum_imag;
}
return MP_OBJ_FROM_PTR(ndarray);
}
#endif
for(int32_t k = -off; k < len-off; k++) {
mp_float_t accum = MICROPY_FLOAT_CONST(0.0);
int32_t top_n = MIN(len_c, len_a - k);
int32_t bot_n = MAX(-k, 0);
for(int32_t n=bot_n; n < top_n; n++) {
for(int32_t n = bot_n; n < top_n; n++) {
int32_t idx_c = (len_c - n - 1) * cs;
int32_t idx_a = (n + k) * as;
mp_float_t ai = ndarray_get_float_index(aarray, a->dtype, idx_a);
mp_float_t ci = ndarray_get_float_index(carray, c->dtype, idx_c);
accum += ai * ci;
}
*outptr++ = accum;
*array++ = accum;
}
return out;
return MP_OBJ_FROM_PTR(ndarray);
}
MP_DEFINE_CONST_FUN_OBJ_KW(filter_convolve_obj, 2, filter_convolve);