Skip to content

Commit 57cc12b

Browse files
committed
MAINT: remove return_on_failure at normalizer
1 parent 0c74aca commit 57cc12b

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

torch_np/_funcs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
DTypeLike,
1919
NDArray,
2020
SubokLike,
21+
normalize_array_like,
2122
normalizer,
2223
)
2324

@@ -1813,10 +1814,13 @@ def i0(x: ArrayLike):
18131814
return torch.special.i0(x)
18141815

18151816

1816-
@normalizer(return_on_failure=False)
1817-
def isscalar(a: ArrayLike):
1817+
def isscalar(a):
18181818
# XXX: this is a stub
1819-
return a.numel() == 1
1819+
try:
1820+
t = normalize_array_like(a)
1821+
return t.numel() == 1
1822+
except Exception:
1823+
return False
18201824

18211825

18221826
"""

torch_np/_normalizations.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
22
"""
3+
import functools
34
import operator
45
import typing
56
from typing import Optional, Sequence
@@ -79,21 +80,11 @@ def normalize_ndarray(arg, name=None):
7980
AxisLike: normalize_axis_like,
8081
}
8182

82-
import functools
83-
84-
_sentinel = object()
85-
8683

87-
def maybe_normalize(arg, parm, return_on_failure=_sentinel):
84+
def maybe_normalize(arg, parm):
8885
"""Normalize arg if a normalizer is registred."""
8986
normalizer = normalizers.get(parm.annotation, None)
90-
try:
91-
return normalizer(arg, parm.name) if normalizer else arg
92-
except Exception as exc:
93-
if return_on_failure is not _sentinel:
94-
return return_on_failure
95-
else:
96-
raise exc from None
87+
return normalizer(arg, parm.name) if normalizer else arg
9788

9889

9990
# ### Return value helpers ###
@@ -145,7 +136,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False):
145136
# ### The main decorator to normalize arguments / postprocess the output ###
146137

147138

148-
def normalizer(_func=None, *, return_on_failure=_sentinel, promote_scalar_result=False):
139+
def normalizer(_func=None, *, promote_scalar_result=False):
149140
def normalizer_inner(func):
150141
@functools.wraps(func)
151142
def wrapped(*args, **kwds):
@@ -154,14 +145,12 @@ def wrapped(*args, **kwds):
154145
first_param = next(iter(params.values()))
155146
# NumPy's API does not have positional args before variadic positional args
156147
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
157-
args = [
158-
maybe_normalize(arg, first_param, return_on_failure) for arg in args
159-
]
148+
args = [maybe_normalize(arg, first_param) for arg in args]
160149
else:
161150
# NB: extra unknown arguments: pass through, will raise in func(*args) below
162151
args = (
163152
tuple(
164-
maybe_normalize(arg, parm, return_on_failure)
153+
maybe_normalize(arg, parm)
165154
for arg, parm in zip(args, params.values())
166155
)
167156
+ args[len(params.values()) :]

0 commit comments

Comments
 (0)