fix in-place assignment from slices (#524)

* fix in-place assignment from slices
This commit is contained in:
Zoltán Vörös 2022-05-17 21:25:20 +02:00 committed by GitHub
parent 53bc8d6b0e
commit d438344943
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 223 deletions

View file

@ -37,7 +37,7 @@ readlinkf_posix() {
}
NPROC=$(python -c 'import multiprocessing; print(multiprocessing.cpu_count())')
HERE="$(dirname -- "$(readlinkf_posix -- "${0}")" )"
[ -e circuitpython/py/py.mk ] || (git clone --no-recurse-submodules --depth 100 --branch main https://github.com/adafruit/circuitpython && cd circuitpython && git submodule update --init lib/uzlib tools)
[ -e circuitpython/py/py.mk ] || (git clone --branch main https://github.com/adafruit/circuitpython && cd circuitpython && make fetch-submodules && git submodule update --init lib/uzlib tools)
rm -rf circuitpython/extmod/ulab; ln -s "$HERE" circuitpython/extmod/ulab
dims=${1-2}
make -C circuitpython/mpy-cross -j$NPROC

View file

@ -1187,139 +1187,66 @@ void ndarray_assign_view(ndarray_obj_t *view, ndarray_obj_t *values) {
int32_t *rstrides = m_new0(int32_t, ULAB_MAX_DIMS);
if(!ndarray_can_broadcast(view, values, &ndim, shape, lstrides, rstrides)) {
mp_raise_ValueError(translate("operands could not be broadcast together"));
m_del(size_t, shape, ULAB_MAX_DIMS);
m_del(int32_t, lstrides, ULAB_MAX_DIMS);
m_del(int32_t, rstrides, ULAB_MAX_DIMS);
}
} else {
uint8_t *rarray = (uint8_t *)values->array;
ndarray_obj_t *ndarray = ndarray_copy_view_convert_type(values, view->dtype);
// re-calculate rstrides, since the copy operation might have changed the directions of the strides
ndarray_can_broadcast(view, ndarray, &ndim, shape, lstrides, rstrides);
uint8_t *rarray = (uint8_t *)ndarray->array;
#if ULAB_SUPPORTS_COMPLEX
if(values->dtype == NDARRAY_COMPLEX) {
if(view->dtype != NDARRAY_COMPLEX) {
mp_raise_TypeError(translate("cannot convert complex to dtype"));
} else {
uint8_t *larray = (uint8_t *)view->array;
#if ULAB_MAX_DIMS > 3
size_t i = 0;
uint8_t *larray = (uint8_t *)view->array;
#if ULAB_MAX_DIMS > 3
size_t i = 0;
do {
#endif
#if ULAB_MAX_DIMS > 2
size_t j = 0;
do {
#endif
#if ULAB_MAX_DIMS > 2
size_t j = 0;
#if ULAB_MAX_DIMS > 1
size_t k = 0;
do {
#endif
#if ULAB_MAX_DIMS > 1
size_t k = 0;
size_t l = 0;
do {
#endif
size_t l = 0;
do {
memcpy(larray, rarray, view->itemsize);
larray += lstrides[ULAB_MAX_DIMS - 1];
rarray += rstrides[ULAB_MAX_DIMS - 1];
l++;
} while(l < view->shape[ULAB_MAX_DIMS - 1]);
#if ULAB_MAX_DIMS > 1
larray -= lstrides[ULAB_MAX_DIMS - 1] * view->shape[ULAB_MAX_DIMS-1];
larray += lstrides[ULAB_MAX_DIMS - 2];
rarray -= rstrides[ULAB_MAX_DIMS - 1] * view->shape[ULAB_MAX_DIMS-1];
rarray += rstrides[ULAB_MAX_DIMS - 2];
k++;
} while(k < view->shape[ULAB_MAX_DIMS - 2]);
#endif
#if ULAB_MAX_DIMS > 2
larray -= lstrides[ULAB_MAX_DIMS - 2] * view->shape[ULAB_MAX_DIMS-2];
larray += lstrides[ULAB_MAX_DIMS - 3];
rarray -= rstrides[ULAB_MAX_DIMS - 2] * view->shape[ULAB_MAX_DIMS-2];
rarray += rstrides[ULAB_MAX_DIMS - 3];
j++;
} while(j < view->shape[ULAB_MAX_DIMS - 3]);
memcpy(larray, rarray, view->itemsize);
larray += lstrides[ULAB_MAX_DIMS - 1];
rarray += rstrides[ULAB_MAX_DIMS - 1];
l++;
} while(l < view->shape[ULAB_MAX_DIMS - 1]);
#if ULAB_MAX_DIMS > 1
larray -= lstrides[ULAB_MAX_DIMS - 1] * view->shape[ULAB_MAX_DIMS-1];
larray += lstrides[ULAB_MAX_DIMS - 2];
rarray -= rstrides[ULAB_MAX_DIMS - 1] * view->shape[ULAB_MAX_DIMS-1];
rarray += rstrides[ULAB_MAX_DIMS - 2];
k++;
} while(k < view->shape[ULAB_MAX_DIMS - 2]);
#endif
#if ULAB_MAX_DIMS > 3
larray -= lstrides[ULAB_MAX_DIMS - 3] * view->shape[ULAB_MAX_DIMS-3];
larray += lstrides[ULAB_MAX_DIMS - 4];
rarray -= rstrides[ULAB_MAX_DIMS - 3] * view->shape[ULAB_MAX_DIMS-3];
rarray += rstrides[ULAB_MAX_DIMS - 4];
i++;
} while(i < view->shape[ULAB_MAX_DIMS - 4]);
#if ULAB_MAX_DIMS > 2
larray -= lstrides[ULAB_MAX_DIMS - 2] * view->shape[ULAB_MAX_DIMS-2];
larray += lstrides[ULAB_MAX_DIMS - 3];
rarray -= rstrides[ULAB_MAX_DIMS - 2] * view->shape[ULAB_MAX_DIMS-2];
rarray += rstrides[ULAB_MAX_DIMS - 3];
j++;
} while(j < view->shape[ULAB_MAX_DIMS - 3]);
#endif
}
return;
}
#endif
// since in ASSIGNMENT_LOOP the array has a type, we have to divide the strides by the itemsize
for(uint8_t i=0; i < ULAB_MAX_DIMS; i++) {
lstrides[i] /= view->itemsize;
#if ULAB_SUPPORTS_COMPLEX
if(view->dtype == NDARRAY_COMPLEX) {
lstrides[i] *= 2;
}
#if ULAB_MAX_DIMS > 3
larray -= lstrides[ULAB_MAX_DIMS - 3] * view->shape[ULAB_MAX_DIMS-3];
larray += lstrides[ULAB_MAX_DIMS - 4];
rarray -= rstrides[ULAB_MAX_DIMS - 3] * view->shape[ULAB_MAX_DIMS-3];
rarray += rstrides[ULAB_MAX_DIMS - 4];
i++;
} while(i < view->shape[ULAB_MAX_DIMS - 4]);
#endif
}
if(view->dtype == NDARRAY_UINT8) {
if(values->dtype == NDARRAY_UINT8) {
ASSIGNMENT_LOOP(view, uint8_t, uint8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT8) {
ASSIGNMENT_LOOP(view, uint8_t, int8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_UINT16) {
ASSIGNMENT_LOOP(view, uint8_t, uint16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT16) {
ASSIGNMENT_LOOP(view, uint8_t, int16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_FLOAT) {
ASSIGNMENT_LOOP(view, uint8_t, mp_float_t, lstrides, rarray, rstrides);
}
} else if(view->dtype == NDARRAY_INT8) {
if(values->dtype == NDARRAY_UINT8) {
ASSIGNMENT_LOOP(view, int8_t, uint8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT8) {
ASSIGNMENT_LOOP(view, int8_t, int8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_UINT16) {
ASSIGNMENT_LOOP(view, int8_t, uint16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT16) {
ASSIGNMENT_LOOP(view, int8_t, int16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_FLOAT) {
ASSIGNMENT_LOOP(view, int8_t, mp_float_t, lstrides, rarray, rstrides);
}
} else if(view->dtype == NDARRAY_UINT16) {
if(values->dtype == NDARRAY_UINT8) {
ASSIGNMENT_LOOP(view, uint16_t, uint8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT8) {
ASSIGNMENT_LOOP(view, uint16_t, int8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_UINT16) {
ASSIGNMENT_LOOP(view, uint16_t, uint16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT16) {
ASSIGNMENT_LOOP(view, uint16_t, int16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_FLOAT) {
ASSIGNMENT_LOOP(view, uint16_t, mp_float_t, lstrides, rarray, rstrides);
}
} else if(view->dtype == NDARRAY_INT16) {
if(values->dtype == NDARRAY_UINT8) {
ASSIGNMENT_LOOP(view, int16_t, uint8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT8) {
ASSIGNMENT_LOOP(view, int16_t, int8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_UINT16) {
ASSIGNMENT_LOOP(view, int16_t, uint16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT16) {
ASSIGNMENT_LOOP(view, int16_t, int16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_FLOAT) {
ASSIGNMENT_LOOP(view, int16_t, mp_float_t, lstrides, rarray, rstrides);
}
} else { // the dtype must be an mp_float_t or complex now
if(values->dtype == NDARRAY_UINT8) {
ASSIGNMENT_LOOP(view, mp_float_t, uint8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT8) {
ASSIGNMENT_LOOP(view, mp_float_t, int8_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_UINT16) {
ASSIGNMENT_LOOP(view, mp_float_t, uint16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_INT16) {
ASSIGNMENT_LOOP(view, mp_float_t, int16_t, lstrides, rarray, rstrides);
} else if(values->dtype == NDARRAY_FLOAT) {
ASSIGNMENT_LOOP(view, mp_float_t, mp_float_t, lstrides, rarray, rstrides);
}
}
m_del(size_t, shape, ULAB_MAX_DIMS);
m_del(int32_t, lstrides, ULAB_MAX_DIMS);
m_del(int32_t, rstrides, ULAB_MAX_DIMS);
return;
}
static mp_obj_t ndarray_from_boolean_index(ndarray_obj_t *ndarray, ndarray_obj_t *index) {

View file

@ -646,105 +646,4 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t );
#endif /* ULAB_MAX_DIMS == 4 */
#endif /* ULAB_HAS_FUNCTION_ITERATOR */
#if ULAB_MAX_DIMS == 1
#define ASSIGNMENT_LOOP(results, type_left, type_right, lstrides, rarray, rstrides)\
type_left *larray = (type_left *)(results)->array;\
size_t l = 0;\
do {\
*larray = (type_left)(*((type_right *)(rarray)));\
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
#endif /* ULAB_MAX_DIMS == 1 */
#if ULAB_MAX_DIMS == 2
#define ASSIGNMENT_LOOP(results, type_left, type_right, lstrides, rarray, rstrides)\
type_left *larray = (type_left *)(results)->array;\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
*larray = (type_left)(*((type_right *)(rarray)));\
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
#endif /* ULAB_MAX_DIMS == 2 */
#if ULAB_MAX_DIMS == 3
#define ASSIGNMENT_LOOP(results, type_left, type_right, lstrides, rarray, rstrides)\
type_left *larray = (type_left *)(results)->array;\
size_t j = 0;\
do {\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
*larray = (type_left)(*((type_right *)(rarray)));\
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
j++;\
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
#endif /* ULAB_MAX_DIMS == 3 */
#if ULAB_MAX_DIMS == 4
#define ASSIGNMENT_LOOP(results, type_left, type_right, lstrides, rarray, rstrides)\
type_left *larray = (type_left *)(results)->array;\
size_t i = 0;\
do {\
size_t j = 0;\
do {\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
*larray = (type_left)(*((type_right *)(rarray)));\
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
j++;\
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
(larray) -= (lstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
(larray) += (lstrides)[ULAB_MAX_DIMS - 4];\
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
(rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
i++;\
} while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\
#endif /* ULAB_MAX_DIMS == 4 */
#endif

View file

@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"
#define ULAB_VERSION 5.0.6
#define ULAB_VERSION 5.0.7
#define xstr(s) str(s)
#define str(s) #s

View file

@ -1,4 +1,11 @@
Mon, 16 May 2022
version 5.0.7
fix in-place assignment from slices
Thu, 14 Apr 2022
version 5.0.6
use m_new0 conditionally