fix vectorize (#568)
This commit is contained in:
parent
42172c6575
commit
e68bb707b2
5 changed files with 109 additions and 6 deletions
|
|
@ -759,12 +759,51 @@ static mp_obj_t vector_vectorized_function_call(mp_obj_t self_in, size_t n_args,
|
|||
if(mp_obj_is_type(args[0], &ulab_ndarray_type)) {
|
||||
ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0]);
|
||||
COMPLEX_DTYPE_NOT_IMPLEMENTED(source->dtype)
|
||||
|
||||
ndarray_obj_t *ndarray = ndarray_new_dense_ndarray(source->ndim, source->shape, self->otypes);
|
||||
for(size_t i=0; i < source->len; i++) {
|
||||
avalue[0] = mp_binary_get_val_array(source->dtype, source->array, i);
|
||||
fvalue = MP_OBJ_TYPE_GET_SLOT(self->type, call)(self->fun, 1, 0, avalue);
|
||||
ndarray_set_value(self->otypes, ndarray->array, i, fvalue);
|
||||
}
|
||||
uint8_t *sarray = (uint8_t *)source->array;
|
||||
uint8_t *narray = (uint8_t *)ndarray->array;
|
||||
|
||||
#if ULAB_MAX_DIMS > 3
|
||||
size_t i = 0;
|
||||
do {
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 2
|
||||
size_t j = 0;
|
||||
do {
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
size_t k = 0;
|
||||
do {
|
||||
#endif
|
||||
size_t l = 0;
|
||||
do {
|
||||
avalue[0] = mp_binary_get_val_array(source->dtype, sarray, 0);
|
||||
fvalue = MP_OBJ_TYPE_GET_SLOT(self->type, call)(self->fun, 1, 0, avalue);
|
||||
ndarray_set_value(self->otypes, narray, 0, fvalue);
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 1];
|
||||
narray += ndarray->itemsize;
|
||||
l++;
|
||||
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 1] * source->shape[ULAB_MAX_DIMS - 1];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 2];
|
||||
k++;
|
||||
} while(k < source->shape[ULAB_MAX_DIMS - 2]);
|
||||
#endif /* ULAB_MAX_DIMS > 1 */
|
||||
#if ULAB_MAX_DIMS > 2
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 2] * source->shape[ULAB_MAX_DIMS - 2];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 3];
|
||||
j++;
|
||||
} while(j < source->shape[ULAB_MAX_DIMS - 3]);
|
||||
#endif /* ULAB_MAX_DIMS > 2 */
|
||||
#if ULAB_MAX_DIMS > 3
|
||||
sarray -= source->strides[ULAB_MAX_DIMS - 3] * source->shape[ULAB_MAX_DIMS - 3];
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 4];
|
||||
i++;
|
||||
} while(i < source->shape[ULAB_MAX_DIMS - 4]);
|
||||
#endif /* ULAB_MAX_DIMS > 3 */
|
||||
|
||||
return MP_OBJ_FROM_PTR(ndarray);
|
||||
} else if(mp_obj_is_type(args[0], &mp_type_tuple) || mp_obj_is_type(args[0], &mp_type_list) ||
|
||||
mp_obj_is_type(args[0], &mp_type_range)) { // i.e., the input is a generic iterable
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@
|
|||
#include "user/user.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
#define ULAB_VERSION 6.0.1
|
||||
#define ULAB_VERSION 6.0.2
|
||||
#define xstr(s) str(s)
|
||||
#define str(s) #s
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
Tue, 3 Jan 2023
|
||||
|
||||
version 6.0.2
|
||||
|
||||
fix vectorize
|
||||
|
||||
Sat, 5 Nov 2022
|
||||
|
||||
version 6.0.1
|
||||
|
|
|
|||
18
tests/2d/numpy/vectorize.py
Normal file
18
tests/2d/numpy/vectorize.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
try:
|
||||
from ulab import numpy as np
|
||||
except:
|
||||
import numpy as np
|
||||
|
||||
|
||||
dtypes = (np.uint8, np.int8, np.uint16, np.int16, np.float)
|
||||
|
||||
square = np.vectorize(lambda n: n*n)
|
||||
|
||||
for dtype in dtypes:
|
||||
a = np.array(range(9), dtype=dtype).reshape((3, 3))
|
||||
print(a)
|
||||
print(square(a))
|
||||
|
||||
b = a[:,2]
|
||||
print(square(b))
|
||||
print()
|
||||
40
tests/2d/numpy/vectorize.py.exp
Normal file
40
tests/2d/numpy/vectorize.py.exp
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
array([[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8]], dtype=uint8)
|
||||
array([[0.0, 1.0, 4.0],
|
||||
[9.0, 16.0, 25.0],
|
||||
[36.0, 49.0, 64.0]], dtype=float64)
|
||||
array([4.0, 25.0, 64.0], dtype=float64)
|
||||
|
||||
array([[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8]], dtype=int8)
|
||||
array([[0.0, 1.0, 4.0],
|
||||
[9.0, 16.0, 25.0],
|
||||
[36.0, 49.0, 64.0]], dtype=float64)
|
||||
array([4.0, 25.0, 64.0], dtype=float64)
|
||||
|
||||
array([[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8]], dtype=uint16)
|
||||
array([[0.0, 1.0, 4.0],
|
||||
[9.0, 16.0, 25.0],
|
||||
[36.0, 49.0, 64.0]], dtype=float64)
|
||||
array([4.0, 25.0, 64.0], dtype=float64)
|
||||
|
||||
array([[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8]], dtype=int16)
|
||||
array([[0.0, 1.0, 4.0],
|
||||
[9.0, 16.0, 25.0],
|
||||
[36.0, 49.0, 64.0]], dtype=float64)
|
||||
array([4.0, 25.0, 64.0], dtype=float64)
|
||||
|
||||
array([[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0]], dtype=float64)
|
||||
array([[0.0, 1.0, 4.0],
|
||||
[9.0, 16.0, 25.0],
|
||||
[36.0, 49.0, 64.0]], dtype=float64)
|
||||
array([4.0, 25.0, 64.0], dtype=float64)
|
||||
|
||||
Loading…
Reference in a new issue