circuitpython-ulab/tests/numpy/argminmax.py
Jeff Epler 196104084b tests: try to fit mp, cp, and py3 into one mold
Mostly by providing a "numpy shim" for CircuitPython,
try to make the numpy tests run on all three systems.
(a "scipy shim" might also be useful?)

However, there are test failures.  Is it worth working
through them and getting this to a point where it could
be included?
2020-11-23 15:17:00 -06:00

65 lines
1.7 KiB
Python

try:
import numpy as np
except:
import ulab as np
# Adapted from https://docs.python.org/3.8/library/itertools.html#itertools.permutations
def permutations(iterable, r=None):
# permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
# permutations(range(3)) --> 012 021 102 120 201 210
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
if r > n:
return
indices = list(range(n))
cycles = list(range(n, n-r, -1))
yield tuple(pool[i] for i in indices[:r])
while n:
for i in reversed(range(r)):
cycles[i] -= 1
if cycles[i] == 0:
indices[i:] = indices[i+1:] + indices[i:i+1]
cycles[i] = n - i
else:
j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]
yield tuple(pool[i] for i in indices[:r])
break
else:
return
# Combinations expected to throw
try:
print(np.argmin([]))
except ValueError:
print("ValueError")
try:
print(np.argmax([]))
except ValueError:
print("ValueError")
# Combinations expected to succeed
print(np.argmin([1]))
print(np.argmax([1]))
print(np.argmin(np.array([1])))
print(np.argmax(np.array([1])))
print()
print("max tests")
for p in permutations((100,200,300)):
m1 = np.argmax(p)
m2 = np.argmax(np.array(p))
print(p, m1, m2)
if m1 != m2 or p[m1] != max(p):
print("FAIL", p, m1, m2, max(p))
print()
print("min tests")
for p in permutations((100,200,300)):
m1 = np.argmin(p)
m2 = np.argmin(np.array(p))
print(p, m1, m2)
if m1 != m2 or p[m1] != min(p):
print("FAIL", p, m1, m2, min(p))