fixed half of the binary operator code (implemented macros)

This commit is contained in:
Zoltán Vörös 2019-09-27 20:23:25 +02:00
parent 093bc39b7d
commit 6a65d33445
5 changed files with 2565 additions and 387 deletions

View file

@ -400,6 +400,7 @@ STATIC uint8_t upcasting(uint8_t type_left, uint8_t type_right) {
}
mp_obj_t ndarray_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
// TODO: swap left and right for scalar + ndarray
ndarray_obj_t *ol = MP_OBJ_TO_PTR(lhs);
uint8_t typecode;
float value;
@ -470,81 +471,82 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
}
} else if((op == MP_BINARY_OP_ADD) || (op == MP_BINARY_OP_SUBTRACT) ||
(op == MP_BINARY_OP_TRUE_DIVIDE) || (op == MP_BINARY_OP_MULTIPLY)) {
// for in-place operations, we won't need this!!!
typecode = upcasting(or->data->typecode, ol->data->typecode);
ndarray_obj_t *out = create_new_ndarray(ol->m, ol->n, typecode);
if(typecode == NDARRAY_UINT8) {
uint8_t *outdata = (uint8_t *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
// These are the upcasting rules
// float always becomes float
// operation on identical types preserves type
// uint8 + int8 => int16
// uint8 + int16 => int16
// uint8 + uint16 => uint16
// int8 + int16 => int16
// int8 + uint16 => uint16
// uint16 + int16 => float
// The parameters of RUN_BINARY_LOOP are
// typecode of result, type_out, type_left, type_right, lhs operand, rhs operand, operator
if(ol->data->typecode == NDARRAY_UINT8) {
if(or->data->typecode == NDARRAY_UINT8) {
RUN_BINARY_LOOP(NDARRAY_UINT8, uint8_t, uint8_t, uint8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT8) {
RUN_BINARY_LOOP(NDARRAY_INT16, int16_t, uint8_t, int8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_UINT16) {
RUN_BINARY_LOOP(NDARRAY_UINT16, uint16_t, uint8_t, uint16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT16) {
RUN_BINARY_LOOP(NDARRAY_INT16, int16_t, uint8_t, int16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_FLOAT) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, uint8_t, float, ol, or, op);
}
} else if(typecode == NDARRAY_INT8) {
int8_t *outdata = (int8_t *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
}
} else if(typecode == NDARRAY_UINT16) {
uint16_t *outdata = (uint16_t *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
} else if(ol->data->typecode == NDARRAY_INT8) {
if(or->data->typecode == NDARRAY_UINT8) {
RUN_BINARY_LOOP(NDARRAY_INT16, int16_t, int8_t, uint8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT8) {
RUN_BINARY_LOOP(NDARRAY_INT8, int8_t, int8_t, int8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_UINT16) {
RUN_BINARY_LOOP(NDARRAY_INT16, int16_t, int8_t, uint16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT16) {
RUN_BINARY_LOOP(NDARRAY_INT16, int16_t, int8_t, int16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_FLOAT) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, int8_t, float, ol, or, op);
}
} else if(ol->data->typecode == NDARRAY_UINT16) {
if(or->data->typecode == NDARRAY_UINT8) {
RUN_BINARY_LOOP(NDARRAY_UINT16, uint16_t, uint16_t, uint8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT8) {
RUN_BINARY_LOOP(NDARRAY_UINT16, uint16_t, uint16_t, int8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_UINT16) {
RUN_BINARY_LOOP(NDARRAY_UINT16, uint16_t, uint16_t, uint16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT16) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, uint16_t, int16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_FLOAT) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, uint8_t, float, ol, or, op);
}
} else if(typecode == NDARRAY_INT16) {
int16_t *outdata = (int16_t *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
}
} else if(typecode == NDARRAY_FLOAT) {
float *outdata = (float *)out->data->items;
for(size_t i=0; i < ol->data->len; i++) {
value = ndarray_get_float_value(or->data->items, or->data->typecode, i);
if(op == MP_BINARY_OP_ADD) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) + value;
} else if(op == MP_BINARY_OP_SUBTRACT) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) - value;
} else if(op == MP_BINARY_OP_MULTIPLY) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) * value;
} else if(op == MP_BINARY_OP_TRUE_DIVIDE) {
outdata[i] = ndarray_get_float_value(ol->data->items, ol->data->typecode, i) / value;
}
} else if(ol->data->typecode == NDARRAY_INT16) {
if(or->data->typecode == NDARRAY_UINT8) {
RUN_BINARY_LOOP(NDARRAY_INT16, int16_t, int16_t, uint8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT8) {
RUN_BINARY_LOOP(NDARRAY_INT16, int16_t, int16_t, int8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_UINT16) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, int16_t, uint16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT16) {
RUN_BINARY_LOOP(NDARRAY_INT16, int16_t, int16_t, int16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_FLOAT) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, uint16_t, float, ol, or, op);
}
} else if(ol->data->typecode == NDARRAY_FLOAT) {
if(or->data->typecode == NDARRAY_UINT8) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, float, uint8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT8) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, float, int8_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_UINT16) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, float, uint16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_INT16) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, float, int16_t, ol, or, op);
} else if(or->data->typecode == NDARRAY_FLOAT) {
RUN_BINARY_LOOP(NDARRAY_FLOAT, float, float, float, ol, or, op);
}
} else {
mp_raise_TypeError("wrong input type");
}
return MP_OBJ_FROM_PTR(out);
// this instruction should never be reached, but we have to make the compiler happy
return MP_OBJ_NULL;
} else {
return MP_OBJ_NULL; // op not supported
}

