moved trace to numpy
This commit is contained in:
parent
701ad767c8
commit
317943b970
11 changed files with 59 additions and 62 deletions
|
|
@ -28,21 +28,6 @@
|
||||||
//| """Linear algebra functions"""
|
//| """Linear algebra functions"""
|
||||||
//|
|
//|
|
||||||
|
|
||||||
#if ULAB_MAX_DIMS > 1
|
|
||||||
static ndarray_obj_t *linalg_object_is_square(mp_obj_t obj) {
|
|
||||||
// Returns an ndarray, if the object is a square ndarray,
|
|
||||||
// raises the appropriate exception otherwise
|
|
||||||
if(!MP_OBJ_IS_TYPE(obj, &ulab_ndarray_type)) {
|
|
||||||
mp_raise_TypeError(translate("size is defined for ndarrays only"));
|
|
||||||
}
|
|
||||||
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(obj);
|
|
||||||
if((ndarray->shape[ULAB_MAX_DIMS - 1] != ndarray->shape[ULAB_MAX_DIMS - 2]) || (ndarray->ndim != 2)) {
|
|
||||||
mp_raise_ValueError(translate("input must be square matrix"));
|
|
||||||
}
|
|
||||||
return ndarray;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if ULAB_MAX_DIMS > 1
|
#if ULAB_MAX_DIMS > 1
|
||||||
//| def cholesky(A: ulab.array) -> ulab.array:
|
//| def cholesky(A: ulab.array) -> ulab.array:
|
||||||
//| """
|
//| """
|
||||||
|
|
@ -55,7 +40,7 @@ static ndarray_obj_t *linalg_object_is_square(mp_obj_t obj) {
|
||||||
//|
|
//|
|
||||||
|
|
||||||
static mp_obj_t linalg_cholesky(mp_obj_t oin) {
|
static mp_obj_t linalg_cholesky(mp_obj_t oin) {
|
||||||
ndarray_obj_t *ndarray = linalg_object_is_square(oin);
|
ndarray_obj_t *ndarray = tools_object_is_square(oin);
|
||||||
ndarray_obj_t *L = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, ndarray->shape[ULAB_MAX_DIMS - 1], ndarray->shape[ULAB_MAX_DIMS - 1]), NDARRAY_FLOAT);
|
ndarray_obj_t *L = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, ndarray->shape[ULAB_MAX_DIMS - 1], ndarray->shape[ULAB_MAX_DIMS - 1]), NDARRAY_FLOAT);
|
||||||
mp_float_t *Larray = (mp_float_t *)L->array;
|
mp_float_t *Larray = (mp_float_t *)L->array;
|
||||||
|
|
||||||
|
|
@ -121,7 +106,7 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_cholesky_obj, linalg_cholesky);
|
||||||
//|
|
//|
|
||||||
|
|
||||||
static mp_obj_t linalg_det(mp_obj_t oin) {
|
static mp_obj_t linalg_det(mp_obj_t oin) {
|
||||||
ndarray_obj_t *ndarray = linalg_object_is_square(oin);
|
ndarray_obj_t *ndarray = tools_object_is_square(oin);
|
||||||
uint8_t *array = (uint8_t *)ndarray->array;
|
uint8_t *array = (uint8_t *)ndarray->array;
|
||||||
size_t N = ndarray->shape[ULAB_MAX_DIMS - 1];
|
size_t N = ndarray->shape[ULAB_MAX_DIMS - 1];
|
||||||
mp_float_t *tmp = m_new(mp_float_t, N * N);
|
mp_float_t *tmp = m_new(mp_float_t, N * N);
|
||||||
|
|
@ -193,7 +178,7 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_det_obj, linalg_det);
|
||||||
//|
|
//|
|
||||||
|
|
||||||
static mp_obj_t linalg_eig(mp_obj_t oin) {
|
static mp_obj_t linalg_eig(mp_obj_t oin) {
|
||||||
ndarray_obj_t *in = linalg_object_is_square(oin);
|
ndarray_obj_t *in = tools_object_is_square(oin);
|
||||||
uint8_t *iarray = (uint8_t *)in->array;
|
uint8_t *iarray = (uint8_t *)in->array;
|
||||||
size_t S = in->shape[ULAB_MAX_DIMS - 1];
|
size_t S = in->shape[ULAB_MAX_DIMS - 1];
|
||||||
mp_float_t *array = m_new(mp_float_t, S*S);
|
mp_float_t *array = m_new(mp_float_t, S*S);
|
||||||
|
|
@ -254,7 +239,7 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_eig_obj, linalg_eig);
|
||||||
//| ...
|
//| ...
|
||||||
//|
|
//|
|
||||||
static mp_obj_t linalg_inv(mp_obj_t o_in) {
|
static mp_obj_t linalg_inv(mp_obj_t o_in) {
|
||||||
ndarray_obj_t *ndarray = linalg_object_is_square(o_in);
|
ndarray_obj_t *ndarray = tools_object_is_square(o_in);
|
||||||
uint8_t *array = (uint8_t *)ndarray->array;
|
uint8_t *array = (uint8_t *)ndarray->array;
|
||||||
size_t N = ndarray->shape[ULAB_MAX_DIMS - 1];
|
size_t N = ndarray->shape[ULAB_MAX_DIMS - 1];
|
||||||
ndarray_obj_t *inverted = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, N, N), NDARRAY_FLOAT);
|
ndarray_obj_t *inverted = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, N, N), NDARRAY_FLOAT);
|
||||||
|
|
@ -378,34 +363,6 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
|
||||||
MP_DEFINE_CONST_FUN_OBJ_KW(linalg_norm_obj, 1, linalg_norm);
|
MP_DEFINE_CONST_FUN_OBJ_KW(linalg_norm_obj, 1, linalg_norm);
|
||||||
// MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
|
// MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
|
||||||
|
|
||||||
#if ULAB_MAX_DIMS > 1
|
|
||||||
#if ULAB_LINALG_HAS_TRACE
|
|
||||||
|
|
||||||
//| def trace(m: ulab.array) -> float:
|
|
||||||
//| """
|
|
||||||
//| :param m: a square matrix
|
|
||||||
//|
|
|
||||||
//| Compute the trace of the matrix, the sum of its diagonal elements."""
|
|
||||||
//| ...
|
|
||||||
//|
|
|
||||||
|
|
||||||
static mp_obj_t linalg_trace(mp_obj_t oin) {
|
|
||||||
ndarray_obj_t *ndarray = linalg_object_is_square(oin);
|
|
||||||
mp_float_t trace = 0.0;
|
|
||||||
for(size_t i=0; i < ndarray->shape[ULAB_MAX_DIMS - 1]; i++) {
|
|
||||||
int32_t pos = i * (ndarray->strides[ULAB_MAX_DIMS - 1] + ndarray->strides[ULAB_MAX_DIMS - 2]);
|
|
||||||
trace += ndarray_get_float_index(ndarray->array, ndarray->dtype, pos/ndarray->itemsize);
|
|
||||||
}
|
|
||||||
if(ndarray->dtype == NDARRAY_FLOAT) {
|
|
||||||
return mp_obj_new_float(trace);
|
|
||||||
}
|
|
||||||
return mp_obj_new_int_from_float(trace);
|
|
||||||
}
|
|
||||||
|
|
||||||
MP_DEFINE_CONST_FUN_OBJ_1(linalg_trace_obj, linalg_trace);
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = {
|
STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = {
|
||||||
{ MP_OBJ_NEW_QSTR(MP_QSTR___name__), MP_OBJ_NEW_QSTR(MP_QSTR_linalg) },
|
{ MP_OBJ_NEW_QSTR(MP_QSTR___name__), MP_OBJ_NEW_QSTR(MP_QSTR_linalg) },
|
||||||
#if ULAB_MAX_DIMS > 1
|
#if ULAB_MAX_DIMS > 1
|
||||||
|
|
@ -421,9 +378,6 @@ STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = {
|
||||||
#if ULAB_LINALG_HAS_INV
|
#if ULAB_LINALG_HAS_INV
|
||||||
{ MP_ROM_QSTR(MP_QSTR_inv), (mp_obj_t)&linalg_inv_obj },
|
{ MP_ROM_QSTR(MP_QSTR_inv), (mp_obj_t)&linalg_inv_obj },
|
||||||
#endif
|
#endif
|
||||||
#if ULAB_LINALG_HAS_TRACE
|
|
||||||
{ MP_ROM_QSTR(MP_QSTR_trace), (mp_obj_t)&linalg_trace_obj },
|
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
#if ULAB_LINALG_HAS_NORM
|
#if ULAB_LINALG_HAS_NORM
|
||||||
{ MP_ROM_QSTR(MP_QSTR_norm), (mp_obj_t)&linalg_norm_obj },
|
{ MP_ROM_QSTR(MP_QSTR_norm), (mp_obj_t)&linalg_norm_obj },
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,5 @@ MP_DECLARE_CONST_FUN_OBJ_1(linalg_cholesky_obj);
|
||||||
MP_DECLARE_CONST_FUN_OBJ_1(linalg_det_obj);
|
MP_DECLARE_CONST_FUN_OBJ_1(linalg_det_obj);
|
||||||
MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj);
|
MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj);
|
||||||
MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj);
|
MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj);
|
||||||
MP_DECLARE_CONST_FUN_OBJ_1(linalg_trace_obj);
|
|
||||||
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj);
|
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -173,6 +173,9 @@ static const mp_rom_map_elem_t ulab_numpy_globals_table[] = {
|
||||||
#if ULAB_NUMPY_HAS_DOT
|
#if ULAB_NUMPY_HAS_DOT
|
||||||
{ MP_OBJ_NEW_QSTR(MP_QSTR_dot), (mp_obj_t)&transform_dot_obj },
|
{ MP_OBJ_NEW_QSTR(MP_QSTR_dot), (mp_obj_t)&transform_dot_obj },
|
||||||
#endif
|
#endif
|
||||||
|
#if ULAB_NUMPY_HAS_TRACE
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_trace), (mp_obj_t)&stats_trace_obj },
|
||||||
|
#endif
|
||||||
#if ULAB_NUMPY_HAS_FLIP
|
#if ULAB_NUMPY_HAS_FLIP
|
||||||
{ MP_OBJ_NEW_QSTR(MP_QSTR_flip), (mp_obj_t)&numerical_flip_obj },
|
{ MP_OBJ_NEW_QSTR(MP_QSTR_flip), (mp_obj_t)&numerical_flip_obj },
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -22,3 +22,31 @@
|
||||||
#include "../../ulab.h"
|
#include "../../ulab.h"
|
||||||
#include "../../ulab_tools.h"
|
#include "../../ulab_tools.h"
|
||||||
#include "stats.h"
|
#include "stats.h"
|
||||||
|
|
||||||
|
#if ULAB_MAX_DIMS > 1
|
||||||
|
#if ULAB_NUMPY_HAS_TRACE
|
||||||
|
|
||||||
|
//| def trace(m: ulab.array) -> float:
|
||||||
|
//| """
|
||||||
|
//| :param m: a square matrix
|
||||||
|
//|
|
||||||
|
//| Compute the trace of the matrix, the sum of its diagonal elements."""
|
||||||
|
//| ...
|
||||||
|
//|
|
||||||
|
|
||||||
|
static mp_obj_t stats_trace(mp_obj_t oin) {
|
||||||
|
ndarray_obj_t *ndarray = tools_object_is_square(oin);
|
||||||
|
mp_float_t trace = 0.0;
|
||||||
|
for(size_t i=0; i < ndarray->shape[ULAB_MAX_DIMS - 1]; i++) {
|
||||||
|
int32_t pos = i * (ndarray->strides[ULAB_MAX_DIMS - 1] + ndarray->strides[ULAB_MAX_DIMS - 2]);
|
||||||
|
trace += ndarray_get_float_index(ndarray->array, ndarray->dtype, pos/ndarray->itemsize);
|
||||||
|
}
|
||||||
|
if(ndarray->dtype == NDARRAY_FLOAT) {
|
||||||
|
return mp_obj_new_float(trace);
|
||||||
|
}
|
||||||
|
return mp_obj_new_int_from_float(trace);
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_DEFINE_CONST_FUN_OBJ_1(stats_trace_obj, stats_trace);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -15,4 +15,6 @@
|
||||||
#include "../../ulab.h"
|
#include "../../ulab.h"
|
||||||
#include "../../ndarray.h"
|
#include "../../ndarray.h"
|
||||||
|
|
||||||
|
MP_DECLARE_CONST_FUN_OBJ_1(stats_trace_obj);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@
|
||||||
|
|
||||||
#include "user/user.h"
|
#include "user/user.h"
|
||||||
|
|
||||||
#define ULAB_VERSION 2.3.5
|
#define ULAB_VERSION 2.3.6
|
||||||
#define xstr(s) str(s)
|
#define xstr(s) str(s)
|
||||||
#define str(s) #s
|
#define str(s) #s
|
||||||
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)
|
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)
|
||||||
|
|
|
||||||
|
|
@ -353,10 +353,6 @@
|
||||||
#define ULAB_LINALG_HAS_NORM (1)
|
#define ULAB_LINALG_HAS_NORM (1)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef ULAB_LINALG_HAS_TRACE
|
|
||||||
#define ULAB_LINALG_HAS_TRACE (1)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// the FFT module; functions of the fft module still have
|
// the FFT module; functions of the fft module still have
|
||||||
// to be defined separately
|
// to be defined separately
|
||||||
#ifndef ULAB_NUMPY_HAS_FFT_MODULE
|
#ifndef ULAB_NUMPY_HAS_FFT_MODULE
|
||||||
|
|
@ -447,6 +443,10 @@
|
||||||
#define ULAB_NUMPY_HAS_SUM (1)
|
#define ULAB_NUMPY_HAS_SUM (1)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef ULAB_NUMPY_HAS_TRACE
|
||||||
|
#define ULAB_NUMPY_HAS_TRACE (1)
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef ULAB_NUMPY_HAS_TRAPZ
|
#ifndef ULAB_NUMPY_HAS_TRAPZ
|
||||||
#define ULAB_NUMPY_HAS_TRAPZ (1)
|
#define ULAB_NUMPY_HAS_TRAPZ (1)
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -212,3 +212,19 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
|
||||||
|
|
||||||
return _shape_strides;
|
return _shape_strides;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if ULAB_MAX_DIMS > 1
|
||||||
|
ndarray_obj_t *tools_object_is_square(mp_obj_t obj) {
|
||||||
|
// Returns an ndarray, if the object is a square ndarray,
|
||||||
|
// raises the appropriate exception otherwise
|
||||||
|
if(!MP_OBJ_IS_TYPE(obj, &ulab_ndarray_type)) {
|
||||||
|
mp_raise_TypeError(translate("size is defined for ndarrays only"));
|
||||||
|
}
|
||||||
|
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(obj);
|
||||||
|
if((ndarray->shape[ULAB_MAX_DIMS - 1] != ndarray->shape[ULAB_MAX_DIMS - 2]) || (ndarray->ndim != 2)) {
|
||||||
|
mp_raise_ValueError(translate("input must be square matrix"));
|
||||||
|
}
|
||||||
|
return ndarray;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -33,4 +33,5 @@ uint8_t ndarray_upcast_dtype(uint8_t , uint8_t );
|
||||||
void *ndarray_set_float_function(uint8_t );
|
void *ndarray_set_float_function(uint8_t );
|
||||||
|
|
||||||
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
|
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
|
||||||
|
ndarray_obj_t *tools_object_is_square(mp_obj_t );
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -91,8 +91,3 @@ result = (np.linalg.norm(a,axis=1)) # fails on low tolerance
|
||||||
ref_result = np.array([2.236068, 7.071068, 10.24695, 9.797959])
|
ref_result = np.array([2.236068, 7.071068, 10.24695, 9.797959])
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
print(math.isclose(result[i], ref_result[i], rel_tol=1E-6, abs_tol=1E-6))
|
print(math.isclose(result[i], ref_result[i], rel_tol=1E-6, abs_tol=1E-6))
|
||||||
|
|
||||||
if use_ulab:
|
|
||||||
print(np.linalg.trace(np.eye(3)))
|
|
||||||
else:
|
|
||||||
print(np.trace(np.eye(3)))
|
|
||||||
|
|
|
||||||
|
|
@ -51,4 +51,3 @@ True
|
||||||
True
|
True
|
||||||
True
|
True
|
||||||
True
|
True
|
||||||
3.0
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue