Update snippets/rclass.py

Co-authored-by: Cal Hays <callumjhays@gmail.com>
This commit is contained in:
Zoltán Vörös 2021-04-01 07:53:04 +02:00 committed by GitHub
parent 4701338c28
commit 7e6a5f9fc9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,20 +1,19 @@
from ulab import numpy as np
from typing import List, Tuple, Union # upip.install("pycopy-typing")
from ulab import numpy as np
ndarray = np.array
_DType = int
RClassKeyType = Union[slice, int, float]
_RClassKeyType = Union[slice, int, float, list, tuple, np.ndarray]
# this is a stripped down version of RClass (used by np.r_[...etc])
# it doesn't include support for string arguments as the first index element
class RClass:
def __getitem__(self, key: Union[RClassKeyType, Tuple[RClassKeyType, ...]]):
def __getitem__(self, key: Union[_RClassKeyType, Tuple[_RClassKeyType, ...]]):
if not isinstance(key, tuple):
key = (key,)
objs: List[ndarray] = []
objs: List[np.ndarray] = []
scalars: List[int] = []
arraytypes: List[_DType] = []
scalartypes: List[_DType] = []
@ -24,50 +23,53 @@ class RClass:
for idx, item in enumerate(key):
scalar = False
if isinstance(item, slice):
step = item.step
start = item.start
stop = item.stop
if start is None:
start = 0
if step is None:
nstep = 1
if isinstance(step, complex):
size = int(abs(step))
newobj = cast(ndarray, linspace(start, stop, num=size))
try:
if isinstance(item, np.ndarray):
newobj = item
elif isinstance(item, slice):
step = item.step
start = item.start
stop = item.stop
if start is None:
start = 0
if step is None:
step = 1
if isinstance(step, complex):
size = int(abs(step))
newobj: np.ndarray = np.linspace(start, stop, num=size)
else:
newobj = np.arange(start, stop, step)
# if is number
elif isinstance(item, (int, float, bool)):
newobj = np.array([item])
scalars.append(len(objs))
scalar = True
scalartypes.append(newobj.dtype())
else:
newobj = np.arange(start, stop, step)
# if is number
elif isinstance(item, (int, float)):
newobj = np.array([item])
scalars.append(len(objs))
scalar = True
scalartypes.append(newobj.dtype())
else:
newobj = np.array(item)
except TypeError:
raise Exception("index object %s of type %s is not supported by r_[]" % (
str(item), type(item)))
objs.append(newobj)
if not scalar and isinstance(newobj, ndarray):
if not scalar and isinstance(newobj, np.ndarray):
arraytypes.append(newobj.dtype())
# Ensure that scalars won't up-cast unless warranted
# TODO: ensure that this actually works for dtype coercion
# likelihood is we're going to have to do some funky logic for this
final_dtype = max(arraytypes + scalartypes)
if final_dtype is not None:
for idx in scalars:
final_dtype = min(arraytypes + scalartypes)
for idx, obj in enumerate(objs):
if obj.dtype != final_dtype:
objs[idx] = np.array(objs[idx], dtype=final_dtype)
res = np.concatenate(tuple(objs), axis=axis)
return res
return np.concatenate(tuple(objs), axis=axis)
# this seems weird - not sure what it's for
def __len__(self):
return 0
r_ = RClass()