implemented solve_triangular function in numpy.linalg module along with two tests

This commit is contained in:
vikas-udupa 2021-04-29 09:20:17 -04:00
parent cd51012fa0
commit ab60c5c98c
7 changed files with 137 additions and 0 deletions

View file

@ -267,6 +267,46 @@ static mp_obj_t linalg_inv(mp_obj_t o_in) {
MP_DEFINE_CONST_FUN_OBJ_1(linalg_inv_obj, linalg_inv);
#endif
#if ULAB_MAX_DIMS > 1
//| def solve_triangular(A: ulab.numpy.ndarray, b: ulab.numpy.ndarray, lower: bool) -> ulab.numpy.ndarray:
//| """
//| :param ~ulab.numpy.ndarray A: a matrix
//| :param ~ulab.numpy.ndarray b: a vector
//| :param ~bool lower: if true, use only data contained in lower triangle of A, else use upper triangle of A
//| :return: solution to the system A x = b. Shape of return matches b
//| :raises TypeError: if A and b are not of type ndarray and are not dense
//|
//| Solve the equation A x = b for x, assuming A is a triangular matrix"""
//| ...
//|
static mp_obj_t solve_triangular(mp_obj_t _A, mp_obj_t _b, mp_obj_t _lower) {
if(!MP_OBJ_IS_TYPE(_A, &ulab_ndarray_type) || !MP_OBJ_IS_TYPE(_b, &ulab_ndarray_type)) {
mp_raise_TypeError(translate("first two arguments must be ndarrays"));
}
ndarray_obj_t *A = MP_OBJ_TO_PTR(_A);
ndarray_obj_t *b = MP_OBJ_TO_PTR(_b);
if(!ndarray_is_dense(A) || !ndarray_is_dense(b)) {
mp_raise_TypeError(translate("input must be a dense ndarray"));
}
ndarray_obj_t *x;
// if lower is true, solve using lower triangle of A
// else solve using upper triangle of A
x = (mp_obj_is_true(_lower) ? solve_lower_triangular(A, b) : solve_upper_triangular(A, b));
return MP_OBJ_FROM_PTR(x);
}
MP_DEFINE_CONST_FUN_OBJ_3(linalg_solve_triangular_obj, solve_triangular);
#endif
//| def norm(x: ulab.array) -> float:
//| """
//| :param ~ulab.array x: a vector or a matrix
@ -378,6 +418,9 @@ STATIC const mp_rom_map_elem_t ulab_linalg_globals_table[] = {
#if ULAB_LINALG_HAS_INV
{ MP_ROM_QSTR(MP_QSTR_inv), (mp_obj_t)&linalg_inv_obj },
#endif
#if ULAB_LINALG_HAS_SOLVE_TRIANGULAR
{ MP_ROM_QSTR(MP_QSTR_solve_triangular), (mp_obj_t)&linalg_solve_triangular_obj },
#endif
#endif
#if ULAB_LINALG_HAS_NORM
{ MP_ROM_QSTR(MP_QSTR_norm), (mp_obj_t)&linalg_norm_obj },

View file

@ -22,5 +22,6 @@ MP_DECLARE_CONST_FUN_OBJ_1(linalg_cholesky_obj);
MP_DECLARE_CONST_FUN_OBJ_1(linalg_det_obj);
MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj);
MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj);
MP_DECLARE_CONST_FUN_OBJ_3(linalg_solve_triangular_obj);
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj);
#endif

View file

