implement axis keyword of transpose

This commit is contained in:
Zoltán Vörös 2025-08-15 21:00:53 +02:00
parent 068da5fc96
commit 261e606fc7
3 changed files with 100 additions and 11 deletions

View file

@ -1874,11 +1874,16 @@ mp_obj_t ndarray_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
#endif /* NDARRAY_HAS_UNARY_OPS */ #endif /* NDARRAY_HAS_UNARY_OPS */
#if NDARRAY_HAS_TRANSPOSE #if NDARRAY_HAS_TRANSPOSE
mp_obj_t ndarray_transpose(mp_obj_t self_in) { // We have to implement the T property separately, for the property can't take keyword arguments
#if ULAB_MAX_DIMS == 1 #if ULAB_MAX_DIMS == 1
// isolating the one-dimensional case saves space, because the transpose is sort of meaningless
mp_obj_t ndarray_T(mp_obj_t self_in) {
return self_in; return self_in;
#endif }
// TODO: check, what happens to the offset here, if we have a view #else
mp_obj_t ndarray_T(mp_obj_t self_in) {
// without argument, simply return a view with axes in reverse order
ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in); ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in);
if(self->ndim == 1) { if(self->ndim == 1) {
return self_in; return self_in;
@ -1889,13 +1894,90 @@ mp_obj_t ndarray_transpose(mp_obj_t self_in) {
shape[ULAB_MAX_DIMS - 1 - i] = self->shape[ULAB_MAX_DIMS - self->ndim + i]; shape[ULAB_MAX_DIMS - 1 - i] = self->shape[ULAB_MAX_DIMS - self->ndim + i];
strides[ULAB_MAX_DIMS - 1 - i] = self->strides[ULAB_MAX_DIMS - self->ndim + i]; strides[ULAB_MAX_DIMS - 1 - i] = self->strides[ULAB_MAX_DIMS - self->ndim + i];
} }
// TODO: I am not sure ndarray_new_view is OK here... ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
// should be deep copy... return MP_OBJ_FROM_PTR(ndarray);
}
#endif /* ULAB_MAX_DIMS == 1 */
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_T_obj, ndarray_T);
# if ULAB_MAX_DIMS == 1
// again, nothing to do, if there is only one dimension, though, the arguments might still have to be parsed...
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
return self_in;
}
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
#else
mp_obj_t ndarray_transpose(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
};
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
ndarray_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj);
if(self->ndim == 1) {
return args[0].u_obj;
}
size_t *shape = m_new(size_t, self->ndim);
int32_t *strides = m_new(int32_t, self->ndim);
uint8_t *order = m_new(uint8_t, self->ndim);
mp_obj_t axis = args[1].u_obj;
if(axis == mp_const_none) {
// simply swap the order of the axes
for(uint8_t i = 0; i < self->ndim; i++) {
order[i] = self->ndim - 1 - i;
}
} else {
if(!mp_obj_is_type(axis, &mp_type_tuple)) {
mp_raise_TypeError(MP_ERROR_TEXT("keyword argument must be tuple of integers"));
}
// start with the straight array, and then swap only those specified in the argument
for(uint8_t i = 0; i < self->ndim; i++) {
order[i] = i;
}
mp_obj_tuple_t *axes = MP_OBJ_TO_PTR(axis);
if(axes->len > self->ndim) {
mp_raise_ValueError(MP_ERROR_TEXT("too many axes specified"));
}
for(uint8_t i = 0; i < axes->len; i++) {
int32_t ax = mp_obj_get_int(axes->items[i]);
if((ax >= self->ndim) || (ax < 0)) {
mp_raise_ValueError(MP_ERROR_TEXT("axis index out of bounds"));
} else {
order[i] = (uint8_t)ax;
// TODO: check that no two identical numbers appear in the tuple
for(uint8_t j = 0; j < i; j++) {
if(order[i] == order[j]) {
mp_raise_ValueError(MP_ERROR_TEXT("repeated indices"));
}
}
}
}
}
uint8_t axis_offset = ULAB_MAX_DIMS - self->ndim;
for(uint8_t i = 0; i < self->ndim; i++) {
shape[axis_offset + i] = self->shape[axis_offset + order[i]];
strides[axis_offset + i] = self->strides[axis_offset + order[i]];
}
ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0); ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
return MP_OBJ_FROM_PTR(ndarray); return MP_OBJ_FROM_PTR(ndarray);
} }
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose); MP_DEFINE_CONST_FUN_OBJ_KW(ndarray_transpose_obj, 1, ndarray_transpose);
#endif /* ULAB_MAX_DIMS == 1 */
#endif /* NDARRAY_HAS_TRANSPOSE */ #endif /* NDARRAY_HAS_TRANSPOSE */
#if ULAB_MAX_DIMS > 1 #if ULAB_MAX_DIMS > 1

View file

@ -265,9 +265,16 @@ MP_DECLARE_CONST_FUN_OBJ_1(ndarray_tolist_obj);
#endif #endif
#if NDARRAY_HAS_TRANSPOSE #if NDARRAY_HAS_TRANSPOSE
mp_obj_t ndarray_T(mp_obj_t );
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_T_obj);
#if ULAB_MAX_DIMS == 1
mp_obj_t ndarray_transpose(mp_obj_t ); mp_obj_t ndarray_transpose(mp_obj_t );
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_transpose_obj); MP_DECLARE_CONST_FUN_OBJ_1(ndarray_transpose_obj);
#endif #else
mp_obj_t ndarray_transpose(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(ndarray_transpose_obj);
#endif /* ULAB_MAX_DIMS == 1 */
#endif /* NDARRAY_HAS_TRANSPOSE */
#if ULAB_NUMPY_HAS_NDINFO #if ULAB_NUMPY_HAS_NDINFO
mp_obj_t ndarray_info(mp_obj_t ); mp_obj_t ndarray_info(mp_obj_t );

View file

@ -64,7 +64,7 @@ void ndarray_properties_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
#endif #endif
#if NDARRAY_HAS_TRANSPOSE #if NDARRAY_HAS_TRANSPOSE
case MP_QSTR_T: case MP_QSTR_T:
dest[0] = ndarray_transpose(self_in); dest[0] = ndarray_T(self_in);
break; break;
#endif #endif
#if ULAB_SUPPORTS_COMPLEX #if ULAB_SUPPORTS_COMPLEX