1
1
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
2
2
"""
3
+ import functools
3
4
import operator
4
5
import typing
5
6
from typing import Optional , Sequence
@@ -79,21 +80,11 @@ def normalize_ndarray(arg, name=None):
79
80
AxisLike : normalize_axis_like ,
80
81
}
81
82
82
- import functools
83
-
84
- _sentinel = object ()
85
-
86
83
87
- def maybe_normalize (arg , parm , return_on_failure = _sentinel ):
84
+ def maybe_normalize (arg , parm ):
88
85
"""Normalize arg if a normalizer is registred."""
89
86
normalizer = normalizers .get (parm .annotation , None )
90
- try :
91
- return normalizer (arg , parm .name ) if normalizer else arg
92
- except Exception as exc :
93
- if return_on_failure is not _sentinel :
94
- return return_on_failure
95
- else :
96
- raise exc from None
87
+ return normalizer (arg , parm .name ) if normalizer else arg
97
88
98
89
99
90
# ### Return value helpers ###
@@ -145,7 +136,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False):
145
136
# ### The main decorator to normalize arguments / postprocess the output ###
146
137
147
138
148
- def normalizer (_func = None , * , return_on_failure = _sentinel , promote_scalar_result = False ):
139
+ def normalizer (_func = None , * , promote_scalar_result = False ):
149
140
def normalizer_inner (func ):
150
141
@functools .wraps (func )
151
142
def wrapped (* args , ** kwds ):
@@ -154,14 +145,12 @@ def wrapped(*args, **kwds):
154
145
first_param = next (iter (params .values ()))
155
146
# NumPy's API does not have positional args before variadic positional args
156
147
if first_param .kind == inspect .Parameter .VAR_POSITIONAL :
157
- args = [
158
- maybe_normalize (arg , first_param , return_on_failure ) for arg in args
159
- ]
148
+ args = [maybe_normalize (arg , first_param ) for arg in args ]
160
149
else :
161
150
# NB: extra unknown arguments: pass through, will raise in func(*args) below
162
151
args = (
163
152
tuple (
164
- maybe_normalize (arg , parm , return_on_failure )
153
+ maybe_normalize (arg , parm )
165
154
for arg , parm in zip (args , params .values ())
166
155
)
167
156
+ args [len (params .values ()) :]
0 commit comments