removed redundant function pointers, added copyright note

This commit is contained in:
vikas-udupa 2021-05-05 18:47:15 -04:00
parent ba6409a7ad
commit e52fa96c23
3 changed files with 16 additions and 18 deletions

View file

@ -6,6 +6,8 @@
*
* The MIT License (MIT)
*
* Copyright (c) 2021 Vikas Udupa
*
*/
#include <stdlib.h>
@ -77,6 +79,8 @@ static mp_obj_t solve_triangular(size_t n_args, const mp_obj_t *pos_args, mp_map
mp_float_t (*get_b_ele)(void *) = ndarray_get_float_function(b->dtype);
uint8_t *temp_A = A_arr;
// check if input matrix A is singular
for (i = 0; i < A_rows; i++) {
if (MICROPY_FLOAT_C_FUN(fabs)(get_A_ele(A_arr)) < TOLERANCE)
mp_raise_ValueError(translate("input matrix is singular"));
@ -87,10 +91,7 @@ static mp_obj_t solve_triangular(size_t n_args, const mp_obj_t *pos_args, mp_map
A_arr = temp_A;
ndarray_obj_t *x = ndarray_new_dense_ndarray(b->ndim, b->shape, NDARRAY_FLOAT);
uint8_t *x_arr = (uint8_t *)x->array;
mp_float_t (*get_x_ele)(void *) = ndarray_get_float_function(NDARRAY_FLOAT);
void (*set_x_ele)(void *, mp_float_t) = ndarray_set_float_function(NDARRAY_FLOAT);
mp_float_t *x_arr = (mp_float_t *)x->array;
if (mp_obj_is_true(args[2].u_obj)) {
// Solve the lower triangular matrix by iterating each row of A.
@ -101,16 +102,16 @@ static mp_obj_t solve_triangular(size_t n_args, const mp_obj_t *pos_args, mp_map
for (i = 0; i < A_rows; i++) {
mp_float_t sum = 0.0;
for (j = 0; j < i; j++) {
sum += (get_A_ele(A_arr) * get_x_ele(x_arr));
sum += (get_A_ele(A_arr) * (*x_arr++));
A_arr += A->strides[ULAB_MAX_DIMS - 1];
x_arr += x->strides[ULAB_MAX_DIMS - 1];
}
sum = (get_b_ele(b_arr) - sum) / (get_A_ele(A_arr));
set_x_ele(x_arr, sum);
*x_arr = sum;
x_arr -= j;
A_arr -= A->strides[ULAB_MAX_DIMS - 1] * j;
A_arr += A->strides[ULAB_MAX_DIMS - 2];
x_arr -= x->strides[ULAB_MAX_DIMS - 1] * j;
b_arr += b->strides[ULAB_MAX_DIMS - 1];
}
} else {
@ -121,22 +122,21 @@ static mp_obj_t solve_triangular(size_t n_args, const mp_obj_t *pos_args, mp_map
A_arr += (A->strides[ULAB_MAX_DIMS - 2] * A_rows);
b_arr += (b->strides[ULAB_MAX_DIMS - 1] * A_cols);
x_arr += (x->strides[ULAB_MAX_DIMS - 1] * A_cols);
x_arr += A_cols;
for (i = A_rows - 1; i < A_rows; i--) {
mp_float_t sum = 0.0;
for (j = i + 1; j < A_cols; j++) {
sum += (get_A_ele(A_arr) * get_x_ele(x_arr));
sum += (get_A_ele(A_arr) * (*x_arr++));
A_arr += A->strides[ULAB_MAX_DIMS - 1];
x_arr += x->strides[ULAB_MAX_DIMS - 1];
}
x_arr -= (j - i);
A_arr -= (A->strides[ULAB_MAX_DIMS - 1] * (j - i));
x_arr -= (x->strides[ULAB_MAX_DIMS - 1] * (j - i));
b_arr -= b->strides[ULAB_MAX_DIMS - 1];
sum = (get_b_ele(b_arr) - sum) / get_A_ele(A_arr);
set_x_ele(x_arr, sum);
*x_arr = sum;
A_arr -= A->strides[ULAB_MAX_DIMS - 2];
}

View file

@ -6,6 +6,8 @@
*
* The MIT License (MIT)
*
* Copyright (c) 2021 Vikas Udupa
*
*/
#ifndef _SCIPY_LINALG_

View file

@ -1,13 +1,9 @@
import math
try:
from ulab import scipy
from ulab import scipy, numpy as np
except ImportError:
import scipy
try:
from ulab import numpy as np
except ImportError:
import numpy as np
A = np.array([[3, 0, 2, 6], [2, 1, 0, 1], [1, 0, 1, 4], [1, 2, 1, 8]])