Skip to content

Commit 63ea434

Browse files
author
Namrata Madan
committed
fix: remote function argument validation
1 parent 5cd4c8e commit 63ea434

File tree

2 files changed

+66
-41
lines changed

2 files changed

+66
-41
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -457,30 +457,28 @@ def _validate_submit_args(func, *args, **kwargs):
457457
)
458458

459459
if len(args) < minimum_num_expected_positional_args:
460-
missing_args_count = minimum_num_expected_positional_args - len(args)
461-
if missing_args_count == 1:
462-
missing_args = f"'{full_arg_spec.args[minimum_num_expected_positional_args - 1]}'"
463-
else:
464-
missing_args = (
465-
", ".join(
466-
map(
467-
lambda x: f"'{x}'",
468-
full_arg_spec.args[
469-
len(args) : minimum_num_expected_positional_args - 1
470-
],
471-
)
472-
)
473-
+ f", and '{full_arg_spec.args[minimum_num_expected_positional_args - 1]}'"
460+
missing_positional_args = full_arg_spec.args[
461+
len(args) : minimum_num_expected_positional_args
462+
]
463+
missing_args = list(filter(lambda arg: arg not in kwargs, missing_positional_args))
464+
if missing_args:
465+
missing_args_str = (
466+
", ".join(map(lambda x: f"'{x}'", missing_args[:-1]))
467+
+ f", and '{missing_args[-1]}'"
468+
if len(missing_args) > 1
469+
else f"'{missing_args[0]}'"
470+
)
471+
raise TypeError(
472+
f"{func.__name__}() missing {len(missing_args)} required positional "
473+
+ f"{'arguments' if len(missing_args) > 1 else 'argument'}: {missing_args_str}"
474474
)
475-
raise TypeError(
476-
f"{func.__name__}() missing {missing_args_count} required positional "
477-
+ f"{'arguments' if missing_args_count > 1 else 'argument'}: {missing_args}"
478-
)
479475

480476
# kwargs related validations
481477

482478
for k in kwargs:
483-
if k not in full_arg_spec.kwonlyargs:
479+
if k in full_arg_spec.args and len(args) > full_arg_spec.args.index(k):
480+
raise TypeError(f"{func.__name__}() got multiple values for argument '{k}'")
481+
if k not in full_arg_spec.kwonlyargs and k not in full_arg_spec.args:
484482
raise TypeError(f"{func.__name__}() got an unexpected keyword argument '{k}'")
485483

486484
missing_kwargs = [
@@ -489,13 +487,12 @@ def _validate_submit_args(func, *args, **kwargs):
489487
if k not in full_arg_spec.kwonlydefaults and k not in kwargs
490488
]
491489
if missing_kwargs:
492-
if len(missing_kwargs) == 1:
493-
missing_kwargs_string = f"'{missing_kwargs[0]}'"
494-
else:
495-
missing_kwargs_string = (
496-
", ".join(map(lambda x: f"'{x}'", missing_kwargs[:-1]))
497-
+ f", and '{missing_kwargs[-1]}'"
498-
)
490+
missing_kwargs_string = (
491+
", ".join(map(lambda x: f"'{x}'", missing_kwargs[:-1]))
492+
+ f", and '{missing_kwargs[-1]}'"
493+
if len(missing_kwargs) > 1
494+
else f"'{missing_kwargs[0]}'"
495+
)
499496

500497
raise TypeError(
501498
f"{func.__name__}() missing {len(missing_kwargs)} required keyword-only "

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -297,29 +297,57 @@ def square(x):
297297
@pytest.mark.parametrize(
298298
"args, kwargs, error_message",
299299
[
300-
((1, 2), {}, "decorated_function() missing 1 required keyword-only argument: 'c'"),
301300
(
302-
(1, 2),
303-
{"c": 3, "d": 4, "e": "extra_arg"},
304-
"decorated_function() got an unexpected keyword argument 'e'",
301+
[1, 2, 3],
302+
{},
303+
"decorated_function() missing 2 required keyword-only arguments: 'd', and 'e'",
305304
),
306-
((), {"c": 3, "d": 4}, "decorated_function() missing 1 required positional argument: 'a'"),
305+
([1, 2, 3], {"d": 4}, "decorated_function() missing 1 required keyword-only argument: 'e'"),
307306
(
308-
(1, 2, "extra_Arg"),
307+
[1, 2, 3],
308+
{"d": 3, "e": 4, "g": "extra_arg"},
309+
"decorated_function() got an unexpected keyword argument 'g'",
310+
),
311+
(
312+
[],
309313
{"c": 3, "d": 4},
310-
"decorated_function() takes 2 positional arguments but 3 were given.",
314+
"decorated_function() missing 2 required positional arguments: 'a', and 'b'",
315+
),
316+
([1], {"c": 3, "d": 4}, "decorated_function() missing 1 required positional argument: 'b'"),
317+
(
318+
[1, 2, 3, "extra_arg"],
319+
{"d": 3, "e": 4},
320+
"decorated_function() takes 3 positional arguments but 4 were given.",
321+
),
322+
([], {"a": 1, "b": 2, "d": 3, "e": 2}, None),
323+
(
324+
(1, 2),
325+
{"a": 1, "c": 3, "d": 2},
326+
"decorated_function() got multiple values for argument 'a'",
327+
),
328+
(
329+
(1, 2),
330+
{"b": 1, "c": 3, "d": 2},
331+
"decorated_function() got multiple values for argument 'b'",
311332
),
312333
],
313334
)
314-
def test_decorator_invalid_function_args(args, kwargs, error_message):
335+
@patch("sagemaker.remote_function.client._JobSettings")
336+
@patch("sagemaker.remote_function.client._Job.start")
337+
def test_decorator_invalid_function_args(job_start, job_settings, args, kwargs, error_message):
315338
@remote(image_uri=IMAGE, s3_root_uri=S3_URI)
316-
def decorated_function(a, b=1, *, c, d=3):
317-
return a * b * c * d
318-
319-
with pytest.raises(TypeError) as e:
320-
decorated_function(*args, **kwargs)
321-
322-
assert error_message in str(e.value)
339+
def decorated_function(a, b, c=1, *, d, e, f=3):
340+
return a * b * c * d * e * f
341+
342+
if error_message:
343+
with pytest.raises(TypeError) as e:
344+
decorated_function(*args, **kwargs)
345+
assert error_message in str(e.value)
346+
else:
347+
try:
348+
decorated_function(*args, **kwargs)
349+
except Exception as ex:
350+
pytest.fail("Unexpected Exception: " + str(ex))
323351

324352

325353
def test_executor_invalid_arguments():

0 commit comments

Comments
 (0)