implemented binary operators

This commit is contained in:
Zoltán Vörös 2019-09-11 19:51:48 +02:00
parent 5997dd2a5c
commit d9f9a26197
5 changed files with 562 additions and 316 deletions

View file

@ -19,6 +19,6 @@
mp_obj_t linalg_transpose(mp_obj_t ); mp_obj_t linalg_transpose(mp_obj_t );
mp_obj_t linalg_reshape(mp_obj_t , mp_obj_t ); mp_obj_t linalg_reshape(mp_obj_t , mp_obj_t );
mp_obj_t linalg_inv(mp_obj_t ); mp_obj_t linalg_inv(mp_obj_t );
mp_obj_t linalg_multiply(mp_obj_t , mp_obj_t ); mp_obj_t linalg_dot(mp_obj_t , mp_obj_t );
#endif #endif

View file

@ -94,15 +94,15 @@ void ndarray_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t ki
} }
// TODO: print typecode // TODO: print typecode
if(self->data->typecode == NDARRAY_UINT8) { if(self->data->typecode == NDARRAY_UINT8) {
printf(", dtype='uint8')\n"); printf(", dtype='uint8')");
} else if(self->data->typecode == NDARRAY_INT8) { } else if(self->data->typecode == NDARRAY_INT8) {
printf(", dtype='int8')\n"); printf(", dtype='int8')");
} if(self->data->typecode == NDARRAY_UINT16) { } if(self->data->typecode == NDARRAY_UINT16) {
printf(", dtype='uint16')\n"); printf(", dtype='uint16')");
} if(self->data->typecode == NDARRAY_INT16) { } if(self->data->typecode == NDARRAY_INT16) {
printf(", dtype='int16')\n"); printf(", dtype='int16')");
} if(self->data->typecode == NDARRAY_FLOAT) { } if(self->data->typecode == NDARRAY_FLOAT) {
printf(", dtype='float')\n"); printf(", dtype='float')");
} }
} }
@ -346,127 +346,190 @@ mp_obj_t ndarray_rawsize(mp_obj_t self_in) {
} }
// Binary operations // Binary operations
/* STATIC uint8_t upcasting(uint8_t type_left, uint8_t type_right) {
STATIC uint8_t upcasting(ndarray_obj_t lhs, ndarray_obj_t rhs) {
// returns the upcast typecode // returns the upcast typecode
// what we have to establish is, whether either of sides has a type code that is // Now we have to collect 25 cases. Perhaps there is a more elegant solution for this
// 'larger' than the other side if(type_left == type_right) {
uint8_t typecode_l, typecode_r; // 5 cases
switch(lhs->data->typecode) { return type_left;
case 'b': } else if((type_left == NDARRAY_FLOAT) || (type_right == NDARRAY_FLOAT)) {
typecode_l = (0x01 << 0); // 8 cases ('f' AND 'f' has already been accounted for)
case 'B': return NDARRAY_FLOAT;
typecode_l = (0x01 << 1); } else if(((type_left == NDARRAY_UINT8) && (type_right == NDARRAY_INT8)) ||
case 'i': ((type_left == NDARRAY_INT8) && (type_right == NDARRAY_UINT8)) ||
typecode_l = (0x01 << 2); ((type_left == NDARRAY_UINT8) && (type_right == NDARRAY_INT16)) ||
case 'I': ((type_left == NDARRAY_INT16) && (type_right == NDARRAY_UINT8)) ||
typecode_l = (0x01 << 3); ((type_left == NDARRAY_UINT8) && (type_right == NDARRAY_UINT16)) ||
case 'f': ((type_left == NDARRAY_UINT16) && (type_right == NDARRAY_UINT8)) ||
typecode_l = (0x01 << 4); ((type_left == NDARRAY_INT8) && (type_right == NDARRAY_UINT16)) ||
} ((type_left == NDARRAY_UINT16) && (type_right == NDARRAY_INT8)) ) {
switch(rhs->data->typecode) { // 8 cases
case 'b': return NDARRAY_UINT16;
typecode_r = (0x01 << 0); } else if ( ((type_left == NDARRAY_INT8) && (type_right == NDARRAY_INT16)) ||
case 'B': ((type_left == NDARRAY_INT16) && (type_right == NDARRAY_INT8)) ) {
typecode_r = (0x01 << 1); // 2 cases
case 'i': return NDARRAY_INT16;
typecode_r = (0x01 << 2); } else if ( ((type_left == NDARRAY_INT16) && (type_right == NDARRAY_UINT16)) ||
case 'I': ((type_left == NDARRAY_UINT16) && (type_right == NDARRAY_INT16)) ) {
typecode_r = (0x01 << 3); // 2 cases
case 'f': return NDARRAY_FLOAT;
typecode_r = (0x01 << 4);
}
// Now we have to collect 25 cases
if((typecode_l | typecode_r) == (0x01 << 0)) { // 2 cases
return 'b';
} else if((typecode_l | typecode_r) == (0x01 << 1)) { // 2 cases
return 'B';
} else if((typecode_l | typecode_r) == (0x01 << 2)) { // 2 casaes
return 'i';
} else if((typecode_l | typecode_r) == (0x01 << 3)) { // 2 cases
return 'I';
} else if((typecode_l | typecode_r) >= (0x01 << 4)) { // 10 cases
return 'f';
} else if((typecode_l | typecode_r) == ((0x01 << 0) | (0x01 << 1)) {
return 'i';
} }
return NDARRAY_FLOAT; // we are never going to reach this statement, but we have to make the compiler happy
} }
mp_obj_t ulab_ndarray_binary_op_helper(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { mp_obj_t ndarray_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
// TODO: support scalar operations ndarray_obj_t *ol = MP_OBJ_TO_PTR(lhs);
uint8_t typecode;
float value;
// First, the right hand side is a native micropython object, i.e, an integer, or a float
if (MP_OBJ_IS_TYPE(rhs, &mp_type_int) || MP_OBJ_IS_TYPE(rhs, &mp_type_float)) { if (MP_OBJ_IS_TYPE(rhs, &mp_type_int) || MP_OBJ_IS_TYPE(rhs, &mp_type_float)) {
return MP_OBJ_NULL; // op not supported // we have to split the two cases here...
} else if(MP_OBJ_IS_TYPE(rhs, &ulab_ndarray_type)) { if(MP_OBJ_IS_TYPE(rhs, &mp_type_int)) {
// At this point, the operands should have the same shape typecode = upcasting(ol->data->typecode, NDARRAY_INT16);
ndarray_obj_t *ol = MP_OBJ_TO_PTR(lhs); } else {
typecode = upcasting(ol->data->typecode, NDARRAY_FLOAT);
}
if(MP_OBJ_IS_TYPE(rhs, &mp_type_int)) {
value = (float)mp_obj_get_int(rhs);
} else {
value = mp_obj_get_float(rhs);
}
if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_MULTIPLY) ||
(op == MP_BINARY_OP_SUBTRACT) || (op == MP_BINARY_OP_TRUE_DIVIDE)) {
ndarray_obj_t *out = create_new_ndarray(ol->m, ol->n, typecode);
if(op == MP_BINARY_OP_SUBTRACT) value *= -1.0;
if(op == MP_BINARY_OP_TRUE_DIVIDE) value = 1.0/value;
if(typecode == NDARRAY_INT16) {
int16_t *outdata = (int16_t *)out->data->items;
if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_SUBTRACT)) {
for(size_t i=0; i < ol->data->len; i++) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
}
} else if((op == MP_BINARY_OP_MULTIPLY) || (op == MP_BINARY_OP_TRUE_DIVIDE)) {
for(size_t i=0; i < ol->data->len; i++) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
}
}
} else if(typecode == NDARRAY_FLOAT) {
float *outdata = (float *)out->data->items;
if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_SUBTRACT)) {
for(size_t i=0; i < ol->data->len; i++) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
}
} else if((op == MP_BINARY_OP_MULTIPLY) || (op == MP_BINARY_OP_TRUE_DIVIDE)) {
for(size_t i=0; i < ol->data->len; i++) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
}
}
}
return MP_OBJ_FROM_PTR(out);
} else {
return MP_OBJ_NULL; // op not supported
}
} else if(MP_OBJ_IS_TYPE(rhs, &ulab_ndarray_type)) { // next, the ndarray stuff
ndarray_obj_t *or = MP_OBJ_TO_PTR(rhs); ndarray_obj_t *or = MP_OBJ_TO_PTR(rhs);
ndarray_obj_t *array;
if((ol->m != or->m) || (ol->n != or->n)) { if((ol->m != or->m) || (ol->n != or->n)) {
mp_raise_ValueError("operands could not be broadcast together"); mp_raise_ValueError("operands could not be broadcast together");
} }
// do not convert types, if they are identical // At this point, the operands should have the same shape
// do not convert either, if the left hand side is a float typecode = upcasting(or->data->typecode, ol->data->typecode);
if((ol->data->typecode == or->data->typecode) || ol->data->typecode == NDARRAY_FLOAT) { if(op == MP_BINARY_OP_EQUAL) {
array = ulab_ndarray_copy(ol); // Two arrays are equal, if their shape, typecode, and elements are equal
} else { if((ol->m != or->m) || (ol->n != or->n) || (ol->data->typecode != or->data->typecode)) {
// the types are not equal, we have to do some conversion here
if(or->data->typecode == NDARRAY_FLOAT) {
array = ulab_ndarray_copy(ol);
} else if((ol->data->typecode == NDARRAY_INT16) || (or->data->typecode == NDARRAY_INT16)) {
array = create_new_ndarray(ol->m, ol->n, NDARRAY_INT16);
} else if((ol->data->typecode == NDARRAY_UINT16) || (or->data->typecode == NDARRAY_UINT16)) {
array = create_new_ndarray(ol->m, ol->n, NDARRAY_INT16);
}
}
switch(op) {
case MP_BINARY_OP_ADD:
for(size_t i=0; i < ol->data->len; i++) {
}
return MP_OBJ_FROM_PTR(array);
break;
default:
break;
}
}
}
STATIC mp_obj_t ulab_ndarray_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
ndarray_obj_t *ol = MP_OBJ_TO_PTR(lhs);
ndarray_obj_t *or = MP_OBJ_TO_PTR(rhs);
// for in-place operations, we won't need this!!!
ndarray_obj_t *array = ulab_ndarray_copy(ol);
switch (op) {
case MP_BINARY_OP_EQUAL:
if(!MP_OBJ_IS_TYPE(rhs, &ulab_ndarray_type)) {
return mp_const_false; return mp_const_false;
} else { } else {
// Two arrays are equal, if their shape, typecode, and elements are equal size_t i = ol->bytes;
if((ol->m != or->m) || (ol->n != or->n) || (ol->data->typecode != or->data->typecode)) { uint8_t *l = (uint8_t *)ol->data->items;
return mp_const_false; uint8_t *r = (uint8_t *)or->data->items;
} else { while(i) { // At this point, we can simply compare the bytes, the type is irrelevant
size_t i = ol->bytes; if(*l++ != *r++) {
uint8_t *l = (uint8_t *)ol->data->items; return mp_const_false;
uint8_t *r = (uint8_t *)or->data->items; }
while(i) { // At this point, we can simply compare the bytes, the types is irrelevant i--;
if(*l++ != *r++) { }
return mp_const_false; return mp_const_true;
} }
i--; } else if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_SUBTRACT) ||
(op == MP_BINARY_OP_TRUE_DIVIDE) || (op == MP_BINARY_OP_MULTIPLY)) {
// for in-place operations, we won't need this!!!
typecode = upcasting(or->data->typecode, ol->data->typecode);
ndarray_obj_t *out = create_new_ndarray(ol->m, ol->n, typecode);
if(typecode == NDARRAY_UINT8) {
uint8_t *outdata = (uint8_t *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
}
} else if(typecode == NDARRAY_INT8) {
int8_t *outdata = (int8_t *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
}
} else if(typecode == NDARRAY_UINT16) {
uint16_t *outdata = (uint16_t *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
}
} else if(typecode == NDARRAY_INT16) {
int16_t *outdata = (int16_t *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
}
} else if(typecode == NDARRAY_FLOAT) {
float *outdata = (float *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
} }
return mp_const_true;
} }
} }
break; return MP_OBJ_FROM_PTR(out);
case MP_BINARY_OP_ADD: } else {
case MP_BINARY_OP_MULTIPLY: return MP_OBJ_NULL; // op not supported
return MP_OBJ_FROM_PTR(array); }
break; } else {
mp_raise_TypeError("wrong operand type on the right hand side");
default:
return MP_OBJ_NULL; // op not supported
} }
} }
*/