@ -169,3 +169,65 @@ size_t linalg_jacobi_rotations(mp_float_t *array, mp_float_t *eigvectors, size_t
return iterations;
}
/*
* This function solves the equation A x = b for x, where A is considered as
* a lower triangular matrix
*/
ndarray_obj_t *solve_lower_triangular(ndarray_obj_t *A, ndarray_obj_t *b) {
mp_float_t *A_arr = (mp_float_t *)A->array;
mp_float_t *b_arr = (mp_float_t *)b->array;
size_t A_rows = A->shape[ULAB_MAX_DIMS - 2];
size_t A_cols = A->shape[ULAB_MAX_DIMS - 1];
size_t i, j;
ndarray_obj_t *x = ndarray_new_dense_ndarray(b->ndim, b->shape, b->dtype);
mp_float_t *x_arr = (mp_float_t *)x->array;
// Solve the lower triangular matrix by iterating each row of A.
// Start by finding the first unknown using the first row.
// On finding this unknown, find the second unknown using the second row.
// Continue the same till the last unknown is found using the last row.
for (i = 0; i < A_rows; i++) {
mp_float_t sum = 0.0;
for (j = 0; j < i; j++) {
sum += (*(A_arr + ((i * A_cols) + j)))
* (*(x_arr + j));
}
*(x_arr + i) = ((*(b_arr + i)) - sum) / (*(A_arr + ((i * A_cols) + i)));
}
return x;
}
/*
* This function solves the equation A x = b for x, where A is considered as
* an upper triangular matrix
*/
ndarray_obj_t *solve_upper_triangular(ndarray_obj_t *A, ndarray_obj_t *b) {
mp_float_t *A_arr = (mp_float_t *)A->array;
mp_float_t *b_arr = (mp_float_t *)b->array;
size_t A_rows = A->shape[ULAB_MAX_DIMS - 2];
size_t A_cols = A->shape[ULAB_MAX_DIMS - 1];
size_t i, j;
ndarray_obj_t *x = ndarray_new_dense_ndarray(b->ndim, b->shape, b->dtype);
mp_float_t *x_arr = (mp_float_t *)x->array;
// Solve the upper triangular matrix by iterating each row of A.
// Start by finding the last unknown using the last row.
// On finding this unknown, find the last-but-one unknown using the last-but-one row.
// Continue the same till the first unknown is found using the first row.
for (i = A_rows - 1; i < A_rows; i--) {
mp_float_t sum = 0.0;
for (j = i + 1; j < A_cols; j++) {
sum += (*(A_arr + ((i * A_cols) + j)))
* (*(x_arr + j));
}
*(x_arr + i) = ((*(b_arr + i)) - sum) / (*(A_arr + ((i * A_cols) + i)));
}
return x;
}

View file

@ -11,6 +11,8 @@
#ifndef _TOOLS_TOOLS_
#define _TOOLS_TOOLS_
#include <ndarray.h>
#ifndef LINALG_EPSILON
#if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT
#define LINALG_EPSILON MICROPY_FLOAT_CONST(1.2e-7)
@ -23,6 +25,8 @@
bool linalg_invert_matrix(mp_float_t *, size_t );
size_t linalg_jacobi_rotations(mp_float_t *, mp_float_t *, size_t );
ndarray_obj_t *solve_lower_triangular(ndarray_obj_t *A, ndarray_obj_t *b);
ndarray_obj_t *solve_upper_triangular(ndarray_obj_t *A, ndarray_obj_t *b);
#endif /* _TOOLS_TOOLS_ */

View file

@ -353,6 +353,10 @@
#define ULAB_LINALG_HAS_INV (1)
#endif
#ifndef ULAB_LINALG_HAS_SOLVE_TRIANGULAR
#define ULAB_LINALG_HAS_SOLVE_TRIANGULAR (1)
#endif
#ifndef ULAB_LINALG_HAS_NORM
#define ULAB_LINALG_HAS_NORM (1)
#endif

View file

@ -91,3 +91,18 @@ result = (np.linalg.norm(a,axis=1)) # fails on low tolerance
ref_result = np.array([2.236068, 7.071068, 10.24695, 9.797959])
for i in range(4):
print(math.isclose(result[i], ref_result[i], rel_tol=1E-6, abs_tol=1E-6))
A = np.array([[3, 0, 2, 6], [2, 1, 0, 1], [1, 0, 1, 4], [1, 2, 1, 8]])
b = np.array([4, 2, 4, 2])
# forward substitution
result = np.linalg.solve_triangular(A, b, True)
ref_result = np.array([1.333333333, -0.666666666, 2.666666666, -0.083333333])
for i in range(4):
print(math.isclose(result[i], ref_result[i], rel_tol=1E-6, abs_tol=1E-6))
# backward substitution
result = np.linalg.solve_triangular(A, b, False)
ref_result = np.array([-1.166666666, 1.75, 3.0, 0.25])
for i in range(4):
print(math.isclose(result[i], ref_result[i], rel_tol=1E-6, abs_tol=1E-6))

View file

@ -51,3 +51,11 @@ True
True
True
True
True
True
True
True
True
True
True
True