Update snippets/rclass.py
Co-authored-by: Cal Hays <callumjhays@gmail.com>
This commit is contained in:
parent
4701338c28
commit
7e6a5f9fc9
1 changed files with 39 additions and 37 deletions
|
|
@ -1,20 +1,19 @@
|
|||
from ulab import numpy as np
|
||||
from typing import List, Tuple, Union # upip.install("pycopy-typing")
|
||||
from ulab import numpy as np
|
||||
|
||||
ndarray = np.array
|
||||
_DType = int
|
||||
RClassKeyType = Union[slice, int, float]
|
||||
_RClassKeyType = Union[slice, int, float, list, tuple, np.ndarray]
|
||||
|
||||
# this is a stripped down version of RClass (used by np.r_[...etc])
|
||||
# it doesn't include support for string arguments as the first index element
|
||||
class RClass:
|
||||
|
||||
def __getitem__(self, key: Union[RClassKeyType, Tuple[RClassKeyType, ...]]):
|
||||
def __getitem__(self, key: Union[_RClassKeyType, Tuple[_RClassKeyType, ...]]):
|
||||
|
||||
if not isinstance(key, tuple):
|
||||
key = (key,)
|
||||
|
||||
objs: List[ndarray] = []
|
||||
objs: List[np.ndarray] = []
|
||||
scalars: List[int] = []
|
||||
arraytypes: List[_DType] = []
|
||||
scalartypes: List[_DType] = []
|
||||
|
|
@ -24,50 +23,53 @@ class RClass:
|
|||
|
||||
for idx, item in enumerate(key):
|
||||
scalar = False
|
||||
if isinstance(item, slice):
|
||||
step = item.step
|
||||
start = item.start
|
||||
stop = item.stop
|
||||
if start is None:
|
||||
start = 0
|
||||
if step is None:
|
||||
nstep = 1
|
||||
if isinstance(step, complex):
|
||||
size = int(abs(step))
|
||||
newobj = cast(ndarray, linspace(start, stop, num=size))
|
||||
|
||||
try:
|
||||
if isinstance(item, np.ndarray):
|
||||
newobj = item
|
||||
|
||||
elif isinstance(item, slice):
|
||||
step = item.step
|
||||
start = item.start
|
||||
stop = item.stop
|
||||
if start is None:
|
||||
start = 0
|
||||
if step is None:
|
||||
step = 1
|
||||
if isinstance(step, complex):
|
||||
size = int(abs(step))
|
||||
newobj: np.ndarray = np.linspace(start, stop, num=size)
|
||||
else:
|
||||
newobj = np.arange(start, stop, step)
|
||||
|
||||
# if is number
|
||||
elif isinstance(item, (int, float, bool)):
|
||||
newobj = np.array([item])
|
||||
scalars.append(len(objs))
|
||||
scalar = True
|
||||
scalartypes.append(newobj.dtype())
|
||||
|
||||
else:
|
||||
newobj = np.arange(start, stop, step)
|
||||
|
||||
# if is number
|
||||
elif isinstance(item, (int, float)):
|
||||
newobj = np.array([item])
|
||||
scalars.append(len(objs))
|
||||
scalar = True
|
||||
scalartypes.append(newobj.dtype())
|
||||
|
||||
else:
|
||||
newobj = np.array(item)
|
||||
|
||||
except TypeError:
|
||||
raise Exception("index object %s of type %s is not supported by r_[]" % (
|
||||
str(item), type(item)))
|
||||
|
||||
objs.append(newobj)
|
||||
if not scalar and isinstance(newobj, ndarray):
|
||||
if not scalar and isinstance(newobj, np.ndarray):
|
||||
arraytypes.append(newobj.dtype())
|
||||
|
||||
# Ensure that scalars won't up-cast unless warranted
|
||||
# TODO: ensure that this actually works for dtype coercion
|
||||
# likelihood is we're going to have to do some funky logic for this
|
||||
final_dtype = max(arraytypes + scalartypes)
|
||||
if final_dtype is not None:
|
||||
for idx in scalars:
|
||||
final_dtype = min(arraytypes + scalartypes)
|
||||
for idx, obj in enumerate(objs):
|
||||
if obj.dtype != final_dtype:
|
||||
objs[idx] = np.array(objs[idx], dtype=final_dtype)
|
||||
|
||||
res = np.concatenate(tuple(objs), axis=axis)
|
||||
|
||||
return res
|
||||
return np.concatenate(tuple(objs), axis=axis)
|
||||
|
||||
# this seems weird - not sure what it's for
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
r_ = RClass()
|
||||
|
|
|
|||
Loading…
Reference in a new issue