View file

@ -44,10 +44,13 @@ void ndarray_print_row(const mp_print_t *, mp_obj_array_t *, size_t , size_t );
void ndarray_print(const mp_print_t *, mp_obj_t , mp_print_kind_t ); void ndarray_print(const mp_print_t *, mp_obj_t , mp_print_kind_t );
void ndarray_assign_elements(mp_obj_array_t *, mp_obj_t , uint8_t , size_t *); void ndarray_assign_elements(mp_obj_array_t *, mp_obj_t , uint8_t , size_t *);
ndarray_obj_t *create_new_ndarray(size_t , size_t , uint8_t ); ndarray_obj_t *create_new_ndarray(size_t , size_t , uint8_t );
mp_obj_t ndarray_copy(mp_obj_t ); mp_obj_t ndarray_copy(mp_obj_t );
mp_obj_t ndarray_make_new(const mp_obj_type_t *, size_t , size_t , const mp_obj_t *); mp_obj_t ndarray_make_new(const mp_obj_type_t *, size_t , size_t , const mp_obj_t *);
mp_obj_t ndarray_subscr(mp_obj_t , mp_obj_t , mp_obj_t ); mp_obj_t ndarray_subscr(mp_obj_t , mp_obj_t , mp_obj_t );
mp_obj_t ndarray_getiter(mp_obj_t , mp_obj_iter_buf_t *); mp_obj_t ndarray_getiter(mp_obj_t , mp_obj_iter_buf_t *);
mp_obj_t ndarray_binary_op(mp_binary_op_t , mp_obj_t , mp_obj_t );
mp_obj_t ndarray_shape(mp_obj_t ); mp_obj_t ndarray_shape(mp_obj_t );
mp_obj_t ndarray_size(mp_obj_t , mp_obj_t ); mp_obj_t ndarray_size(mp_obj_t , mp_obj_t );