View file

@ -20,12 +20,23 @@
const mp_obj_type_t ulab_ndarray_type;
#define RUN_BINARY_LOOP(typecode, type_out, type_left, type_right, ol, or, op) do {\
ndarray_obj_t *out = create_new_ndarray(ol->m, ol->n, typecode);\
type_out *(odata) = (type_out *)out->data->items;\
type_left *left = (type_left *)(ol)->data->items;\
type_right *right = (type_right *)(or)->data->items;\
if((op) == MP_BINARY_OP_ADD) { for(size_t i=0; i < (ol)->data->len; i++) odata[i] = left[i] + right[i];}\
if((op) == MP_BINARY_OP_SUBTRACT) { for(size_t i=0; i < (ol)->data->len; i++) odata[i] = left[i] - right[i];}\
if((op) == MP_BINARY_OP_MULTIPLY) { for(size_t i=0; i < (ol)->data->len; i++) odata[i] = left[i] * right[i];}\
if((op) == MP_BINARY_OP_TRUE_DIVIDE) { for(size_t i=0; i < (ol)->data->len; i++) odata[i] = left[i] / right[i];}\
return MP_OBJ_FROM_PTR(out);\
} while(0)
enum NDARRAY_TYPE {
NDARRAY_UINT8 = 'b',
NDARRAY_INT8 = 'B',
NDARRAY_UINT16 = 'i',
NDARRAY_INT16 = 'I',
NDARRAY_UINT8 = 'B',
NDARRAY_INT8 = 'b',
NDARRAY_UINT16 = 'H',
NDARRAY_INT16 = 'h',
NDARRAY_FLOAT = 'f',
};

File diff suppressed because it is too large Load diff

View file

@ -11,14 +11,14 @@
{% block input scoped%}
{%- if '%%ccode' in cell.source.strip().split('\n')[0] -%}
{{ 'https://github.com/v923z/micropython-usermod/tree/master/snippets' + cell.source.strip().split('\n')[0].split()[-1] }}
{{ 'https://github.com/v923z/micropython-ulab/tree/master/code/' + cell.source.strip().split('\n')[0].split()[-1] }}
.. code:: cpp
{{ '\n'.join( cell.source.strip().split('\n')[1:] ) | indent }}
{%- elif '%%makefile' in cell.source.strip().split('\n')[0] -%}
{{ 'https://github.com/v923z/micropython-usermod/tree/master/snippets/' + cell.source.strip().split('\n')[0].split()[-1].split('/')[1] + '/micropython.mk' }}
{{ 'https://github.com/v923z/micropython-ulab/tree/master/cpde/' + cell.source.strip().split('\n')[0].split()[-1].split('/')[1] + '/micropython.mk' }}
.. code:: make

File diff suppressed because it is too large Load diff