add qr implementation
This commit is contained in:
parent
90cb0959b9
commit
a3ce0ce29a
4 changed files with 143 additions and 2 deletions
|
|
@ -364,7 +364,140 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
|
||||||
}
|
}
|
||||||
|
|
||||||
MP_DEFINE_CONST_FUN_OBJ_KW(linalg_norm_obj, 1, linalg_norm);
|
MP_DEFINE_CONST_FUN_OBJ_KW(linalg_norm_obj, 1, linalg_norm);
|
||||||
// MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
|
|
||||||
|
#if ULAB_MAX_DIMS > 1
|
||||||
|
//| def qr(m: ulab.numpy.ndarray) -> Tuple[ulab.numpy.ndarray, ulab.numpy.ndarray]:
|
||||||
|
//| """
|
||||||
|
//| :param m: a matrix
|
||||||
|
//| :return tuple (Q, R):
|
||||||
|
//|
|
||||||
|
//| Computes the QR decomposition of a matrix"""
|
||||||
|
//| ...
|
||||||
|
//|
|
||||||
|
|
||||||
|
static mp_obj_t linalg_qr(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||||
|
static const mp_arg_t allowed_args[] = {
|
||||||
|
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = mp_const_none } },
|
||||||
|
{ MP_QSTR_mode, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_QSTR(MP_QSTR_complete) } },
|
||||||
|
};
|
||||||
|
|
||||||
|
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||||
|
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||||
|
|
||||||
|
|
||||||
|
if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) {
|
||||||
|
mp_raise_TypeError(translate("operation is defined for ndarrays only"));
|
||||||
|
}
|
||||||
|
ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0].u_obj);
|
||||||
|
if(source->ndim != 2) {
|
||||||
|
mp_raise_ValueError(translate("operation is defined for 2D arrays only"));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t m = source->shape[ULAB_MAX_DIMS - 2]; // rows
|
||||||
|
size_t n = source->shape[ULAB_MAX_DIMS - 1]; // columns
|
||||||
|
|
||||||
|
ndarray_obj_t *Q = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, m, m), NDARRAY_FLOAT);
|
||||||
|
ndarray_obj_t *R = ndarray_new_dense_ndarray(2, source->shape, NDARRAY_FLOAT);
|
||||||
|
|
||||||
|
mp_float_t *qarray = (mp_float_t *)Q->array;
|
||||||
|
mp_float_t *rarray = (mp_float_t *)R->array;
|
||||||
|
|
||||||
|
// simply copy the entries of source to a float array
|
||||||
|
mp_float_t (*func)(void *) = ndarray_get_float_function(source->dtype);
|
||||||
|
uint8_t *sarray = (uint8_t *)source->array;
|
||||||
|
|
||||||
|
for(size_t i = 0; i < m; i++) {
|
||||||
|
for(size_t j = 0; j < n; j++) {
|
||||||
|
*rarray++ = func(sarray);
|
||||||
|
sarray += source->strides[ULAB_MAX_DIMS - 1];
|
||||||
|
}
|
||||||
|
sarray -= n * source->strides[ULAB_MAX_DIMS - 1];
|
||||||
|
sarray += source->strides[ULAB_MAX_DIMS - 2];
|
||||||
|
}
|
||||||
|
rarray -= m * n;
|
||||||
|
|
||||||
|
// start with the unit matrix
|
||||||
|
for(size_t i = 0; i < m; i++) {
|
||||||
|
qarray[i * (m + 1)] = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(size_t j = 0; j < n; j++) { // columns
|
||||||
|
for(size_t i = m - 1; i > j; i--) { // rows
|
||||||
|
mp_float_t c, s;
|
||||||
|
// Givens matrix: note that numpy uses a strange form of the rotation
|
||||||
|
// [[c s],
|
||||||
|
// [s -c]]
|
||||||
|
if(MICROPY_FLOAT_C_FUN(fabs)(rarray[i * n + j]) < LINALG_EPSILON) { // r[i, j]
|
||||||
|
c = (rarray[(i - 1) * n + j] >= 0.0) ? 1.0 : -1.0; // r[i-1, j]
|
||||||
|
s = 0.0;
|
||||||
|
} else if(MICROPY_FLOAT_C_FUN(fabs)(rarray[(i - 1) * n + j]) < LINALG_EPSILON) { // r[i-1, j]
|
||||||
|
c = 0.0;
|
||||||
|
s = (rarray[i * n + j] >= 0.0) ? -1.0 : 1.0; // r[i, j]
|
||||||
|
} else {
|
||||||
|
mp_float_t t, u;
|
||||||
|
if(MICROPY_FLOAT_C_FUN(fabs)(rarray[(i - 1) * n + j]) > MICROPY_FLOAT_C_FUN(fabs)(rarray[i * n + j])) { // r[i-1, j], r[i, j]
|
||||||
|
t = rarray[i * n + j] / rarray[(i - 1) * n + j]; // r[i, j]/r[i-1, j]
|
||||||
|
u = MICROPY_FLOAT_C_FUN(sqrt)(1 + t * t);
|
||||||
|
c = -1.0 / u;
|
||||||
|
s = c * t;
|
||||||
|
} else {
|
||||||
|
t = rarray[(i - 1) * n + j] / rarray[i * n + j]; // r[i-1, j]/r[i, j]
|
||||||
|
u = MICROPY_FLOAT_C_FUN(sqrt)(1 + t * t);
|
||||||
|
s = -1.0 / u;
|
||||||
|
c = s * t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp_float_t r1, r2;
|
||||||
|
// update R: multiply with the rotation matrix from the left
|
||||||
|
for(size_t k = 0; k < n; k++) {
|
||||||
|
r1 = rarray[(i - 1) * n + k]; // r[i-1, k]
|
||||||
|
r2 = rarray[i * n + k]; // r[i, k]
|
||||||
|
rarray[(i - 1) * n + k] = c * r1 + s * r2; // r[i-1, k]
|
||||||
|
rarray[i * n + k] = s * r1 - c * r2; // r[i, k]
|
||||||
|
}
|
||||||
|
|
||||||
|
// update Q: multiply with the transpose of the rotation matrix from the right
|
||||||
|
for(size_t k = 0; k < m; k++) {
|
||||||
|
r1 = qarray[k * m + (i - 1)];
|
||||||
|
r2 = qarray[k * m + i];
|
||||||
|
qarray[k * m + (i - 1)] = c * r1 + s * r2;
|
||||||
|
qarray[k * m + i] = s * r1 - c * r2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp_obj_tuple_t *tuple = MP_OBJ_TO_PTR(mp_obj_new_tuple(2, NULL));
|
||||||
|
GET_STR_DATA_LEN(args[1].u_obj, mode, len);
|
||||||
|
if(memcmp(mode, "complete", 8) == 0) {
|
||||||
|
tuple->items[0] = MP_OBJ_FROM_PTR(Q);
|
||||||
|
tuple->items[1] = MP_OBJ_FROM_PTR(R);
|
||||||
|
} else if(memcmp(mode, "reduced", 7) == 0) {
|
||||||
|
size_t k = MAX(m, n) - MIN(m, n);
|
||||||
|
ndarray_obj_t *q = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, m, m - k), NDARRAY_FLOAT);
|
||||||
|
ndarray_obj_t *r = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, m - k, n), NDARRAY_FLOAT);
|
||||||
|
mp_float_t *qa = (mp_float_t *)q->array;
|
||||||
|
mp_float_t *ra = (mp_float_t *)r->array;
|
||||||
|
for(size_t i = 0; i < m; i++) {
|
||||||
|
memcpy(qa, qarray, (m - k) * q->itemsize);
|
||||||
|
qa += (m - k);
|
||||||
|
qarray += m;
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i < m - k; i++) {
|
||||||
|
memcpy(ra, rarray, n * r->itemsize);
|
||||||
|
ra += n;
|
||||||
|
rarray += n;
|
||||||
|
}
|
||||||
|
tuple->items[0] = MP_OBJ_FROM_PTR(q);
|
||||||
|
tuple->items[1] = MP_OBJ_FROM_PTR(r);
|
||||||
|
} else {
|
||||||
|
mp_raise_ValueError(translate("mode must be complete, or reduced"));
|
||||||
|
}
|
||||||
|
return tuple;
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_DEFINE_CONST_FUN_OBJ_KW(linalg_qr_obj, 1, linalg_qr);
|
||||||
|
#endif
|
||||||
|
|
||||||
STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = {
|
STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = {
|
||||||
{ MP_OBJ_NEW_QSTR(MP_QSTR___name__), MP_OBJ_NEW_QSTR(MP_QSTR_linalg) },
|
{ MP_OBJ_NEW_QSTR(MP_QSTR___name__), MP_OBJ_NEW_QSTR(MP_QSTR_linalg) },
|
||||||
|
|
@ -381,6 +514,9 @@ STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = {
|
||||||
#if ULAB_LINALG_HAS_INV
|
#if ULAB_LINALG_HAS_INV
|
||||||
{ MP_ROM_QSTR(MP_QSTR_inv), (mp_obj_t)&linalg_inv_obj },
|
{ MP_ROM_QSTR(MP_QSTR_inv), (mp_obj_t)&linalg_inv_obj },
|
||||||
#endif
|
#endif
|
||||||
|
#if ULAB_LINALG_HAS_QR
|
||||||
|
{ MP_ROM_QSTR(MP_QSTR_qr), (mp_obj_t)&linalg_qr_obj },
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#if ULAB_LINALG_HAS_NORM
|
#if ULAB_LINALG_HAS_NORM
|
||||||
{ MP_ROM_QSTR(MP_QSTR_norm), (mp_obj_t)&linalg_norm_obj },
|
{ MP_ROM_QSTR(MP_QSTR_norm), (mp_obj_t)&linalg_norm_obj },
|
||||||
|
|
|
||||||
|
|
@ -23,4 +23,5 @@ MP_DECLARE_CONST_FUN_OBJ_1(linalg_det_obj);
|
||||||
MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj);
|
MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj);
|
||||||
MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj);
|
MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj);
|
||||||
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj);
|
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj);
|
||||||
|
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_qr_obj);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@
|
||||||
#include "user/user.h"
|
#include "user/user.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
|
||||||
#define ULAB_VERSION 3.2.0
|
#define ULAB_VERSION 3.3.0
|
||||||
#define xstr(s) str(s)
|
#define xstr(s) str(s)
|
||||||
#define str(s) #s
|
#define str(s) #s
|
||||||
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)
|
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)
|
||||||
|
|
|
||||||
|
|
@ -371,6 +371,10 @@
|
||||||
#define ULAB_LINALG_HAS_NORM (1)
|
#define ULAB_LINALG_HAS_NORM (1)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef ULAB_LINALG_HAS_QR
|
||||||
|
#define ULAB_LINALG_HAS_QR (1)
|
||||||
|
#endif
|
||||||
|
|
||||||
// the FFT module; functions of the fft module still have
|
// the FFT module; functions of the fft module still have
|
||||||
// to be defined separately
|
// to be defined separately
|
||||||
#ifndef ULAB_NUMPY_HAS_FFT_MODULE
|
#ifndef ULAB_NUMPY_HAS_FFT_MODULE
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue