fix the np.delete bug (#653)

* fix the `np.delete` bug

* fix the `np.delete` bug, add unittest code

* increment the version number and update the change log

* update the expected file `delete.py.exp`
This commit is contained in:
yyyz 2023-12-25 17:56:16 +08:00 committed by GitHub
parent e32920645c
commit 7a9370612f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 1 deletions

View file

@ -204,6 +204,11 @@ static mp_obj_t transform_delete(size_t n_args, const mp_obj_t *pos_args, mp_map
mp_raise_TypeError(MP_ERROR_TEXT("wrong index type"));
}
index_len = MP_OBJ_SMALL_INT_VALUE(mp_obj_len_maybe(indices));
if (index_len == 0){
// if the second positional argument is empty
// return the original array
return MP_OBJ_FROM_PTR(ndarray);
}
}
if(index_len > axis_len) {

View file

@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"
#define ULAB_VERSION 6.4.2
#define ULAB_VERSION 6.4.3
#define xstr(s) str(s)
#define str(s) #s

View file

@ -1,3 +1,9 @@
Mon, 25 Dec 2023
version 6.4.3
fix the 'np.delete' error that occurs when passing an empty iterable object as the second positional argument (#653)
Thu, 11 Dec 2023
version 6.4.2

View file

@ -11,7 +11,9 @@ for dtype in dtypes:
a = np.array(range(25), dtype=dtype).reshape((5,5))
print(np.delete(a, [1, 2], axis=0))
print(np.delete(a, [1, 2], axis=1))
print(np.delete(a, [], axis=1))
print(np.delete(a, [1, 5, 10]))
print(np.delete(a, []))
for dtype in dtypes:
a = np.array(range(25), dtype=dtype).reshape((5,5))

View file

@ -6,7 +6,17 @@ array([[0, 3, 4],
[10, 13, 14],
[15, 18, 19],
[20, 23, 24]], dtype=uint8)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint8)
array([0, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=uint8)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint8)
array([[0, 1, 2, 3, 4],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int8)
@ -15,7 +25,17 @@ array([[0, 3, 4],
[10, 13, 14],
[15, 18, 19],
[20, 23, 24]], dtype=int8)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int8)
array([0, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=int8)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int8)
array([[0, 1, 2, 3, 4],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint16)
@ -24,7 +44,17 @@ array([[0, 3, 4],
[10, 13, 14],
[15, 18, 19],
[20, 23, 24]], dtype=uint16)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint16)
array([0, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=uint16)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=uint16)
array([[0, 1, 2, 3, 4],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int16)
@ -33,7 +63,17 @@ array([[0, 3, 4],
[10, 13, 14],
[15, 18, 19],
[20, 23, 24]], dtype=int16)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int16)
array([0, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=int16)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]], dtype=int16)
array([[0.0, 1.0, 2.0, 3.0, 4.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
[20.0, 21.0, 22.0, 23.0, 24.0]], dtype=float64)
@ -42,7 +82,17 @@ array([[0.0, 3.0, 4.0],
[10.0, 13.0, 14.0],
[15.0, 18.0, 19.0],
[20.0, 23.0, 24.0]], dtype=float64)
array([[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
[20.0, 21.0, 22.0, 23.0, 24.0]], dtype=float64)
array([0.0, 2.0, 3.0, 4.0, 6.0, 7.0, 8.0, 9.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0], dtype=float64)
array([[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
[20.0, 21.0, 22.0, 23.0, 24.0]], dtype=float64)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[15, 16, 17, 18, 19],