Skip to content

fix: update kwargs target attribute #1941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/sagemaker/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,24 @@ def renamed_warning(phrase):
_warn(f"{phrase} has been renamed")


def renamed_kwargs(name, default, kwargs):
def renamed_kwargs(old_name, new_name, value, kwargs):
"""Checks if the deprecated argument is in kwargs

Raises warning, if present.

Args:
name: name of deprecated argument
default: default value to use, if not present
old_name: name of deprecated argument
new_name: name of the new argument
value: value associated with new name, if supplied
kwargs: keyword arguments dict

Returns:
value of the keyword argument, if present
"""
value = kwargs.get(name, default)
if value != default:
renamed_warning(name)
if old_name in kwargs:
value = kwargs.get(old_name, value)
kwargs[new_name] = value
renamed_warning(old_name)
return value


Expand Down
31 changes: 21 additions & 10 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def __init__(
:class:`~sagemaker.debugger.Rule` objects used to define
rules for continuous analysis with SageMaker Debugger
(default: ``None``). For more, see
https://sagemaker.readthedocs.io/en/stable/amazon_sagemaker_debugger.html#continuous-analyses-through-rules
https://sagemaker.readthedocs.io/en/stable/amazon_sagemaker_debugger.html#
continuous-analyses-through-rules
debugger_hook_config (:class:`~sagemaker.debugger.DebuggerHookConfig` or bool):
Configuration for how debugging information is emitted with
SageMaker Debugger. If not specified, a default one is created using
Expand All @@ -218,24 +219,34 @@ def __init__(
tensorboard_output_config (:class:`~sagemaker.debugger.TensorBoardOutputConfig`):
Configuration for customizing debugging visualization using TensorBoard
(default: ``None``). For more, see
https://sagemaker.readthedocs.io/en/stable/amazon_sagemaker_debugger.html#capture-real-time-tensorboard-data-from-the-debugging-hook
https://sagemaker.readthedocs.io/en/stable/amazon_sagemaker_debugger.html#
capture-real-time-tensorboard-data-from-the-debugging-hook
enable_sagemaker_metrics (bool): Enables SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#
SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default: ``None``).
enable_network_isolation (bool): Specifies whether container will
run in network isolation mode (default: ``False``). Network
isolation mode restricts the container access to outside networks
(such as the Internet). The container does not make any inbound or
outbound network calls. Also known as Internet-free mode.
"""
instance_count = renamed_kwargs("train_instance_count", instance_count, kwargs)
instance_type = renamed_kwargs("train_instance_type", instance_type, kwargs)
max_run = renamed_kwargs("train_max_run", max_run, kwargs)
use_spot_instances = renamed_kwargs("train_use_spot_instances", use_spot_instances, kwargs)
max_wait = renamed_kwargs("train_max_run_wait", max_wait, kwargs)
volume_size = renamed_kwargs("train_volume_size", volume_size, kwargs)
volume_kms_key = renamed_kwargs("train_volume_kms_key", volume_kms_key, kwargs)
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
)
instance_type = renamed_kwargs(
"train_instance_type", "instance_type", instance_type, kwargs
)
max_run = renamed_kwargs("train_max_run", "max_run", max_run, kwargs)
use_spot_instances = renamed_kwargs(
"train_use_spot_instances", "use_spot_instances", use_spot_instances, kwargs
)
max_wait = renamed_kwargs("train_max_run_wait", "max_wait", max_wait, kwargs)
volume_size = renamed_kwargs("train_volume_size", "volume_size", volume_size, kwargs)
volume_kms_key = renamed_kwargs(
"train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs
)

if instance_count is None or instance_type is None:
raise ValueError("Both instance_count and instance_type are required.")
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
distribution = renamed_kwargs("distributions", distribution, kwargs)
distribution = renamed_kwargs("distributions", "distribution", distribution, kwargs)
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
"""
removed_kwargs("content_type", kwargs)
removed_kwargs("accept", kwargs)
endpoint_name = renamed_kwargs("endpoint", endpoint_name, kwargs)
endpoint_name = renamed_kwargs("endpoint", "endpoint_name", endpoint_name, kwargs)
self.endpoint_name = endpoint_name
self.sagemaker_session = sagemaker_session or Session()
self.serializer = serializer
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/test_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@


def test_renamed_kwargs():
kwargs, b = {"a": 1}, 2
val = renamed_kwargs("b", default=b, kwargs=kwargs)
kwargs, c = {"a": 1}, 2
val = renamed_kwargs("b", new_name="c", value=c, kwargs=kwargs)
assert val == 2

kwargs, c = {"a": 1, "c": 2}, 2
val = renamed_kwargs("b", new_name="c", value=c, kwargs=kwargs)
assert val == 2

with pytest.warns(DeprecationWarning):
kwargs, b = {"a": 1, "b": 3}, 2
val = renamed_kwargs("b", default=b, kwargs=kwargs)
kwargs, c = {"a": 1, "b": 3}, 2
val = renamed_kwargs("b", new_name="c", value=c, kwargs=kwargs)
assert val == 3
assert kwargs == {"a": 1, "b": 3, "c": 3}


def test_removed_arg():
Expand Down