152 lines
6.4 KiB
C
152 lines
6.4 KiB
C
|
|
/*
|
|
* This file is part of the micropython-ulab project,
|
|
*
|
|
* https://github.com/v923z/micropython-ulab
|
|
*
|
|
* The MIT License (MIT)
|
|
*
|
|
* Copyright (c) 2020-2025 Zoltán Vörös
|
|
*/
|
|
|
|
#ifndef _COMPARE_
|
|
#define _COMPARE_
|
|
|
|
#include "../ulab.h"
|
|
#include "../ndarray.h"
|
|
|
|
enum COMPARE_FUNCTION_TYPE {
|
|
COMPARE_EQUAL,
|
|
COMPARE_NOT_EQUAL,
|
|
COMPARE_MINIMUM,
|
|
COMPARE_MAXIMUM,
|
|
COMPARE_CLIP,
|
|
};
|
|
|
|
MP_DECLARE_CONST_FUN_OBJ_KW(compare_bincount_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_3(compare_clip_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_2(compare_equal_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_2(compare_isfinite_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_2(compare_isinf_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_2(compare_minimum_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_2(compare_maximum_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_1(compare_nonzero_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_2(compare_not_equal_obj);
|
|
MP_DECLARE_CONST_FUN_OBJ_3(compare_where_obj);
|
|
|
|
#if ULAB_MAX_DIMS == 1
|
|
#define COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
|
size_t l = 0;\
|
|
do {\
|
|
*((type_out *)(array)) = *((type_left *)(larray)) OPERATOR *((type_right *)(rarray)) ? (type_out)(*((type_left *)(larray))) : (type_out)(*((type_right *)(rarray)));\
|
|
(array) += (results)->strides[ULAB_MAX_DIMS - 1];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
|
l++;\
|
|
} while(l < results->shape[ULAB_MAX_DIMS - 1]);\
|
|
return MP_OBJ_FROM_PTR(results);\
|
|
|
|
#endif // ULAB_MAX_DIMS == 1
|
|
|
|
#if ULAB_MAX_DIMS == 2
|
|
#define COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
|
size_t k = 0;\
|
|
do {\
|
|
size_t l = 0;\
|
|
do {\
|
|
*((type_out *)(array)) = *((type_left *)(larray)) OPERATOR *((type_right *)(rarray)) ? (type_out)(*((type_left *)(larray))) : (type_out)(*((type_right *)(rarray)));\
|
|
(array) += (results)->strides[ULAB_MAX_DIMS - 1];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
|
l++;\
|
|
} while(l < results->shape[ULAB_MAX_DIMS - 1]);\
|
|
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
|
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
|
k++;\
|
|
} while(k < results->shape[ULAB_MAX_DIMS - 2]);\
|
|
return MP_OBJ_FROM_PTR(results);\
|
|
|
|
#endif // ULAB_MAX_DIMS == 2
|
|
|
|
#if ULAB_MAX_DIMS == 3
|
|
#define COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
|
size_t j = 0;\
|
|
do {\
|
|
size_t k = 0;\
|
|
do {\
|
|
size_t l = 0;\
|
|
do {\
|
|
*((type_out *)(array)) = *((type_left *)(larray)) OPERATOR *((type_right *)(rarray)) ? (type_out)(*((type_left *)(larray))) : (type_out)(*((type_right *)(rarray)));\
|
|
(array) += (results)->strides[ULAB_MAX_DIMS - 1];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
|
l++;\
|
|
} while(l < results->shape[ULAB_MAX_DIMS - 1]);\
|
|
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
|
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
|
k++;\
|
|
} while(k < results->shape[ULAB_MAX_DIMS - 2]);\
|
|
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
|
|
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
|
|
j++;\
|
|
} while(j < results->shape[ULAB_MAX_DIMS - 3]);\
|
|
return MP_OBJ_FROM_PTR(results);\
|
|
|
|
#endif // ULAB_MAX_DIMS == 3
|
|
|
|
#if ULAB_MAX_DIMS == 4
|
|
#define COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, OPERATOR)\
|
|
size_t i = 0;\
|
|
do {\
|
|
size_t j = 0;\
|
|
do {\
|
|
size_t k = 0;\
|
|
do {\
|
|
size_t l = 0;\
|
|
do {\
|
|
*((type_out *)(array)) = *((type_left *)(larray)) OPERATOR *((type_right *)(rarray)) ? (type_out)(*((type_left *)(larray))) : (type_out)(*((type_right *)(rarray)));\
|
|
(array) += (results)->strides[ULAB_MAX_DIMS - 1];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 1];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
|
|
l++;\
|
|
} while(l < results->shape[ULAB_MAX_DIMS - 1]);\
|
|
(larray) -= (lstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 2];\
|
|
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * results->shape[ULAB_MAX_DIMS-1];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
|
|
k++;\
|
|
} while(k < results->shape[ULAB_MAX_DIMS - 2]);\
|
|
(larray) -= (lstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 3];\
|
|
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * results->shape[ULAB_MAX_DIMS-2];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
|
|
j++;\
|
|
} while(j < results->shape[ULAB_MAX_DIMS - 3]);\
|
|
(larray) -= (lstrides)[ULAB_MAX_DIMS - 3] * results->shape[ULAB_MAX_DIMS-3];\
|
|
(larray) += (lstrides)[ULAB_MAX_DIMS - 4];\
|
|
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * results->shape[ULAB_MAX_DIMS-3];\
|
|
(rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
|
|
i++;\
|
|
} while(i < results->shape[ULAB_MAX_DIMS - 4]);\
|
|
return MP_OBJ_FROM_PTR(results);\
|
|
|
|
#endif // ULAB_MAX_DIMS == 4
|
|
|
|
#define RUN_COMPARE_LOOP(dtype, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, ndim, shape, op) do {\
|
|
ndarray_obj_t *results = ndarray_new_dense_ndarray((ndim), (shape), (dtype));\
|
|
uint8_t *array = (uint8_t *)results->array;\
|
|
if((op) == COMPARE_MINIMUM) {\
|
|
COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, <);\
|
|
}\
|
|
if((op) == COMPARE_MAXIMUM) {\
|
|
COMPARE_LOOP(results, array, type_out, type_left, type_right, larray, lstrides, rarray, rstrides, >);\
|
|
}\
|
|
} while(0)
|
|
|
|
#endif
|