From a3ce0ce29a37889d3106c7613de6b3813dae4206 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20V=C3=B6r=C3=B6s?= Date: Thu, 22 Jul 2021 19:40:57 +0200 Subject: [PATCH] add qr implementation --- code/numpy/linalg/linalg.c | 138 ++++++++++++++++++++++++++++++++++++- code/numpy/linalg/linalg.h | 1 + code/ulab.c | 2 +- code/ulab.h | 4 ++ 4 files changed, 143 insertions(+), 2 deletions(-) diff --git a/code/numpy/linalg/linalg.c b/code/numpy/linalg/linalg.c index 697c5e4..37ac18a 100644 --- a/code/numpy/linalg/linalg.c +++ b/code/numpy/linalg/linalg.c @@ -364,7 +364,140 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k } MP_DEFINE_CONST_FUN_OBJ_KW(linalg_norm_obj, 1, linalg_norm); -// MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm); + +#if ULAB_MAX_DIMS > 1 +//| def qr(m: ulab.numpy.ndarray) -> Tuple[ulab.numpy.ndarray, ulab.numpy.ndarray]: +//| """ +//| :param m: a matrix +//| :return tuple (Q, R): +//| +//| Computes the QR decomposition of a matrix""" +//| ... +//| + +static mp_obj_t linalg_qr(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { + static const mp_arg_t allowed_args[] = { + { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = mp_const_none } }, + { MP_QSTR_mode, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_QSTR(MP_QSTR_complete) } }, + }; + + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); + + + if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) { + mp_raise_TypeError(translate("operation is defined for ndarrays only")); + } + ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0].u_obj); + if(source->ndim != 2) { + mp_raise_ValueError(translate("operation is defined for 2D arrays only")); + } + + size_t m = source->shape[ULAB_MAX_DIMS - 2]; // rows + size_t n = source->shape[ULAB_MAX_DIMS - 1]; // columns + + ndarray_obj_t *Q = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, m, m), NDARRAY_FLOAT); + ndarray_obj_t *R = ndarray_new_dense_ndarray(2, source->shape, NDARRAY_FLOAT); + + mp_float_t *qarray = (mp_float_t *)Q->array; + mp_float_t *rarray = (mp_float_t *)R->array; + + // simply copy the entries of source to a float array + mp_float_t (*func)(void *) = ndarray_get_float_function(source->dtype); + uint8_t *sarray = (uint8_t *)source->array; + + for(size_t i = 0; i < m; i++) { + for(size_t j = 0; j < n; j++) { + *rarray++ = func(sarray); + sarray += source->strides[ULAB_MAX_DIMS - 1]; + } + sarray -= n * source->strides[ULAB_MAX_DIMS - 1]; + sarray += source->strides[ULAB_MAX_DIMS - 2]; + } + rarray -= m * n; + + // start with the unit matrix + for(size_t i = 0; i < m; i++) { + qarray[i * (m + 1)] = 1.0; + } + + for(size_t j = 0; j < n; j++) { // columns + for(size_t i = m - 1; i > j; i--) { // rows + mp_float_t c, s; + // Givens matrix: note that numpy uses a strange form of the rotation + // [[c s], + // [s -c]] + if(MICROPY_FLOAT_C_FUN(fabs)(rarray[i * n + j]) < LINALG_EPSILON) { // r[i, j] + c = (rarray[(i - 1) * n + j] >= 0.0) ? 1.0 : -1.0; // r[i-1, j] + s = 0.0; + } else if(MICROPY_FLOAT_C_FUN(fabs)(rarray[(i - 1) * n + j]) < LINALG_EPSILON) { // r[i-1, j] + c = 0.0; + s = (rarray[i * n + j] >= 0.0) ? -1.0 : 1.0; // r[i, j] + } else { + mp_float_t t, u; + if(MICROPY_FLOAT_C_FUN(fabs)(rarray[(i - 1) * n + j]) > MICROPY_FLOAT_C_FUN(fabs)(rarray[i * n + j])) { // r[i-1, j], r[i, j] + t = rarray[i * n + j] / rarray[(i - 1) * n + j]; // r[i, j]/r[i-1, j] + u = MICROPY_FLOAT_C_FUN(sqrt)(1 + t * t); + c = -1.0 / u; + s = c * t; + } else { + t = rarray[(i - 1) * n + j] / rarray[i * n + j]; // r[i-1, j]/r[i, j] + u = MICROPY_FLOAT_C_FUN(sqrt)(1 + t * t); + s = -1.0 / u; + c = s * t; + } + } + + mp_float_t r1, r2; + // update R: multiply with the rotation matrix from the left + for(size_t k = 0; k < n; k++) { + r1 = rarray[(i - 1) * n + k]; // r[i-1, k] + r2 = rarray[i * n + k]; // r[i, k] + rarray[(i - 1) * n + k] = c * r1 + s * r2; // r[i-1, k] + rarray[i * n + k] = s * r1 - c * r2; // r[i, k] + } + + // update Q: multiply with the transpose of the rotation matrix from the right + for(size_t k = 0; k < m; k++) { + r1 = qarray[k * m + (i - 1)]; + r2 = qarray[k * m + i]; + qarray[k * m + (i - 1)] = c * r1 + s * r2; + qarray[k * m + i] = s * r1 - c * r2; + } + } + } + + mp_obj_tuple_t *tuple = MP_OBJ_TO_PTR(mp_obj_new_tuple(2, NULL)); + GET_STR_DATA_LEN(args[1].u_obj, mode, len); + if(memcmp(mode, "complete", 8) == 0) { + tuple->items[0] = MP_OBJ_FROM_PTR(Q); + tuple->items[1] = MP_OBJ_FROM_PTR(R); + } else if(memcmp(mode, "reduced", 7) == 0) { + size_t k = MAX(m, n) - MIN(m, n); + ndarray_obj_t *q = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, m, m - k), NDARRAY_FLOAT); + ndarray_obj_t *r = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, m - k, n), NDARRAY_FLOAT); + mp_float_t *qa = (mp_float_t *)q->array; + mp_float_t *ra = (mp_float_t *)r->array; + for(size_t i = 0; i < m; i++) { + memcpy(qa, qarray, (m - k) * q->itemsize); + qa += (m - k); + qarray += m; + } + for(size_t i = 0; i < m - k; i++) { + memcpy(ra, rarray, n * r->itemsize); + ra += n; + rarray += n; + } + tuple->items[0] = MP_OBJ_FROM_PTR(q); + tuple->items[1] = MP_OBJ_FROM_PTR(r); + } else { + mp_raise_ValueError(translate("mode must be complete, or reduced")); + } + return tuple; +} + +MP_DEFINE_CONST_FUN_OBJ_KW(linalg_qr_obj, 1, linalg_qr); +#endif STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = { { MP_OBJ_NEW_QSTR(MP_QSTR___name__), MP_OBJ_NEW_QSTR(MP_QSTR_linalg) }, @@ -381,6 +514,9 @@ STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = { #if ULAB_LINALG_HAS_INV { MP_ROM_QSTR(MP_QSTR_inv), (mp_obj_t)&linalg_inv_obj }, #endif + #if ULAB_LINALG_HAS_QR + { MP_ROM_QSTR(MP_QSTR_qr), (mp_obj_t)&linalg_qr_obj }, + #endif #endif #if ULAB_LINALG_HAS_NORM { MP_ROM_QSTR(MP_QSTR_norm), (mp_obj_t)&linalg_norm_obj }, diff --git a/code/numpy/linalg/linalg.h b/code/numpy/linalg/linalg.h index fc1867f..cf6b0ad 100644 --- a/code/numpy/linalg/linalg.h +++ b/code/numpy/linalg/linalg.h @@ -23,4 +23,5 @@ MP_DECLARE_CONST_FUN_OBJ_1(linalg_det_obj); MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj); MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj); MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj); +MP_DECLARE_CONST_FUN_OBJ_KW(linalg_qr_obj); #endif diff --git a/code/ulab.c b/code/ulab.c index 072d63b..5802a72 100644 --- a/code/ulab.c +++ b/code/ulab.c @@ -33,7 +33,7 @@ #include "user/user.h" #include "utils/utils.h" -#define ULAB_VERSION 3.2.0 +#define ULAB_VERSION 3.3.0 #define xstr(s) str(s) #define str(s) #s #define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D) diff --git a/code/ulab.h b/code/ulab.h index 94e5587..630faa0 100644 --- a/code/ulab.h +++ b/code/ulab.h @@ -371,6 +371,10 @@ #define ULAB_LINALG_HAS_NORM (1) #endif +#ifndef ULAB_LINALG_HAS_QR +#define ULAB_LINALG_HAS_QR (1) +#endif + // the FFT module; functions of the fft module still have // to be defined separately #ifndef ULAB_NUMPY_HAS_FFT_MODULE