fix dot function
This commit is contained in:
parent
42212622ff
commit
3fb04aedac
3 changed files with 40 additions and 40 deletions
|
|
@ -45,50 +45,44 @@ mp_obj_t transform_dot(mp_obj_t _m1, mp_obj_t _m2) {
|
|||
mp_float_t (*func1)(void *) = ndarray_get_float_function(m1->dtype);
|
||||
mp_float_t (*func2)(void *) = ndarray_get_float_function(m2->dtype);
|
||||
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
if ((m1->ndim == 1) && (m2->ndim == 1)) {
|
||||
#endif
|
||||
// 2 vectors
|
||||
if (m1->len != m2->len) {
|
||||
mp_raise_ValueError(translate("vectors must have same lengths"));
|
||||
}
|
||||
mp_float_t dot = 0.0;
|
||||
for (size_t i=0; i < m1->len; i++) {
|
||||
dot += func1(array1) * func2(array2);
|
||||
array1 += m1->strides[ULAB_MAX_DIMS - 1];
|
||||
if(m1->shape[ULAB_MAX_DIMS - 1] != m2->shape[ULAB_MAX_DIMS - m2->ndim]) {
|
||||
mp_raise_ValueError(translate("dimensions do not match"));
|
||||
}
|
||||
uint8_t ndim = MIN(m1->ndim, m2->ndim);
|
||||
size_t shape1 = m1->ndim == 2 ? m1->shape[ULAB_MAX_DIMS - m1->ndim] : 1;
|
||||
size_t shape2 = m2->ndim == 2 ? m2->shape[ULAB_MAX_DIMS - 1] : 1;
|
||||
|
||||
size_t *shape = NULL;
|
||||
if(ndim == 2) { // matrix times matrix -> matrix
|
||||
shape = ndarray_shape_vector(0, 0, shape1, shape2);
|
||||
} else { // matrix times vector -> vector, vector times vector -> vector (size 1)
|
||||
shape = ndarray_shape_vector(0, 0, 0, shape1);
|
||||
}
|
||||
ndarray_obj_t *results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_FLOAT);
|
||||
mp_float_t *rarray = (mp_float_t *)results->array;
|
||||
|
||||
for(size_t i=0; i < shape1; i++) { // rows of m1
|
||||
for(size_t j=0; j < shape2; j++) { // columns of m2
|
||||
mp_float_t dot = 0.0;
|
||||
for(size_t k=0; k < m1->shape[ULAB_MAX_DIMS - 1]; k++) {
|
||||
// (i, k) * (k, j)
|
||||
dot += func1(array1) * func2(array2);
|
||||
array1 += m1->strides[ULAB_MAX_DIMS - 1];
|
||||
array2 += m2->strides[ULAB_MAX_DIMS - m2->ndim];
|
||||
}
|
||||
*rarray++ = dot;
|
||||
array1 -= m1->strides[ULAB_MAX_DIMS - 1] * m1->shape[ULAB_MAX_DIMS - 1];
|
||||
array2 -= m2->strides[ULAB_MAX_DIMS - m2->ndim] * m2->shape[ULAB_MAX_DIMS - m2->ndim];
|
||||
array2 += m2->strides[ULAB_MAX_DIMS - 1];
|
||||
}
|
||||
return mp_obj_new_float(dot);
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
array1 += m1->strides[ULAB_MAX_DIMS - m1->ndim];
|
||||
array2 = m2->array;
|
||||
}
|
||||
if((m1->ndim * m2->ndim) == 1) { // return a scalar, if product of two vectors
|
||||
return mp_obj_new_float(*(--rarray));
|
||||
} else {
|
||||
// 2 matrices
|
||||
if(m1->shape[ULAB_MAX_DIMS - 1] != m2->shape[ULAB_MAX_DIMS - 2]) {
|
||||
mp_raise_ValueError(translate("matrix dimensions do not match"));
|
||||
}
|
||||
size_t *shape = ndarray_shape_vector(0, 0, m1->shape[ULAB_MAX_DIMS - 2], m2->shape[ULAB_MAX_DIMS - 1]);
|
||||
ndarray_obj_t *results = ndarray_new_dense_ndarray(2, shape, NDARRAY_FLOAT);
|
||||
mp_float_t *rarray = (mp_float_t *)results->array;
|
||||
|
||||
for(size_t i=0; i < m1->shape[ULAB_MAX_DIMS - 2]; i++) { // rows of m1
|
||||
for(size_t j=0; j < m2->shape[ULAB_MAX_DIMS - 1]; j++) { // columns of m2
|
||||
mp_float_t dot = 0.0;
|
||||
for(size_t k=0; k < m1->shape[ULAB_MAX_DIMS - 1]; k++) {
|
||||
// (i, k) * (k, j)
|
||||
dot += func1(array1) * func2(array2);
|
||||
array1 += m1->strides[ULAB_MAX_DIMS - 1];
|
||||
array2 += m2->strides[ULAB_MAX_DIMS - 2];
|
||||
}
|
||||
*rarray++ = dot;
|
||||
array1 -= m1->strides[ULAB_MAX_DIMS - 1] * m1->shape[ULAB_MAX_DIMS - 1];
|
||||
array2 -= m2->strides[ULAB_MAX_DIMS - 2] * m2->shape[ULAB_MAX_DIMS - 1];
|
||||
array2 += m2->strides[ULAB_MAX_DIMS - 1];
|
||||
}
|
||||
array1 += m1->strides[ULAB_MAX_DIMS - 2];
|
||||
array2 = m2->array;
|
||||
}
|
||||
return MP_OBJ_FROM_PTR(results);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
MP_DEFINE_CONST_FUN_OBJ_2(transform_dot_obj, transform_dot);
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@
|
|||
|
||||
#include "user/user.h"
|
||||
|
||||
#define ULAB_VERSION 2.4.3
|
||||
#define ULAB_VERSION 2.4.5
|
||||
#define xstr(s) str(s)
|
||||
#define str(s) #s
|
||||
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
Tue, 23 Feb 2021
|
||||
|
||||
version 2.4.5
|
||||
|
||||
fix dot function
|
||||
|
||||
Sun, 21 Feb 2021
|
||||
|
||||
version 2.4.3
|
||||
|
|
|
|||
Loading…
Reference in a new issue