View file

@ -79,11 +79,11 @@ STATIC const mp_rom_map_elem_t ulab_ndarray_locals_dict_table[] = {
// { MP_ROM_QSTR(MP_QSTR_get), MP_ROM_PTR(&ulab_ndarray_get_obj) }, // { MP_ROM_QSTR(MP_QSTR_get), MP_ROM_PTR(&ulab_ndarray_get_obj) },
// { MP_ROM_QSTR(MP_QSTR_dot), MP_ROM_PTR(&ulab_ndarray_dot_obj) }, // { MP_ROM_QSTR(MP_QSTR_dot), MP_ROM_PTR(&ulab_ndarray_dot_obj) },
// class constants // class constants
// { MP_ROM_QSTR(MP_QSTR_uint8), MP_ROM_INT(NDARRAY_UINT8) }, { MP_ROM_QSTR(MP_QSTR_uint8), MP_ROM_INT(NDARRAY_UINT8) },
// { MP_ROM_QSTR(MP_QSTR_int8), MP_ROM_INT(NDARRAY_INT8) }, { MP_ROM_QSTR(MP_QSTR_int8), MP_ROM_INT(NDARRAY_INT8) },
// { MP_ROM_QSTR(MP_QSTR_uint16), MP_ROM_INT(NDARRAY_UINT16) }, { MP_ROM_QSTR(MP_QSTR_uint16), MP_ROM_INT(NDARRAY_UINT16) },
// { MP_ROM_QSTR(MP_QSTR_int16), MP_ROM_INT(NDARRAY_INT16) }, { MP_ROM_QSTR(MP_QSTR_int16), MP_ROM_INT(NDARRAY_INT16) },
// { MP_ROM_QSTR(MP_QSTR_float), MP_ROM_INT(NDARRAY_FLOAT) }, { MP_ROM_QSTR(MP_QSTR_float), MP_ROM_INT(NDARRAY_FLOAT) },
}; };
STATIC MP_DEFINE_CONST_DICT(ulab_ndarray_locals_dict, ulab_ndarray_locals_dict_table); STATIC MP_DEFINE_CONST_DICT(ulab_ndarray_locals_dict, ulab_ndarray_locals_dict_table);
@ -96,7 +96,7 @@ const mp_obj_type_t ulab_ndarray_type = {
.subscr = ndarray_subscr, .subscr = ndarray_subscr,
.getiter = ndarray_getiter, .getiter = ndarray_getiter,
// .unary_op = ndarray_unary_op, // .unary_op = ndarray_unary_op,
// .binary_op = ndarray_binary_op, .binary_op = ndarray_binary_op,
.locals_dict = (mp_obj_dict_t*)&ulab_ndarray_locals_dict, .locals_dict = (mp_obj_dict_t*)&ulab_ndarray_locals_dict,
}; };

View file

