Skip to content

Commit 8c78725

Browse files
committed
MAINT: simplify arg/param handing in normalize
1 parent a6eb581 commit 8c78725

File tree

2 files changed

+18
-29
lines changed

2 files changed

+18
-29
lines changed

torch_np/_normalizations.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -115,35 +115,20 @@ def wrapped(*args, **kwds):
115115
raise NotImplementedError
116116
break
117117

118-
# loop over positional parameters and actual arguments
119-
lst, dct = [], {}
120-
for arg, (name, parm) in zip(args, sig.parameters.items()):
121-
lst.append(normalize_this(arg, parm))
122-
123-
# normalize keyword arguments
124-
for name, arg in kwds.items():
125-
if not name in sig.parameters:
126-
# unknown kwarg, bail out
127-
raise TypeError(
128-
f"{func.__name__}() got an unexpected keyword argument '{name}'."
129-
)
130-
131-
parm = sig.parameters[name]
132-
dct[name] = normalize_this(arg, parm)
133-
134-
ba = sig.bind(*lst, **dct)
135-
ba.apply_defaults()
136-
137-
# Now that all parameters have been consumed, check:
138-
# Anything that has not been bound is unexpected positional arg => raise.
139-
# If there are too few actual arguments, this fill fail in func(*ba.args) below
140-
if len(args) > len(ba.args):
141-
raise TypeError(
142-
f"{func.__name__}() takes {len(ba.args)} positional argument but {len(args)} were given."
143-
)
144-
145-
# finally, pass normalized arguments through
146-
result = func(*ba.args, **ba.kwargs)
118+
# normalize positional and keyword arguments
119+
# NB: extra unknown arguments: pass through, will raise in func(*lst) below
120+
sp = sig.parameters
121+
122+
lst = [normalize_this(arg, parm) for arg, parm in zip(args, sp.values())]
123+
lst += args[len(lst) :]
124+
125+
dct = {
126+
name: normalize_this(arg, sp[name]) if name in sp else arg
127+
for name, arg in kwds.items()
128+
}
129+
130+
result = func(*lst, **dct)
131+
147132
return result
148133

149134
return wrapped

torch_np/tests/test_basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,10 @@ def test_unknown_args(self):
401401
with assert_raises(TypeError):
402402
w.nonzero(a, oops="ouch")
403403

404+
def test_too_few_args_positional(self):
405+
with assert_raises(TypeError):
406+
w.nonzero()
407+
404408
def test_unknown_args_with_defaults(self):
405409
# check a function 5 arguments and 4 defaults: this should work
406410
w.eye(3)

0 commit comments

Comments
 (0)