implement axis keyword of transpose
This commit is contained in:
parent
068da5fc96
commit
261e606fc7
3 changed files with 100 additions and 11 deletions
|
|
@ -1874,28 +1874,110 @@ mp_obj_t ndarray_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
|
|||
#endif /* NDARRAY_HAS_UNARY_OPS */
|
||||
|
||||
#if NDARRAY_HAS_TRANSPOSE
|
||||
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
|
||||
#if ULAB_MAX_DIMS == 1
|
||||
// We have to implement the T property separately, for the property can't take keyword arguments
|
||||
|
||||
#if ULAB_MAX_DIMS == 1
|
||||
// isolating the one-dimensional case saves space, because the transpose is sort of meaningless
|
||||
mp_obj_t ndarray_T(mp_obj_t self_in) {
|
||||
return self_in;
|
||||
#endif
|
||||
// TODO: check, what happens to the offset here, if we have a view
|
||||
}
|
||||
#else
|
||||
mp_obj_t ndarray_T(mp_obj_t self_in) {
|
||||
// without argument, simply return a view with axes in reverse order
|
||||
ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in);
|
||||
if(self->ndim == 1) {
|
||||
return self_in;
|
||||
}
|
||||
size_t *shape = m_new(size_t, self->ndim);
|
||||
int32_t *strides = m_new(int32_t, self->ndim);
|
||||
for(uint8_t i=0; i < self->ndim; i++) {
|
||||
for(uint8_t i = 0; i < self->ndim; i++) {
|
||||
shape[ULAB_MAX_DIMS - 1 - i] = self->shape[ULAB_MAX_DIMS - self->ndim + i];
|
||||
strides[ULAB_MAX_DIMS - 1 - i] = self->strides[ULAB_MAX_DIMS - self->ndim + i];
|
||||
}
|
||||
// TODO: I am not sure ndarray_new_view is OK here...
|
||||
// should be deep copy...
|
||||
ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
|
||||
return MP_OBJ_FROM_PTR(ndarray);
|
||||
}
|
||||
#endif /* ULAB_MAX_DIMS == 1 */
|
||||
|
||||
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_T_obj, ndarray_T);
|
||||
|
||||
# if ULAB_MAX_DIMS == 1
|
||||
// again, nothing to do, if there is only one dimension, though, the arguments might still have to be parsed...
|
||||
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
|
||||
return self_in;
|
||||
}
|
||||
|
||||
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
|
||||
#else
|
||||
mp_obj_t ndarray_transpose(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
|
||||
static const mp_arg_t allowed_args[] = {
|
||||
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
|
||||
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
|
||||
};
|
||||
|
||||
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
|
||||
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
|
||||
|
||||
ndarray_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj);
|
||||
|
||||
if(self->ndim == 1) {
|
||||
return args[0].u_obj;
|
||||
}
|
||||
|
||||
size_t *shape = m_new(size_t, self->ndim);
|
||||
int32_t *strides = m_new(int32_t, self->ndim);
|
||||
uint8_t *order = m_new(uint8_t, self->ndim);
|
||||
|
||||
mp_obj_t axis = args[1].u_obj;
|
||||
|
||||
if(axis == mp_const_none) {
|
||||
// simply swap the order of the axes
|
||||
for(uint8_t i = 0; i < self->ndim; i++) {
|
||||
order[i] = self->ndim - 1 - i;
|
||||
}
|
||||
} else {
|
||||
if(!mp_obj_is_type(axis, &mp_type_tuple)) {
|
||||
mp_raise_TypeError(MP_ERROR_TEXT("keyword argument must be tuple of integers"));
|
||||
}
|
||||
// start with the straight array, and then swap only those specified in the argument
|
||||
for(uint8_t i = 0; i < self->ndim; i++) {
|
||||
order[i] = i;
|
||||
}
|
||||
|
||||
mp_obj_tuple_t *axes = MP_OBJ_TO_PTR(axis);
|
||||
|
||||
if(axes->len > self->ndim) {
|
||||
mp_raise_ValueError(MP_ERROR_TEXT("too many axes specified"));
|
||||
}
|
||||
|
||||
for(uint8_t i = 0; i < axes->len; i++) {
|
||||
int32_t ax = mp_obj_get_int(axes->items[i]);
|
||||
if((ax >= self->ndim) || (ax < 0)) {
|
||||
mp_raise_ValueError(MP_ERROR_TEXT("axis index out of bounds"));
|
||||
} else {
|
||||
order[i] = (uint8_t)ax;
|
||||
// TODO: check that no two identical numbers appear in the tuple
|
||||
for(uint8_t j = 0; j < i; j++) {
|
||||
if(order[i] == order[j]) {
|
||||
mp_raise_ValueError(MP_ERROR_TEXT("repeated indices"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t axis_offset = ULAB_MAX_DIMS - self->ndim;
|
||||
for(uint8_t i = 0; i < self->ndim; i++) {
|
||||
shape[axis_offset + i] = self->shape[axis_offset + order[i]];
|
||||
strides[axis_offset + i] = self->strides[axis_offset + order[i]];
|
||||
}
|
||||
|
||||
ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
|
||||
return MP_OBJ_FROM_PTR(ndarray);
|
||||
}
|
||||
|
||||
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
|
||||
MP_DEFINE_CONST_FUN_OBJ_KW(ndarray_transpose_obj, 1, ndarray_transpose);
|
||||
#endif /* ULAB_MAX_DIMS == 1 */
|
||||
#endif /* NDARRAY_HAS_TRANSPOSE */
|
||||
|
||||
#if ULAB_MAX_DIMS > 1
|
||||
|
|
|
|||
|
|
@ -265,9 +265,16 @@ MP_DECLARE_CONST_FUN_OBJ_1(ndarray_tolist_obj);
|
|||
#endif
|
||||
|
||||
#if NDARRAY_HAS_TRANSPOSE
|
||||
mp_obj_t ndarray_T(mp_obj_t );
|
||||
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_T_obj);
|
||||
#if ULAB_MAX_DIMS == 1
|
||||
mp_obj_t ndarray_transpose(mp_obj_t );
|
||||
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_transpose_obj);
|
||||
#endif
|
||||
#else
|
||||
mp_obj_t ndarray_transpose(size_t , const mp_obj_t *, mp_map_t *);
|
||||
MP_DECLARE_CONST_FUN_OBJ_KW(ndarray_transpose_obj);
|
||||
#endif /* ULAB_MAX_DIMS == 1 */
|
||||
#endif /* NDARRAY_HAS_TRANSPOSE */
|
||||
|
||||
#if ULAB_NUMPY_HAS_NDINFO
|
||||
mp_obj_t ndarray_info(mp_obj_t );
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ void ndarray_properties_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
|
|||
#endif
|
||||
#if NDARRAY_HAS_TRANSPOSE
|
||||
case MP_QSTR_T:
|
||||
dest[0] = ndarray_transpose(self_in);
|
||||
dest[0] = ndarray_T(self_in);
|
||||
break;
|
||||
#endif
|
||||
#if ULAB_SUPPORTS_COMPLEX
|
||||
|
|
|
|||
Loading…
Reference in a new issue