implemented cho_solve function in scipy.linalg
This commit is contained in:
parent
d157fc2393
commit
b0679e6d16
7 changed files with 160 additions and 1 deletions
|
|
@ -147,6 +147,114 @@ static mp_obj_t solve_triangular(size_t n_args, const mp_obj_t *pos_args, mp_map
|
|||
|
||||
MP_DEFINE_CONST_FUN_OBJ_KW(linalg_solve_triangular_obj, 2, solve_triangular);
|
||||
|
||||
//| def cho_solve(L: ulab.numpy.ndarray, b: ulab.numpy.ndarray) -> ulab.numpy.ndarray:
|
||||
//| """
|
||||
//| :param ~ulab.numpy.ndarray L: the lower triangular, Cholesky factorization of A
|
||||
//| :param ~ulab.numpy.ndarray b: right-hand-side vector b
|
||||
//| :return: solution to the system A x = b. Shape of return matches b
|
||||
//| :raises TypeError: if L and b are not of type ndarray and are not dense
|
||||
//|
|
||||
//| Solve the linear equations A x = b, given the Cholesky factorization of A as input"""
|
||||
//| ...
|
||||
//|
|
||||
|
||||
static mp_obj_t cho_solve(mp_obj_t _L, mp_obj_t _b) {
|
||||
|
||||
if(!mp_obj_is_type(_L, &ulab_ndarray_type) || !mp_obj_is_type(_b, &ulab_ndarray_type)) {
|
||||
mp_raise_TypeError(translate("first two arguments must be ndarrays"));
|
||||
}
|
||||
|
||||
ndarray_obj_t *L = MP_OBJ_TO_PTR(_L);
|
||||
ndarray_obj_t *b = MP_OBJ_TO_PTR(_b);
|
||||
|
||||
if(!ndarray_is_dense(L) || !ndarray_is_dense(b)) {
|
||||
mp_raise_TypeError(translate("input must be a dense ndarray"));
|
||||
}
|
||||
|
||||
mp_float_t (*get_L_ele)(void *) = ndarray_get_float_function(L->dtype);
|
||||
mp_float_t (*get_b_ele)(void *) = ndarray_get_float_function(b->dtype);
|
||||
void (*set_L_ele)(void *, mp_float_t) = ndarray_set_float_function(L->dtype);
|
||||
|
||||
size_t L_rows = L->shape[ULAB_MAX_DIMS - 2];
|
||||
size_t L_cols = L->shape[ULAB_MAX_DIMS - 1];
|
||||
|
||||
// Obtain transpose of the input matrix L in L_t
|
||||
size_t L_t_shape[ULAB_MAX_DIMS];
|
||||
size_t L_t_rows = L_t_shape[ULAB_MAX_DIMS - 2] = L_cols;
|
||||
size_t L_t_cols = L_t_shape[ULAB_MAX_DIMS - 1] = L_rows;
|
||||
ndarray_obj_t *L_t = ndarray_new_dense_ndarray(L->ndim, L_t_shape, L->dtype);
|
||||
|
||||
uint8_t *L_arr = (uint8_t *)L->array;
|
||||
uint8_t *L_t_arr = (uint8_t *)L_t->array;
|
||||
uint8_t *b_arr = (uint8_t *)b->array;
|
||||
|
||||
size_t i, j;
|
||||
|
||||
uint8_t *L_ptr = L_arr;
|
||||
uint8_t *L_t_ptr = L_t_arr;
|
||||
for (i = 0; i < L_rows; i++) {
|
||||
for (j = 0; j < L_cols; j++) {
|
||||
set_L_ele(L_t_ptr, get_L_ele(L_ptr));
|
||||
L_t_ptr += L_t->strides[ULAB_MAX_DIMS - 2];
|
||||
L_ptr += L->strides[ULAB_MAX_DIMS - 1];
|
||||
}
|
||||
|
||||
L_t_ptr -= j * L_t->strides[ULAB_MAX_DIMS - 2];
|
||||
L_t_ptr += L_t->strides[ULAB_MAX_DIMS - 1];
|
||||
L_ptr -= j * L->strides[ULAB_MAX_DIMS - 1];
|
||||
L_ptr += L->strides[ULAB_MAX_DIMS - 2];
|
||||
}
|
||||
|
||||
ndarray_obj_t *x = ndarray_new_dense_ndarray(b->ndim, b->shape, NDARRAY_FLOAT);
|
||||
mp_float_t *x_arr = (mp_float_t *)x->array;
|
||||
|
||||
ndarray_obj_t *y = ndarray_new_dense_ndarray(b->ndim, b->shape, NDARRAY_FLOAT);
|
||||
mp_float_t *y_arr = (mp_float_t *)y->array;
|
||||
|
||||
// solve L y = b to obtain y, where L_t x = y
|
||||
for (i = 0; i < L_rows; i++) {
|
||||
mp_float_t sum = 0.0;
|
||||
for (j = 0; j < i; j++) {
|
||||
sum += (get_L_ele(L_arr) * (*y_arr++));
|
||||
L_arr += L->strides[ULAB_MAX_DIMS - 1];
|
||||
}
|
||||
|
||||
sum = (get_b_ele(b_arr) - sum) / (get_L_ele(L_arr));
|
||||
*y_arr = sum;
|
||||
|
||||
y_arr -= j;
|
||||
L_arr -= L->strides[ULAB_MAX_DIMS - 1] * j;
|
||||
L_arr += L->strides[ULAB_MAX_DIMS - 2];
|
||||
b_arr += b->strides[ULAB_MAX_DIMS - 1];
|
||||
}
|
||||
|
||||
// using y, solve L_t x = y to obtain x
|
||||
L_t_arr += (L_t->strides[ULAB_MAX_DIMS - 2] * L_t_rows);
|
||||
y_arr += L_t_cols;
|
||||
x_arr += L_t_cols;
|
||||
|
||||
for (i = L_t_rows - 1; i < L_t_rows; i--) {
|
||||
mp_float_t sum = 0.0;
|
||||
for (j = i + 1; j < L_t_cols; j++) {
|
||||
sum += (get_L_ele(L_t_arr) * (*x_arr++));
|
||||
L_t_arr += L_t->strides[ULAB_MAX_DIMS - 1];
|
||||
}
|
||||
|
||||
x_arr -= (j - i);
|
||||
L_t_arr -= (L_t->strides[ULAB_MAX_DIMS - 1] * (j - i));
|
||||
y_arr--;
|
||||
|
||||
sum = ((*y_arr) - sum) / get_L_ele(L_t_arr);
|
||||
*x_arr = sum;
|
||||
|
||||
L_t_arr -= L_t->strides[ULAB_MAX_DIMS - 2];
|
||||
}
|
||||
|
||||
return MP_OBJ_FROM_PTR(x);
|
||||
}
|
||||
|
||||
MP_DEFINE_CONST_FUN_OBJ_2(linalg_cho_solve_obj, cho_solve);
|
||||
|
||||
#endif
|
||||
|
||||
static const mp_rom_map_elem_t ulab_scipy_linalg_globals_table[] = {
|
||||
|
|
@ -155,6 +263,9 @@ static const mp_rom_map_elem_t ulab_scipy_linalg_globals_table[] = {
|
|||
#if ULAB_SCIPY_LINALG_HAS_SOLVE_TRIANGULAR
|
||||
{ MP_ROM_QSTR(MP_QSTR_solve_triangular), (mp_obj_t)&linalg_solve_triangular_obj },
|
||||
#endif
|
||||
#if ULAB_SCIPY_LINALG_HAS_CHO_SOLVE
|
||||
{ MP_ROM_QSTR(MP_QSTR_cho_solve), (mp_obj_t)&linalg_cho_solve_obj },
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -16,5 +16,6 @@
|
|||
extern mp_obj_module_t ulab_scipy_linalg_module;
|
||||
|
||||
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_solve_triangular_obj);
|
||||
MP_DECLARE_CONST_FUN_OBJ_2(linalg_cho_solve_obj);
|
||||
|
||||
#endif /* _SCIPY_LINALG_ */
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@
|
|||
#include "user/user.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
#define ULAB_VERSION 2.7.1
|
||||
#define ULAB_VERSION 2.8.0
|
||||
#define xstr(s) str(s)
|
||||
#define str(s) #s
|
||||
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)
|
||||
|
|
|
|||
|
|
@ -565,6 +565,10 @@
|
|||
#define ULAB_SCIPY_HAS_LINALG_MODULE (1)
|
||||
#endif
|
||||
|
||||
#ifndef ULAB_SCIPY_LINALG_HAS_CHO_SOLVE
|
||||
#define ULAB_SCIPY_LINALG_HAS_CHO_SOLVE (1)
|
||||
#endif
|
||||
|
||||
#ifndef ULAB_SCIPY_LINALG_HAS_SOLVE_TRIANGULAR
|
||||
#define ULAB_SCIPY_LINALG_HAS_SOLVE_TRIANGULAR (1)
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
Sun, 16 May 2021
|
||||
|
||||
version 2.8.0
|
||||
|
||||
added cho_solve function in scipy.linalg module
|
||||
|
||||
Thu, 13 May 2021
|
||||
|
||||
version 2.7.1
|
||||
|
|
|
|||
29
tests/scipy/cho_solve.py
Normal file
29
tests/scipy/cho_solve.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import math
|
||||
|
||||
try:
|
||||
from ulab import scipy, numpy as np
|
||||
except ImportError:
|
||||
import scipy
|
||||
import numpy as np
|
||||
|
||||
## test cholesky solve
|
||||
L = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 2, 1, 8]])
|
||||
b = np.array([4, 2, 4, 2])
|
||||
|
||||
# L needs to be a lower triangular matrix
|
||||
result = scipy.linalg.cho_solve(L, b)
|
||||
ref_result = np.array([-0.01388888888888906, -0.6458333333333331, 2.677083333333333, -0.01041666666666667])
|
||||
|
||||
for i in range(4):
|
||||
print(math.isclose(result[i], ref_result[i], rel_tol=1E-6, abs_tol=1E-6))
|
||||
|
||||
## test cholesky and cho_solve together
|
||||
C = np.array([[18, 22, 54, 42], [22, 70, 86, 62], [54, 86, 174, 134], [42, 62, 134, 106]])
|
||||
L = np.linalg.cholesky(C)
|
||||
|
||||
# L is a lower triangular matrix obtained by performing cholesky of positive-definite linear system
|
||||
result = scipy.linalg.cho_solve(L, b)
|
||||
ref_result = np.array([6.5625, 1.1875, -2.9375, 0.4375])
|
||||
|
||||
for i in range(4):
|
||||
print(math.isclose(result[i], ref_result[i], rel_tol=1E-6, abs_tol=1E-6))
|
||||
8
tests/scipy/cho_solve.py.exp
Normal file
8
tests/scipy/cho_solve.py.exp
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
True
|
||||
True
|
||||
True
|
||||
True
|
||||
True
|
||||
True
|
||||
True
|
||||
True
|
||||
Loading…
Reference in a new issue