62 lines
1.6 KiB
Python
62 lines
1.6 KiB
Python
from ulab import numpy as np
|
|
|
|
# Adapted from https://docs.python.org/3.8/library/itertools.html#itertools.permutations
|
|
def permutations(iterable, r=None):
|
|
# permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
|
|
# permutations(range(3)) --> 012 021 102 120 201 210
|
|
pool = tuple(iterable)
|
|
n = len(pool)
|
|
r = n if r is None else r
|
|
if r > n:
|
|
return
|
|
indices = list(range(n))
|
|
cycles = list(range(n, n-r, -1))
|
|
yield tuple(pool[i] for i in indices[:r])
|
|
while n:
|
|
for i in reversed(range(r)):
|
|
cycles[i] -= 1
|
|
if cycles[i] == 0:
|
|
indices[i:] = indices[i+1:] + indices[i:i+1]
|
|
cycles[i] = n - i
|
|
else:
|
|
j = cycles[i]
|
|
indices[i], indices[-j] = indices[-j], indices[i]
|
|
yield tuple(pool[i] for i in indices[:r])
|
|
break
|
|
else:
|
|
return
|
|
|
|
# Combinations expected to throw
|
|
try:
|
|
print(np.argmin([]))
|
|
except ValueError:
|
|
print("ValueError")
|
|
|
|
try:
|
|
print(np.argmax([]))
|
|
except ValueError:
|
|
print("ValueError")
|
|
|
|
# Combinations expected to succeed
|
|
print(np.argmin([1]))
|
|
print(np.argmax([1]))
|
|
print(np.argmin(np.array([1])))
|
|
print(np.argmax(np.array([1])))
|
|
|
|
print()
|
|
print("max tests")
|
|
for p in permutations((100,200,300)):
|
|
m1 = np.argmax(p)
|
|
m2 = np.argmax(np.array(p))
|
|
print(p, m1, m2)
|
|
if m1 != m2 or p[m1] != max(p):
|
|
print("FAIL", p, m1, m2, max(p))
|
|
|
|
print()
|
|
print("min tests")
|
|
for p in permutations((100,200,300)):
|
|
m1 = np.argmin(p)
|
|
m2 = np.argmin(np.array(p))
|
|
print(p, m1, m2)
|
|
if m1 != m2 or p[m1] != min(p):
|
|
print("FAIL", p, m1, m2, min(p))
|