implemented second half of diag

This commit is contained in:
Zoltán Vörös 2020-11-02 22:30:15 +01:00
parent e15c13feaa
commit dc4c9d692a
4 changed files with 33 additions and 22 deletions

View file

@ -126,8 +126,8 @@ STATIC const mp_map_elem_t ulab_globals_table[] = {
#if ULAB_CREATE_HAS_CONCATENATE
{ MP_ROM_QSTR(MP_QSTR_concatenate), (mp_obj_t)&create_concatenate_obj },
#endif
#if ULAB_CREATE_HAS_DIAGONAL
{ MP_ROM_QSTR(MP_QSTR_diagonal), (mp_obj_t)&create_diagonal_obj },
#if ULAB_CREATE_HAS_DIAG
{ MP_ROM_QSTR(MP_QSTR_diag), (mp_obj_t)&create_diag_obj },
#endif
#if ULAB_MAX_DIMS > 1
#if ULAB_CREATE_HAS_EYE

View file

@ -123,7 +123,7 @@
// module constant
#define ULAB_CREATE_HAS_ARANGE (1)
#define ULAB_CREATE_HAS_CONCATENATE (1)
#define ULAB_CREATE_HAS_DIAGONAL (1)
#define ULAB_CREATE_HAS_DIAG (1)
#define ULAB_CREATE_HAS_EYE (1)
#define ULAB_CREATE_HAS_FULL (1)
#define ULAB_CREATE_HAS_LINSPACE (1)

View file

@ -271,21 +271,21 @@ mp_obj_t create_concatenate(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
MP_DEFINE_CONST_FUN_OBJ_KW(create_concatenate_obj, 1, create_concatenate);
#endif
#if ULAB_CREATE_HAS_DIAGONAL
//| def diagonal(a: ulab.array, *, offset: int = 0) -> ulab.array:
#if ULAB_CREATE_HAS_DIAG
//| def diag(a: ulab.array, *, k: int = 0) -> ulab.array:
//| """
//| .. param: a
//| an ndarray
//| .. param: offset
//| .. param: k
//| Offset of the diagonal from the main diagonal. Can be positive or negative.
//|
//| Return specified diagonals."""
//| ...
//|
mp_obj_t create_diagonal(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
mp_obj_t create_diag(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = mp_const_none } },
{ MP_QSTR_offset, MP_ARG_KW_ONLY | MP_ARG_INT, { .u_int = 0 } },
{ MP_QSTR_k, MP_ARG_KW_ONLY | MP_ARG_INT, { .u_int = 0 } },
};
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
@ -295,22 +295,33 @@ mp_obj_t create_diagonal(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
mp_raise_TypeError(translate("input must be an ndarray"));
}
ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0].u_obj);
if(source->ndim != 2) {
if(source->ndim == 1) { // return a rank-2 tensor with the prescribed diagonal
ndarray_obj_t *target = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, source->len, source->len), source->dtype);
uint8_t *sarray = (uint8_t *)source->array;
uint8_t *tarray = (uint8_t *)target->array;
for(size_t i=0; i < source->len; i++) {
memcpy(tarray, sarray, source->itemsize);
sarray += source->strides[ULAB_MAX_DIMS - 1];
tarray += (source->len + 1) * target->itemsize;
}
return MP_OBJ_FROM_PTR(target);
}
if(source->ndim > 2) {
mp_raise_TypeError(translate("input must be a tensor of rank 2"));
}
int32_t offset = args[1].u_int;
int32_t k = args[1].u_int;
size_t len = 0;
uint8_t *sarray = (uint8_t *)source->array;
if(offset < 0) { // move the pointer "vertically"
sarray -= offset * source->strides[ULAB_MAX_DIMS - 2];
if(-offset < (int32_t)source->shape[ULAB_MAX_DIMS - 2]) {
len = MIN(source->shape[ULAB_MAX_DIMS - 2] + offset, source->shape[ULAB_MAX_DIMS - 1]);
if(k < 0) { // move the pointer "vertically"
if(-k < (int32_t)source->shape[ULAB_MAX_DIMS - 2]) {
sarray -= k * source->strides[ULAB_MAX_DIMS - 2];
len = MIN(source->shape[ULAB_MAX_DIMS - 2] + k, source->shape[ULAB_MAX_DIMS - 1]);
}
} else { // move the pointer "horizontally"
if(offset < (int32_t)source->shape[ULAB_MAX_DIMS - 1]) {
len = MIN(source->shape[ULAB_MAX_DIMS - 1] - offset, source->shape[ULAB_MAX_DIMS - 2]);
if(k < (int32_t)source->shape[ULAB_MAX_DIMS - 1]) {
sarray += k * source->strides[ULAB_MAX_DIMS - 1];
len = MIN(source->shape[ULAB_MAX_DIMS - 1] - k, source->shape[ULAB_MAX_DIMS - 2]);
}
sarray += offset * source->strides[ULAB_MAX_DIMS - 1];
}
if(len == 0) {
@ -329,8 +340,8 @@ mp_obj_t create_diagonal(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
return MP_OBJ_FROM_PTR(target);
}
MP_DEFINE_CONST_FUN_OBJ_KW(create_diagonal_obj, 1, create_diagonal);
#endif /* ULAB_CREATE_HAS_DIAGONAL */
MP_DEFINE_CONST_FUN_OBJ_KW(create_diag_obj, 1, create_diag);
#endif /* ULAB_CREATE_HAS_DIAG */
#if ULAB_MAX_DIMS > 1
#if ULAB_CREATE_HAS_EYE

View file

@ -25,9 +25,9 @@ mp_obj_t create_concatenate(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_concatenate_obj);
#endif
#if ULAB_CREATE_HAS_DIAGONAL
mp_obj_t create_diagonal(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_diagonal_obj);
#if ULAB_CREATE_HAS_DIAG
mp_obj_t create_diag(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_diag_obj);
#endif
#if ULAB_MAX_DIMS > 1