diff --git a/code/numpy/linalg/linalg.c b/code/numpy/linalg/linalg.c index 7dbc257..e62a092 100644 --- a/code/numpy/linalg/linalg.c +++ b/code/numpy/linalg/linalg.c @@ -28,21 +28,6 @@ //| """Linear algebra functions""" //| -#if ULAB_MAX_DIMS > 1 -static ndarray_obj_t *linalg_object_is_square(mp_obj_t obj) { - // Returns an ndarray, if the object is a square ndarray, - // raises the appropriate exception otherwise - if(!MP_OBJ_IS_TYPE(obj, &ulab_ndarray_type)) { - mp_raise_TypeError(translate("size is defined for ndarrays only")); - } - ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(obj); - if((ndarray->shape[ULAB_MAX_DIMS - 1] != ndarray->shape[ULAB_MAX_DIMS - 2]) || (ndarray->ndim != 2)) { - mp_raise_ValueError(translate("input must be square matrix")); - } - return ndarray; -} -#endif - #if ULAB_MAX_DIMS > 1 //| def cholesky(A: ulab.array) -> ulab.array: //| """ @@ -55,7 +40,7 @@ static ndarray_obj_t *linalg_object_is_square(mp_obj_t obj) { //| static mp_obj_t linalg_cholesky(mp_obj_t oin) { - ndarray_obj_t *ndarray = linalg_object_is_square(oin); + ndarray_obj_t *ndarray = tools_object_is_square(oin); ndarray_obj_t *L = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, ndarray->shape[ULAB_MAX_DIMS - 1], ndarray->shape[ULAB_MAX_DIMS - 1]), NDARRAY_FLOAT); mp_float_t *Larray = (mp_float_t *)L->array; @@ -121,7 +106,7 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_cholesky_obj, linalg_cholesky); //| static mp_obj_t linalg_det(mp_obj_t oin) { - ndarray_obj_t *ndarray = linalg_object_is_square(oin); + ndarray_obj_t *ndarray = tools_object_is_square(oin); uint8_t *array = (uint8_t *)ndarray->array; size_t N = ndarray->shape[ULAB_MAX_DIMS - 1]; mp_float_t *tmp = m_new(mp_float_t, N * N); @@ -193,7 +178,7 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_det_obj, linalg_det); //| static mp_obj_t linalg_eig(mp_obj_t oin) { - ndarray_obj_t *in = linalg_object_is_square(oin); + ndarray_obj_t *in = tools_object_is_square(oin); uint8_t *iarray = (uint8_t *)in->array; size_t S = in->shape[ULAB_MAX_DIMS - 1]; mp_float_t *array = m_new(mp_float_t, S*S); @@ -254,7 +239,7 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_eig_obj, linalg_eig); //| ... //| static mp_obj_t linalg_inv(mp_obj_t o_in) { - ndarray_obj_t *ndarray = linalg_object_is_square(o_in); + ndarray_obj_t *ndarray = tools_object_is_square(o_in); uint8_t *array = (uint8_t *)ndarray->array; size_t N = ndarray->shape[ULAB_MAX_DIMS - 1]; ndarray_obj_t *inverted = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, N, N), NDARRAY_FLOAT); @@ -378,34 +363,6 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k MP_DEFINE_CONST_FUN_OBJ_KW(linalg_norm_obj, 1, linalg_norm); // MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm); -#if ULAB_MAX_DIMS > 1 -#if ULAB_LINALG_HAS_TRACE - -//| def trace(m: ulab.array) -> float: -//| """ -//| :param m: a square matrix -//| -//| Compute the trace of the matrix, the sum of its diagonal elements.""" -//| ... -//| - -static mp_obj_t linalg_trace(mp_obj_t oin) { - ndarray_obj_t *ndarray = linalg_object_is_square(oin); - mp_float_t trace = 0.0; - for(size_t i=0; i < ndarray->shape[ULAB_MAX_DIMS - 1]; i++) { - int32_t pos = i * (ndarray->strides[ULAB_MAX_DIMS - 1] + ndarray->strides[ULAB_MAX_DIMS - 2]); - trace += ndarray_get_float_index(ndarray->array, ndarray->dtype, pos/ndarray->itemsize); - } - if(ndarray->dtype == NDARRAY_FLOAT) { - return mp_obj_new_float(trace); - } - return mp_obj_new_int_from_float(trace); -} - -MP_DEFINE_CONST_FUN_OBJ_1(linalg_trace_obj, linalg_trace); -#endif -#endif - STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = { { MP_OBJ_NEW_QSTR(MP_QSTR___name__), MP_OBJ_NEW_QSTR(MP_QSTR_linalg) }, #if ULAB_MAX_DIMS > 1 @@ -421,9 +378,6 @@ STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = { #if ULAB_LINALG_HAS_INV { MP_ROM_QSTR(MP_QSTR_inv), (mp_obj_t)&linalg_inv_obj }, #endif - #if ULAB_LINALG_HAS_TRACE - { MP_ROM_QSTR(MP_QSTR_trace), (mp_obj_t)&linalg_trace_obj }, - #endif #endif #if ULAB_LINALG_HAS_NORM { MP_ROM_QSTR(MP_QSTR_norm), (mp_obj_t)&linalg_norm_obj }, diff --git a/code/numpy/linalg/linalg.h b/code/numpy/linalg/linalg.h index d873b93..fc1867f 100644 --- a/code/numpy/linalg/linalg.h +++ b/code/numpy/linalg/linalg.h @@ -22,6 +22,5 @@ MP_DECLARE_CONST_FUN_OBJ_1(linalg_cholesky_obj); MP_DECLARE_CONST_FUN_OBJ_1(linalg_det_obj); MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj); MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj); -MP_DECLARE_CONST_FUN_OBJ_1(linalg_trace_obj); MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj); #endif diff --git a/code/numpy/numpy.c b/code/numpy/numpy.c index 7242d3f..804ccaf 100644 --- a/code/numpy/numpy.c +++ b/code/numpy/numpy.c @@ -173,6 +173,9 @@ static const mp_rom_map_elem_t ulab_numpy_globals_table[] = { #if ULAB_NUMPY_HAS_DOT { MP_OBJ_NEW_QSTR(MP_QSTR_dot), (mp_obj_t)&transform_dot_obj }, #endif + #if ULAB_NUMPY_HAS_TRACE + { MP_ROM_QSTR(MP_QSTR_trace), (mp_obj_t)&stats_trace_obj }, + #endif #if ULAB_NUMPY_HAS_FLIP { MP_OBJ_NEW_QSTR(MP_QSTR_flip), (mp_obj_t)&numerical_flip_obj }, #endif diff --git a/code/numpy/stats/stats.c b/code/numpy/stats/stats.c index b863602..8022ebe 100644 --- a/code/numpy/stats/stats.c +++ b/code/numpy/stats/stats.c @@ -22,3 +22,31 @@ #include "../../ulab.h" #include "../../ulab_tools.h" #include "stats.h" + +#if ULAB_MAX_DIMS > 1 +#if ULAB_NUMPY_HAS_TRACE + +//| def trace(m: ulab.array) -> float: +//| """ +//| :param m: a square matrix +//| +//| Compute the trace of the matrix, the sum of its diagonal elements.""" +//| ... +//| + +static mp_obj_t stats_trace(mp_obj_t oin) { + ndarray_obj_t *ndarray = tools_object_is_square(oin); + mp_float_t trace = 0.0; + for(size_t i=0; i < ndarray->shape[ULAB_MAX_DIMS - 1]; i++) { + int32_t pos = i * (ndarray->strides[ULAB_MAX_DIMS - 1] + ndarray->strides[ULAB_MAX_DIMS - 2]); + trace += ndarray_get_float_index(ndarray->array, ndarray->dtype, pos/ndarray->itemsize); + } + if(ndarray->dtype == NDARRAY_FLOAT) { + return mp_obj_new_float(trace); + } + return mp_obj_new_int_from_float(trace); +} + +MP_DEFINE_CONST_FUN_OBJ_1(stats_trace_obj, stats_trace); +#endif +#endif diff --git a/code/numpy/stats/stats.h b/code/numpy/stats/stats.h index 4fd519d..e5fab5f 100644 --- a/code/numpy/stats/stats.h +++ b/code/numpy/stats/stats.h @@ -15,4 +15,6 @@ #include "../../ulab.h" #include "../../ndarray.h" +MP_DECLARE_CONST_FUN_OBJ_1(stats_trace_obj); + #endif diff --git a/code/ulab.c b/code/ulab.c index 1908a34..445e60b 100644 --- a/code/ulab.c +++ b/code/ulab.c @@ -33,7 +33,7 @@ #include "user/user.h" -#define ULAB_VERSION 2.3.5 +#define ULAB_VERSION 2.3.6 #define xstr(s) str(s) #define str(s) #s #define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D) diff --git a/code/ulab.h b/code/ulab.h index 1043932..88b30b9 100644 --- a/code/ulab.h +++ b/code/ulab.h @@ -353,10 +353,6 @@ #define ULAB_LINALG_HAS_NORM (1) #endif -#ifndef ULAB_LINALG_HAS_TRACE -#define ULAB_LINALG_HAS_TRACE (1) -#endif - // the FFT module; functions of the fft module still have // to be defined separately #ifndef ULAB_NUMPY_HAS_FFT_MODULE @@ -447,6 +443,10 @@ #define ULAB_NUMPY_HAS_SUM (1) #endif +#ifndef ULAB_NUMPY_HAS_TRACE +#define ULAB_NUMPY_HAS_TRACE (1) +#endif + #ifndef ULAB_NUMPY_HAS_TRAPZ #define ULAB_NUMPY_HAS_TRAPZ (1) #endif diff --git a/code/ulab_tools.c b/code/ulab_tools.c index 0665a0c..9663d3d 100644 --- a/code/ulab_tools.c +++ b/code/ulab_tools.c @@ -212,3 +212,19 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) { return _shape_strides; } + + +#if ULAB_MAX_DIMS > 1 +ndarray_obj_t *tools_object_is_square(mp_obj_t obj) { + // Returns an ndarray, if the object is a square ndarray, + // raises the appropriate exception otherwise + if(!MP_OBJ_IS_TYPE(obj, &ulab_ndarray_type)) { + mp_raise_TypeError(translate("size is defined for ndarrays only")); + } + ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(obj); + if((ndarray->shape[ULAB_MAX_DIMS - 1] != ndarray->shape[ULAB_MAX_DIMS - 2]) || (ndarray->ndim != 2)) { + mp_raise_ValueError(translate("input must be square matrix")); + } + return ndarray; +} +#endif diff --git a/code/ulab_tools.h b/code/ulab_tools.h index 23e6d7b..378e4f0 100644 --- a/code/ulab_tools.h +++ b/code/ulab_tools.h @@ -33,4 +33,5 @@ uint8_t ndarray_upcast_dtype(uint8_t , uint8_t ); void *ndarray_set_float_function(uint8_t ); shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t ); +ndarray_obj_t *tools_object_is_square(mp_obj_t ); #endif diff --git a/tests/numpy/linalg.py b/tests/numpy/linalg.py index 7ae1059..7d64c2e 100644 --- a/tests/numpy/linalg.py +++ b/tests/numpy/linalg.py @@ -91,8 +91,3 @@ result = (np.linalg.norm(a,axis=1)) # fails on low tolerance ref_result = np.array([2.236068, 7.071068, 10.24695, 9.797959]) for i in range(4): print(math.isclose(result[i], ref_result[i], rel_tol=1E-6, abs_tol=1E-6)) - -if use_ulab: - print(np.linalg.trace(np.eye(3))) -else: - print(np.trace(np.eye(3))) diff --git a/tests/numpy/linalg.py.exp b/tests/numpy/linalg.py.exp index 57965a1..d5d1d99 100644 --- a/tests/numpy/linalg.py.exp +++ b/tests/numpy/linalg.py.exp @@ -51,4 +51,3 @@ True True True True -3.0