@ -384,16 +384,16 @@
"source": [ "source": [
"## Iterators\n", "## Iterators\n",
"\n", "\n",
"Flattened `ndarray` objects can be iterated on:" "`ndarray` objects can be iterated on, and just as in numpy, matrices are iterated along their first axis, and they return `ndarray`s. "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 746, "execution_count": 169,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2019-08-11T16:37:48.301972Z", "end_time": "2019-09-06T18:28:03.881908Z",
"start_time": "2019-08-11T16:37:48.275062Z" "start_time": "2019-09-06T18:28:03.857148Z"
} }
}, },
"outputs": [ "outputs": [
@ -401,14 +401,13 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"1.0\n", "ndarray([[1.0, 2.0, 3.0, 4.0],\n",
"2.0\n", "\t [6.0, 7.0, 8.0, 9.0]], dtype='float')\n",
"3.0\n", "\n",
"4.0\n", "ndarray([1.0, 2.0, 3.0, 4.0], dtype='float')\n",
"6.0\n", "\n",
"7.0\n", "ndarray([6.0, 7.0, 8.0, 9.0], dtype='float')\n",
"8.0\n", "\n",
"9.0\n",
"\n", "\n",
"\n" "\n"
] ]
@ -420,10 +419,153 @@
"from ulab import ndarray\n", "from ulab import ndarray\n",
"\n", "\n",
"a = ndarray([[1, 2, 3, 4], [6, 7, 8, 9]])\n", "a = ndarray([[1, 2, 3, 4], [6, 7, 8, 9]])\n",
"print(a)\n",
"\n", "\n",
"for _a in a: print(_a)" "for _a in a: print(_a)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On the other hand, flat arrays return their elements:"
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-06T18:29:47.050919Z",
"start_time": "2019-09-06T18:29:47.032888Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ndarray([1, 2, 3, 4, 6, 7, 8, 9], dtype='uint8')\n",
"\n",
"1\n",
"2\n",
"3\n",
"4\n",
"6\n",
"7\n",
"8\n",
"9\n",
"\n",
"\n"
]
}
],
"source": [
"%%micropython\n",
"\n",
"from ulab import ndarray\n",
"\n",
"a = ndarray([1, 2, 3, 4, 6, 7, 8, 9], dtype='uint8')\n",
"print(a)\n",
"\n",
"for _a in a: print(_a)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Upcasting"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are some unexpected results in numpy. E.g., it seems that the upcasting happens only *after* the calculation has been carried out. Besides, the sum of a signed and an unsigned character should be an unsigned integer. "
]
},
{
"cell_type": "code",
"execution_count": 235,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-11T17:11:59.818740Z",
"start_time": "2019-09-11T17:11:59.804800Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([145], dtype=int16)"
]
},
"execution_count": 235,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = array([200], dtype=int8)\n",
"b = array([201], dtype=uint8)\n",
"a + b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When in an operation the `dtype` of two arrays is different, the results `dtype` will be decided by upcasting rules: \n",
"\n",
"1. Operations with two `ndarray`s of the same `dtype` preserve their `dtype`, even when the results overflow.\n",
"\n",
"2. if either of the operands is a float, the results is also a float\n",
"\n",
"3. \n",
" - `uint8` + `int8` => `uint16`, \n",
" - `uint8` + `int16` => `uint16`\n",
" - `uint8` + `uint16` => `uint16`\n",
" \n",
" - `int8` + `int16` => `int16`\n",
" - `int8` + `uint16` => `uint16`\n",
"\n",
" - `uint16` + `int16` => `float`\n",
" \n",
"4. When the right hand side of a binary operator is a micropython variable, `mp_obj_int`, or `mp_obj_float`, then the result will be promoted to `dtype` `float`. This is necessary, because a micropython integer can be 31 bites wide.\n",
"\n",
"Note that the rules of `numpy` are not very consistent: while upcasting is meant to preserve the accuracy of the computation, the sum of an `int8`, and a `uint8` is an `int16`. \n",
"\n",
"`numpy` is also inconsistent in how it represents `dtype`: as an argument, it is denoted by the constants `int8`, `uint8`, etc., while a string will be returned, if the user asks for the type of an array."
]
},
{
"cell_type": "code",
"execution_count": 222,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-10T18:19:51.449499Z",
"start_time": "2019-09-10T18:19:51.441482Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([6., 7., 8.])"
]
},
"execution_count": 222,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = array([1, 2, 3], dtype=uint8)\n",
"b = a + 5.0\n",
"b"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -433,11 +575,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 255,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2019-09-05T19:48:52.953296Z", "end_time": "2019-09-11T17:38:05.528629Z",
"start_time": "2019-09-05T19:48:52.946225Z" "start_time": "2019-09-11T17:38:05.522729Z"
} }
}, },
"outputs": [ "outputs": [
@ -445,7 +587,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"written 1522 bytes to ndarray.h\n" "written 1591 bytes to ndarray.h\n"
] ]
} }
], ],
@ -488,10 +630,13 @@
"void ndarray_print(const mp_print_t *, mp_obj_t , mp_print_kind_t );\n", "void ndarray_print(const mp_print_t *, mp_obj_t , mp_print_kind_t );\n",
"void ndarray_assign_elements(mp_obj_array_t *, mp_obj_t , uint8_t , size_t *);\n", "void ndarray_assign_elements(mp_obj_array_t *, mp_obj_t , uint8_t , size_t *);\n",
"ndarray_obj_t *create_new_ndarray(size_t , size_t , uint8_t );\n", "ndarray_obj_t *create_new_ndarray(size_t , size_t , uint8_t );\n",
"\n",
"mp_obj_t ndarray_copy(mp_obj_t );\n", "mp_obj_t ndarray_copy(mp_obj_t );\n",
"mp_obj_t ndarray_make_new(const mp_obj_type_t *, size_t , size_t , const mp_obj_t *);\n", "mp_obj_t ndarray_make_new(const mp_obj_type_t *, size_t , size_t , const mp_obj_t *);\n",
"mp_obj_t ndarray_subscr(mp_obj_t , mp_obj_t , mp_obj_t );\n", "mp_obj_t ndarray_subscr(mp_obj_t , mp_obj_t , mp_obj_t );\n",
"mp_obj_t ndarray_getiter(mp_obj_t , mp_obj_iter_buf_t *);\n", "mp_obj_t ndarray_getiter(mp_obj_t , mp_obj_iter_buf_t *);\n",
"mp_obj_t ndarray_binary_op(mp_binary_op_t , mp_obj_t , mp_obj_t );\n",
"\n",
"\n", "\n",
"mp_obj_t ndarray_shape(mp_obj_t );\n", "mp_obj_t ndarray_shape(mp_obj_t );\n",
"mp_obj_t ndarray_size(mp_obj_t , mp_obj_t );\n", "mp_obj_t ndarray_size(mp_obj_t , mp_obj_t );\n",
@ -509,11 +654,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 163, "execution_count": 266,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2019-09-06T18:22:16.257859Z", "end_time": "2019-09-11T17:48:40.648892Z",
"start_time": "2019-09-06T18:22:16.232985Z" "start_time": "2019-09-11T17:48:40.628565Z"
} }
}, },
"outputs": [ "outputs": [
@ -521,7 +666,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"written 18547 bytes to ndarray.c\n" "written 25067 bytes to ndarray.c\n"
] ]
} }
], ],
@ -614,15 +759,15 @@
" }\n", " }\n",
" // TODO: print typecode\n", " // TODO: print typecode\n",
" if(self->data->typecode == NDARRAY_UINT8) {\n", " if(self->data->typecode == NDARRAY_UINT8) {\n",
" printf(\", dtype='uint8')\\n\");\n", " printf(\", dtype='uint8')\");\n",
" } else if(self->data->typecode == NDARRAY_INT8) {\n", " } else if(self->data->typecode == NDARRAY_INT8) {\n",
" printf(\", dtype='int8')\\n\");\n", " printf(\", dtype='int8')\");\n",
" } if(self->data->typecode == NDARRAY_UINT16) {\n", " } if(self->data->typecode == NDARRAY_UINT16) {\n",
" printf(\", dtype='uint16')\\n\");\n", " printf(\", dtype='uint16')\");\n",
" } if(self->data->typecode == NDARRAY_INT16) {\n", " } if(self->data->typecode == NDARRAY_INT16) {\n",
" printf(\", dtype='int16')\\n\");\n", " printf(\", dtype='int16')\");\n",
" } if(self->data->typecode == NDARRAY_FLOAT) {\n", " } if(self->data->typecode == NDARRAY_FLOAT) {\n",
" printf(\", dtype='float')\\n\");\n", " printf(\", dtype='float')\");\n",
" } \n", " } \n",
"}\n", "}\n",
"\n", "\n",
@ -866,130 +1011,193 @@
"}\n", "}\n",
"\n", "\n",
"// Binary operations\n", "// Binary operations\n",
"/*\n", "STATIC uint8_t upcasting(uint8_t type_left, uint8_t type_right) {\n",
"STATIC uint8_t upcasting(ndarray_obj_t lhs, ndarray_obj_t rhs) {\n",
" // returns the upcast typecode\n", " // returns the upcast typecode\n",
" // what we have to establish is, whether either of sides has a type code that is \n", " // Now we have to collect 25 cases. Perhaps there is a more elegant solution for this \n",
" // 'larger' than the other side\n", " if(type_left == type_right) { \n",
" uint8_t typecode_l, typecode_r;\n", " // 5 cases\n",
" switch(lhs->data->typecode) {\n", " return type_left;\n",
" case 'b':\n", " } else if((type_left == NDARRAY_FLOAT) || (type_right == NDARRAY_FLOAT)) { \n",
" typecode_l = (0x01 << 0);\n", " // 8 cases ('f' AND 'f' has already been accounted for) \n",
" case 'B':\n", " return NDARRAY_FLOAT;\n",
" typecode_l = (0x01 << 1);\n", " } else if(((type_left == NDARRAY_UINT8) && (type_right == NDARRAY_INT8)) || \n",
" case 'i':\n", " ((type_left == NDARRAY_INT8) && (type_right == NDARRAY_UINT8)) || \n",
" typecode_l = (0x01 << 2);\n", " ((type_left == NDARRAY_UINT8) && (type_right == NDARRAY_INT16)) || \n",
" case 'I':\n", " ((type_left == NDARRAY_INT16) && (type_right == NDARRAY_UINT8)) || \n",
" typecode_l = (0x01 << 3);\n", " ((type_left == NDARRAY_UINT8) && (type_right == NDARRAY_UINT16)) ||\n",
" case 'f':\n", " ((type_left == NDARRAY_UINT16) && (type_right == NDARRAY_UINT8)) || \n",
" typecode_l = (0x01 << 4);\n", " ((type_left == NDARRAY_INT8) && (type_right == NDARRAY_UINT16)) ||\n",
" }\n", " ((type_left == NDARRAY_UINT16) && (type_right == NDARRAY_INT8)) ) {\n",
" switch(rhs->data->typecode) {\n", " // 8 cases\n",
" case 'b':\n", " return NDARRAY_UINT16;\n",
" typecode_r = (0x01 << 0);\n", " } else if ( ((type_left == NDARRAY_INT8) && (type_right == NDARRAY_INT16)) ||\n",
" case 'B':\n", " ((type_left == NDARRAY_INT16) && (type_right == NDARRAY_INT8)) ) {\n",
" typecode_r = (0x01 << 1);\n", " // 2 cases\n",
" case 'i':\n", " return NDARRAY_INT16;\n",
" typecode_r = (0x01 << 2);\n", " } else if ( ((type_left == NDARRAY_INT16) && (type_right == NDARRAY_UINT16)) ||\n",
" case 'I':\n", " ((type_left == NDARRAY_UINT16) && (type_right == NDARRAY_INT16)) ) {\n",
" typecode_r = (0x01 << 3);\n", " // 2 cases\n",
" case 'f':\n", " return NDARRAY_FLOAT;\n",
" typecode_r = (0x01 << 4);\n",
" }\n",
" // Now we have to collect 25 cases \n",
" if((typecode_l | typecode_r) == (0x01 << 0)) { // 2 cases\n",
" return 'b';\n",
" } else if((typecode_l | typecode_r) == (0x01 << 1)) { // 2 cases\n",
" return 'B';\n",
" } else if((typecode_l | typecode_r) == (0x01 << 2)) { // 2 casaes\n",
" return 'i';\n",
" } else if((typecode_l | typecode_r) == (0x01 << 3)) { // 2 cases\n",
" return 'I';\n",
" } else if((typecode_l | typecode_r) >= (0x01 << 4)) { // 10 cases\n",
" return 'f';\n",
" } else if((typecode_l | typecode_r) == ((0x01 << 0) | (0x01 << 1)) {\n",
" return 'i';\n",
" }\n", " }\n",
" return NDARRAY_FLOAT; // we are never going to reach this statement, but we have to make the compiler happy\n",
"}\n", "}\n",
"\n", "\n",
"mp_obj_t ulab_ndarray_binary_op_helper(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {\n", "mp_obj_t ndarray_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {\n",
" // TODO: support scalar operations\n", " ndarray_obj_t *ol = MP_OBJ_TO_PTR(lhs);\n",
" uint8_t typecode;\n",
" float value;\n",
" // First, the right hand side is a native micropython object, i.e, an integer, or a float\n",
" if (MP_OBJ_IS_TYPE(rhs, &mp_type_int) || MP_OBJ_IS_TYPE(rhs, &mp_type_float)) {\n", " if (MP_OBJ_IS_TYPE(rhs, &mp_type_int) || MP_OBJ_IS_TYPE(rhs, &mp_type_float)) {\n",
" return MP_OBJ_NULL; // op not supported\n", " // we have to split the two cases here...\n",
" } else if(MP_OBJ_IS_TYPE(rhs, &ulab_ndarray_type)) {\n", " if(MP_OBJ_IS_TYPE(rhs, &mp_type_int)) {\n",
" // At this point, the operands should have the same shape\n", " typecode = upcasting(ol->data->typecode, NDARRAY_INT16);\n",
" ndarray_obj_t *ol = MP_OBJ_TO_PTR(lhs);\n", " } else {\n",
" typecode = upcasting(ol->data->typecode, NDARRAY_FLOAT); \n",
" }\n",
" if(MP_OBJ_IS_TYPE(rhs, &mp_type_int)) {\n",
" value = (float)mp_obj_get_int(rhs);\n",
" } else {\n",
" value = mp_obj_get_float(rhs);\n",
" }\n",
" if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_MULTIPLY) || \n",
" (op == MP_BINARY_OP_SUBTRACT) || (op == MP_BINARY_OP_TRUE_DIVIDE)) {\n",
" ndarray_obj_t *out = create_new_ndarray(ol->m, ol->n, typecode);\n",
" if(op == MP_BINARY_OP_SUBTRACT) value *= -1.0;\n",
" if(op == MP_BINARY_OP_TRUE_DIVIDE) value = 1.0/value;\n",
" if(typecode == NDARRAY_INT16) {\n",
" int16_t *outdata = (int16_t *)out->data->items;\n",
" if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_SUBTRACT)) {\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;\n",
" }\n",
" } else if((op == MP_BINARY_OP_MULTIPLY) || (op == MP_BINARY_OP_TRUE_DIVIDE)) {\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;\n",
" }\n",
" }\n",
" } else if(typecode == NDARRAY_FLOAT) {\n",
" float *outdata = (float *)out->data->items;\n",
" if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_SUBTRACT)) {\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;\n",
" }\n",
" } else if((op == MP_BINARY_OP_MULTIPLY) || (op == MP_BINARY_OP_TRUE_DIVIDE)) {\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;\n",
" }\n",
" } \n",
" }\n",
" return MP_OBJ_FROM_PTR(out);\n",
" } else {\n",
" return MP_OBJ_NULL; // op not supported\n",
" }\n",
" } else if(MP_OBJ_IS_TYPE(rhs, &ulab_ndarray_type)) { // next, the ndarray stuff\n",
" ndarray_obj_t *or = MP_OBJ_TO_PTR(rhs);\n", " ndarray_obj_t *or = MP_OBJ_TO_PTR(rhs);\n",
" ndarray_obj_t *array;\n",
" if((ol->m != or->m) || (ol->n != or->n)) {\n", " if((ol->m != or->m) || (ol->n != or->n)) {\n",
" mp_raise_ValueError(\"operands could not be broadcast together\");\n", " mp_raise_ValueError(\"operands could not be broadcast together\");\n",
" }\n", " }\n",
" // do not convert types, if they are identical\n", " // At this point, the operands should have the same shape\n",
" // do not convert either, if the left hand side is a float\n", " typecode = upcasting(or->data->typecode, ol->data->typecode);\n",
" if((ol->data->typecode == or->data->typecode) || ol->data->typecode == NDARRAY_FLOAT) {\n", " if(op == MP_BINARY_OP_EQUAL) {\n",
" array = ulab_ndarray_copy(ol);\n", " // Two arrays are equal, if their shape, typecode, and elements are equal\n",
" } else {\n", " if((ol->m != or->m) || (ol->n != or->n) || (ol->data->typecode != or->data->typecode)) {\n",
" // the types are not equal, we have to do some conversion here\n",
" if(or->data->typecode == NDARRAY_FLOAT) {\n",
" array = ulab_ndarray_copy(ol);\n",
" } else if((ol->data->typecode == NDARRAY_INT16) || (or->data->typecode == NDARRAY_INT16)) {\n",
" array = create_new_ndarray(ol->m, ol->n, NDARRAY_INT16);\n",
" } else if((ol->data->typecode == NDARRAY_UINT16) || (or->data->typecode == NDARRAY_UINT16)) {\n",
" array = create_new_ndarray(ol->m, ol->n, NDARRAY_INT16);\n",
" }\n",
" }\n",
" switch(op) {\n",
" case MP_BINARY_OP_ADD:\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
"\n",
" }\n",
" return MP_OBJ_FROM_PTR(array);\n",
" break;\n",
" default:\n",
" break;\n",
" }\n",
"\n",
" }\n",
"}\n",
"\n",
"STATIC mp_obj_t ulab_ndarray_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {\n",
" ndarray_obj_t *ol = MP_OBJ_TO_PTR(lhs);\n",
" ndarray_obj_t *or = MP_OBJ_TO_PTR(rhs);\n",
"\n",
" // for in-place operations, we won't need this!!!\n",
" ndarray_obj_t *array = ulab_ndarray_copy(ol); \n",
" switch (op) {\n",
" case MP_BINARY_OP_EQUAL:\n",
" if(!MP_OBJ_IS_TYPE(rhs, &ulab_ndarray_type)) {\n",
" return mp_const_false;\n", " return mp_const_false;\n",
" } else {\n", " } else {\n",
" // Two arrays are equal, if their shape, typecode, and elements are equal\n", " size_t i = ol->bytes;\n",
" if((ol->m != or->m) || (ol->n != or->n) || (ol->data->typecode != or->data->typecode)) {\n", " uint8_t *l = (uint8_t *)ol->data->items;\n",
" return mp_const_false;\n", " uint8_t *r = (uint8_t *)or->data->items;\n",
" } else {\n", " while(i) { // At this point, we can simply compare the bytes, the type is irrelevant\n",
" size_t i = ol->bytes;\n", " if(*l++ != *r++) {\n",
" uint8_t *l = (uint8_t *)ol->data->items;\n", " return mp_const_false;\n",
" uint8_t *r = (uint8_t *)or->data->items;\n", " }\n",
" while(i) { // At this point, we can simply compare the bytes, the types is irrelevant\n", " i--;\n",
" if(*l++ != *r++) {\n", " }\n",
" return mp_const_false;\n", " return mp_const_true;\n",
" }\n", " }\n",
" i--;\n", " } else if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_SUBTRACT) || \n",
" (op == MP_BINARY_OP_TRUE_DIVIDE) || (op == MP_BINARY_OP_MULTIPLY)) {\n",
" // for in-place operations, we won't need this!!!\n",
" typecode = upcasting(or->data->typecode, ol->data->typecode);\n",
" ndarray_obj_t *out = create_new_ndarray(ol->m, ol->n, typecode);\n",
" if(typecode == NDARRAY_UINT8) {\n",
" uint8_t *outdata = (uint8_t *)out->data->items;\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" value = ndarray_get_float_value(or->data->items, or->data->typecode, i);\n",
" if(op == MP_BINARY_OP_ADD) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;\n",
" } else if(op == MP_BINARY_OP_SUBTRACT) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value; \n",
" } else if(op == MP_BINARY_OP_MULTIPLY) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;\n",
" } else if(op == MP_BINARY_OP_TRUE_DIVIDE) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;\n",
" }\n",
" }\n",
" } else if(typecode == NDARRAY_INT8) {\n",
" int8_t *outdata = (int8_t *)out->data->items;\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" value = ndarray_get_float_value(or->data->items, or->data->typecode, i);\n",
" if(op == MP_BINARY_OP_ADD) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;\n",
" } else if(op == MP_BINARY_OP_SUBTRACT) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value; \n",
" } else if(op == MP_BINARY_OP_MULTIPLY) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;\n",
" } else if(op == MP_BINARY_OP_TRUE_DIVIDE) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;\n",
" }\n",
" } \n",
" } else if(typecode == NDARRAY_UINT16) {\n",
" uint16_t *outdata = (uint16_t *)out->data->items;\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" value = ndarray_get_float_value(or->data->items, or->data->typecode, i);\n",
" if(op == MP_BINARY_OP_ADD) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;\n",
" } else if(op == MP_BINARY_OP_SUBTRACT) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value; \n",
" } else if(op == MP_BINARY_OP_MULTIPLY) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;\n",
" } else if(op == MP_BINARY_OP_TRUE_DIVIDE) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;\n",
" }\n",
" }\n",
" } else if(typecode == NDARRAY_INT16) {\n",
" int16_t *outdata = (int16_t *)out->data->items;\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" value = ndarray_get_float_value(or->data->items, or->data->typecode, i);\n",
" if(op == MP_BINARY_OP_ADD) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;\n",
" } else if(op == MP_BINARY_OP_SUBTRACT) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value; \n",
" } else if(op == MP_BINARY_OP_MULTIPLY) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;\n",
" } else if(op == MP_BINARY_OP_TRUE_DIVIDE) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;\n",
" }\n",
" } \n",
" } else if(typecode == NDARRAY_FLOAT) {\n",
" float *outdata = (float *)out->data->items;\n",
" for(size_t i=0; i < ol->data->len; i++) {\n",
" value = ndarray_get_float_value(or->data->items, or->data->typecode, i);\n",
" if(op == MP_BINARY_OP_ADD) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;\n",
" } else if(op == MP_BINARY_OP_SUBTRACT) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value; \n",
" } else if(op == MP_BINARY_OP_MULTIPLY) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;\n",
" } else if(op == MP_BINARY_OP_TRUE_DIVIDE) {\n",
" outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;\n",
" }\n", " }\n",
" return mp_const_true;\n",
" }\n", " }\n",
" }\n", " }\n",
" break;\n", " return MP_OBJ_FROM_PTR(out);\n",
" case MP_BINARY_OP_ADD:\n", " } else {\n",
" case MP_BINARY_OP_MULTIPLY: \n", " return MP_OBJ_NULL; // op not supported \n",
" return MP_OBJ_FROM_PTR(array);\n", " }\n",
" break;\n", " } else {\n",
"\n", " mp_raise_TypeError(\"wrong operand type on the right hand side\");\n",
" default:\n",
" return MP_OBJ_NULL; // op not supported\n",
" }\n", " }\n",
" }\n", "}"
"*/"
] ]
}, },
{ {
@ -1010,11 +1218,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 786, "execution_count": 173,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2019-08-11T17:47:20.913255Z", "end_time": "2019-09-09T15:14:14.211206Z",
"start_time": "2019-08-11T17:47:20.853670Z" "start_time": "2019-09-09T15:14:14.090399Z"
} }
}, },
"outputs": [ "outputs": [
@ -1022,7 +1230,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"written 492 bytes to linalg.h\n" "written 487 bytes to linalg.h\n"
] ]
} }
], ],
@ -1040,7 +1248,7 @@
"mp_obj_t linalg_transpose(mp_obj_t );\n", "mp_obj_t linalg_transpose(mp_obj_t );\n",
"mp_obj_t linalg_reshape(mp_obj_t , mp_obj_t );\n", "mp_obj_t linalg_reshape(mp_obj_t , mp_obj_t );\n",
"mp_obj_t linalg_inv(mp_obj_t );\n", "mp_obj_t linalg_inv(mp_obj_t );\n",
"mp_obj_t linalg_multiply(mp_obj_t , mp_obj_t );\n", "mp_obj_t linalg_dot(mp_obj_t , mp_obj_t );\n",
"\n", "\n",
"#endif" "#endif"
] ]
@ -1197,23 +1405,27 @@
" return MP_OBJ_FROM_PTR(inverted);\n", " return MP_OBJ_FROM_PTR(inverted);\n",
"}\n", "}\n",
"\n", "\n",
"mp_obj_t linalg_multiply(mp_obj_t _m1, mp_obj_t _m2) {\n", "mp_obj_t linalg_dot(mp_obj_t _m1, mp_obj_t _m2) {\n",
" // TODO: should the results be upcast?\n",
" ndarray_obj_t *m1 = MP_OBJ_TO_PTR(_m1);\n", " ndarray_obj_t *m1 = MP_OBJ_TO_PTR(_m1);\n",
" ndarray_obj_t *m2 = MP_OBJ_TO_PTR(_m2); \n", " ndarray_obj_t *m2 = MP_OBJ_TO_PTR(_m2); \n",
" if(m1->n != m2->m) {\n", " if(m1->n != m2->m) {\n",
" mp_raise_ValueError(\"matrix dimensions do not match\");\n", " mp_raise_ValueError(\"matrix dimensions do not match\");\n",
" }\n", " }\n",
" ndarray_obj_t *out = create_new_ndarray(m1->m, m2->n, NDARRAY_FLOAT);\n", " ndarray_obj_t *out = create_new_ndarray(m1->m, m2->n, NDARRAY_FLOAT);\n",
" float *data = (float *)out->data;\n",
" float sum;\n", " float sum;\n",
" for(size_t i=0; i < m1->n; i++) {\n", " for(size_t i=0; i < m1->n; i++) {\n",
" for(size_t j=0; j < m2->m; j++) {\n", " for(size_t j=0; j < m2->m; j++) {\n",
" sum = 0.0;\n",
" for(size_t k=0; k < m1->m; k++) {\n", " for(size_t k=0; k < m1->m; k++) {\n",
" // (j, k) * (k, j)\n",
" sum += m1->data->items[i*m1->n+k]*m2->data->items[k*m2->n+j];\n", " sum += m1->data->items[i*m1->n+k]*m2->data->items[k*m2->n+j];\n",
" }\n", " }\n",
" mp_set_out->data->items\n", " data[i*m1->m+j] = sum;\n",
" }\n", " }\n",
" }\n", " }\n",
" return MP_OBJ_FROM_PTR(out); \n", " return MP_OBJ_FROM_PTR(out);\n",
"}" "}"
] ]
}, },
@ -2270,11 +2482,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 825, "execution_count": 251,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2019-08-13T16:06:32.160834Z", "end_time": "2019-09-11T17:35:31.541539Z",
"start_time": "2019-08-13T16:06:32.154580Z" "start_time": "2019-09-11T17:35:31.535557Z"
} }
}, },
"outputs": [ "outputs": [
@ -2282,7 +2494,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"written 7323 bytes to ulab.c\n" "written 7311 bytes to ulab.c\n"
] ]
} }
], ],
@ -2377,7 +2589,7 @@
" .subscr = ndarray_subscr,\n", " .subscr = ndarray_subscr,\n",
" .getiter = ndarray_getiter,\n", " .getiter = ndarray_getiter,\n",
"// .unary_op = ndarray_unary_op,\n", "// .unary_op = ndarray_unary_op,\n",
"// .binary_op = ndarray_binary_op,\n", " .binary_op = ndarray_binary_op,\n",
" .locals_dict = (mp_obj_dict_t*)&ulab_ndarray_locals_dict,\n", " .locals_dict = (mp_obj_dict_t*)&ulab_ndarray_locals_dict,\n",
"};\n", "};\n",
"\n", "\n",
@ -2487,11 +2699,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 164, "execution_count": 267,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2019-09-06T18:22:29.908059Z", "end_time": "2019-09-11T17:48:47.464054Z",
"start_time": "2019-09-06T18:22:20.573945Z" "start_time": "2019-09-11T17:48:46.334799Z"
}, },
"scrolled": false "scrolled": false
}, },
@ -2513,7 +2725,7 @@
" text\t data\t bss\t dec\t hex\tfilename\n", " text\t data\t bss\t dec\t hex\tfilename\n",
" 2085\t 6862\t 0\t 8947\t 22f3\tbuild/build/frozen_mpy.o\n", " 2085\t 6862\t 0\t 8947\t 22f3\tbuild/build/frozen_mpy.o\n",
" 2\t 0\t 0\t 2\t 2\tbuild/build/frozen.o\n", " 2\t 0\t 0\t 2\t 2\tbuild/build/frozen.o\n",
" 455198\t 56992\t 2104\t 514294\t 7d8f6\tmicropython\n" " 457502\t 57088\t 2104\t 516694\t 7e256\tmicropython\n"
] ]
} }
], ],
@ -2523,11 +2735,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 66, "execution_count": 270,
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2019-09-05T19:54:32.829236Z", "end_time": "2019-09-11T17:49:18.773261Z",
"start_time": "2019-09-05T19:54:32.806484Z" "start_time": "2019-09-11T17:49:18.750945Z"
} }
}, },
"outputs": [ "outputs": [
@ -2535,9 +2747,10 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"4.0\n", "ndarray([2.0, 4.0, 6.0, 8.0], dtype='float')\n",
"ndarray([2.0, 3.0], dtype='float')\n", "ndarray([5.400000095367432, 6.400000095367432, 7.400000095367432, 8.399999618530273], dtype='float')\n",
"\n", "ndarray([3.0, 6.0, 9.0, 12.0], dtype='float')\n",
"ndarray([0.3333333432674408, 0.6666666865348816, 1.0, 1.333333373069763], dtype='float')\n",
"\n", "\n",
"\n" "\n"
] ]
@ -2549,44 +2762,11 @@
"import ulab\n", "import ulab\n",
"\n", "\n",
"a = ulab.ndarray([1, 2, 3, 4], dtype='float')\n", "a = ulab.ndarray([1, 2, 3, 4], dtype='float')\n",
"print(ulab.max(a))\n", "b = ulab.ndarray([1, 2, 3, 4], dtype='float')\n",
"print(a[1:3])" "print(a+b)\n",
] "print(a+4.4)\n",
}, "print(a*3.0)\n",
{ "print(a/3.0)"
"cell_type": "code",
"execution_count": 166,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-06T18:22:47.413074Z",
"start_time": "2019-09-06T18:22:47.395440Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ndarray([1, 2, 3, 4], dtype='uint8')\n",
"\n",
"ndarray([5, 6, 7, 8], dtype='uint8')\n",
"\n",
"ndarray([9, 10, 11, 12], dtype='uint8')\n",
"\n",
"\n",
"\n"
]
}
],
"source": [
"%%micropython\n",
"\n",
"import ulab\n",
"\n",
"a = ulab.ndarray([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype='uint8')\n",
"\n",
"for i in a:\n",
" print(i)"
] ]
}, },
{ {