implemented second half of diag
This commit is contained in:
parent
e15c13feaa
commit
dc4c9d692a
4 changed files with 33 additions and 22 deletions
|
|
@ -126,8 +126,8 @@ STATIC const mp_map_elem_t ulab_globals_table[] = {
|
|||
#if ULAB_CREATE_HAS_CONCATENATE
|
||||
{ MP_ROM_QSTR(MP_QSTR_concatenate), (mp_obj_t)&create_concatenate_obj },
|
||||
#endif
|
||||
#if ULAB_CREATE_HAS_DIAGONAL
|
||||
{ MP_ROM_QSTR(MP_QSTR_diagonal), (mp_obj_t)&create_diagonal_obj },
|
||||
#if ULAB_CREATE_HAS_DIAG
|
||||
{ MP_ROM_QSTR(MP_QSTR_diag), (mp_obj_t)&create_diag_obj },
|
||||
#endif
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
#if ULAB_CREATE_HAS_EYE
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@
|
|||
// module constant
|
||||
#define ULAB_CREATE_HAS_ARANGE (1)
|
||||
#define ULAB_CREATE_HAS_CONCATENATE (1)
|
||||
#define ULAB_CREATE_HAS_DIAGONAL (1)
|
||||
#define ULAB_CREATE_HAS_DIAG (1)
|
||||
#define ULAB_CREATE_HAS_EYE (1)
|
||||
#define ULAB_CREATE_HAS_FULL (1)
|
||||
#define ULAB_CREATE_HAS_LINSPACE (1)
|
||||
|
|
|
|||
|
|
@ -271,21 +271,21 @@ mp_obj_t create_concatenate(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
|
|||
MP_DEFINE_CONST_FUN_OBJ_KW(create_concatenate_obj, 1, create_concatenate);
|
||||
#endif
|
||||
|
||||
#if ULAB_CREATE_HAS_DIAGONAL
|
||||
//| def diagonal(a: ulab.array, *, offset: int = 0) -> ulab.array:
|
||||
#if ULAB_CREATE_HAS_DIAG
|
||||
//| def diag(a: ulab.array, *, k: int = 0) -> ulab.array:
|
||||
//| """
|
||||
//| .. param: a
|
||||
//| an ndarray
|
||||
//| .. param: offset
|
||||
//| .. param: k
|
||||
//| Offset of the diagonal from the main diagonal. Can be positive or negative.
|
||||
//|
|
||||
//| Return specified diagonals."""
|
||||
//| ...
|
||||
//|
|
||||
mp_obj_t create_diagonal(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||
mp_obj_t create_diag(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||
static const mp_arg_t allowed_args[] = {
|
||||
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = mp_const_none } },
|
||||
{ MP_QSTR_offset, MP_ARG_KW_ONLY | MP_ARG_INT, { .u_int = 0 } },
|
||||
{ MP_QSTR_k, MP_ARG_KW_ONLY | MP_ARG_INT, { .u_int = 0 } },
|
||||
};
|
||||
|
||||
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||
|
|
@ -295,22 +295,33 @@ mp_obj_t create_diagonal(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
|
|||
mp_raise_TypeError(translate("input must be an ndarray"));
|
||||
}
|
||||
ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0].u_obj);
|
||||
if(source->ndim != 2) {
|
||||
if(source->ndim == 1) { // return a rank-2 tensor with the prescribed diagonal
|
||||
ndarray_obj_t *target = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, source->len, source->len), source->dtype);
|
||||
uint8_t *sarray = (uint8_t *)source->array;
|
||||
uint8_t *tarray = (uint8_t *)target->array;
|
||||
for(size_t i=0; i < source->len; i++) {
|
||||
memcpy(tarray, sarray, source->itemsize);
|
||||
sarray += source->strides[ULAB_MAX_DIMS - 1];
|
||||
tarray += (source->len + 1) * target->itemsize;
|
||||
}
|
||||
return MP_OBJ_FROM_PTR(target);
|
||||
}
|
||||
if(source->ndim > 2) {
|
||||
mp_raise_TypeError(translate("input must be a tensor of rank 2"));
|
||||
}
|
||||
int32_t offset = args[1].u_int;
|
||||
int32_t k = args[1].u_int;
|
||||
size_t len = 0;
|
||||
uint8_t *sarray = (uint8_t *)source->array;
|
||||
if(offset < 0) { // move the pointer "vertically"
|
||||
sarray -= offset * source->strides[ULAB_MAX_DIMS - 2];
|
||||
if(-offset < (int32_t)source->shape[ULAB_MAX_DIMS - 2]) {
|
||||
len = MIN(source->shape[ULAB_MAX_DIMS - 2] + offset, source->shape[ULAB_MAX_DIMS - 1]);
|
||||
if(k < 0) { // move the pointer "vertically"
|
||||
if(-k < (int32_t)source->shape[ULAB_MAX_DIMS - 2]) {
|
||||
sarray -= k * source->strides[ULAB_MAX_DIMS - 2];
|
||||
len = MIN(source->shape[ULAB_MAX_DIMS - 2] + k, source->shape[ULAB_MAX_DIMS - 1]);
|
||||
}
|
||||
} else { // move the pointer "horizontally"
|
||||
if(offset < (int32_t)source->shape[ULAB_MAX_DIMS - 1]) {
|
||||
len = MIN(source->shape[ULAB_MAX_DIMS - 1] - offset, source->shape[ULAB_MAX_DIMS - 2]);
|
||||
if(k < (int32_t)source->shape[ULAB_MAX_DIMS - 1]) {
|
||||
sarray += k * source->strides[ULAB_MAX_DIMS - 1];
|
||||
len = MIN(source->shape[ULAB_MAX_DIMS - 1] - k, source->shape[ULAB_MAX_DIMS - 2]);
|
||||
}
|
||||
sarray += offset * source->strides[ULAB_MAX_DIMS - 1];
|
||||
}
|
||||
|
||||
if(len == 0) {
|
||||
|
|
@ -329,8 +340,8 @@ mp_obj_t create_diagonal(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_a
|
|||
return MP_OBJ_FROM_PTR(target);
|
||||
}
|
||||
|
||||
MP_DEFINE_CONST_FUN_OBJ_KW(create_diagonal_obj, 1, create_diagonal);
|
||||
#endif /* ULAB_CREATE_HAS_DIAGONAL */
|
||||
MP_DEFINE_CONST_FUN_OBJ_KW(create_diag_obj, 1, create_diag);
|
||||
#endif /* ULAB_CREATE_HAS_DIAG */
|
||||
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
#if ULAB_CREATE_HAS_EYE
|
||||
|
|
|
|||
|
|
@ -25,9 +25,9 @@ mp_obj_t create_concatenate(size_t , const mp_obj_t *, mp_map_t *);
|
|||
MP_DECLARE_CONST_FUN_OBJ_KW(create_concatenate_obj);
|
||||
#endif
|
||||
|
||||
#if ULAB_CREATE_HAS_DIAGONAL
|
||||
mp_obj_t create_diagonal(size_t , const mp_obj_t *, mp_map_t *);
|
||||
MP_DECLARE_CONST_FUN_OBJ_KW(create_diagonal_obj);
|
||||
#if ULAB_CREATE_HAS_DIAG
|
||||
mp_obj_t create_diag(size_t , const mp_obj_t *, mp_map_t *);
|
||||
MP_DECLARE_CONST_FUN_OBJ_KW(create_diag_obj);
|
||||
#endif
|
||||
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue