Skip to content

Commit 9dbf028

Browse files
fix: use local updated args; use train_max_wait (#1945)
Co-authored-by: Ajay Karpur <[email protected]>
1 parent 18f95af commit 9dbf028

File tree

8 files changed

+24
-10
lines changed

8 files changed

+24
-10
lines changed

doc/v2.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ The following estimator parameters have been renamed:
231231
+------------------------------+------------------------+
232232
| ``train_use_spot_instances`` | ``use_spot_instances`` |
233233
+------------------------------+------------------------+
234-
| ``train_max_run_wait`` | ``max_wait`` |
234+
| ``train_max_wait`` | ``max_wait`` |
235235
+------------------------------+------------------------+
236236
| ``train_volume_size`` | ``volume_size`` |
237237
+------------------------------+------------------------+

src/sagemaker/cli/compatibility/v2/modifiers/training_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"train_instance_count",
4646
"train_instance_type",
4747
"train_max_run",
48-
"train_max_run_wait",
48+
"train_max_wait",
4949
"train_use_spot_instances",
5050
"train_volume_size",
5151
"train_volume_kms_key",
@@ -63,7 +63,7 @@ def node_should_be_modified(self, node):
6363
- ``train_instance_count``
6464
- ``train_instance_type``
6565
- ``train_max_run``
66-
- ``train_max_run_wait``
66+
- ``train_max_wait``
6767
- ``train_use_spot_instances``
6868
- ``train_volume_kms_key``
6969
- ``train_volume_size``

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def __init__(
242242
use_spot_instances = renamed_kwargs(
243243
"train_use_spot_instances", "use_spot_instances", use_spot_instances, kwargs
244244
)
245-
max_wait = renamed_kwargs("train_max_run_wait", "max_wait", max_wait, kwargs)
245+
max_wait = renamed_kwargs("train_max_wait", "max_wait", max_wait, kwargs)
246246
volume_size = renamed_kwargs("train_volume_size", "volume_size", volume_size, kwargs)
247247
volume_kms_key = renamed_kwargs(
248248
"train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs

src/sagemaker/mxnet/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def __init__(
150150
:class:`~sagemaker.estimator.EstimatorBase`.
151151
"""
152152
distribution = renamed_kwargs("distributions", "distribution", distribution, kwargs)
153+
instance_type = renamed_kwargs(
154+
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
155+
)
153156
validate_version_or_image_args(framework_version, py_version, image_uri)
154157
if py_version == "py2":
155158
logger.warning(
@@ -168,7 +171,6 @@ def __init__(
168171
)
169172

170173
if distribution is not None:
171-
instance_type = kwargs.get("instance_type")
172174
warn_if_parameter_server_with_multi_gpu(
173175
training_instance_type=instance_type, distribution=distribution
174176
)

src/sagemaker/sklearn/estimator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717

1818
from sagemaker import image_uris
19+
from sagemaker.deprecations import renamed_kwargs
1920
from sagemaker.estimator import Framework
2021
from sagemaker.fw_utils import (
2122
framework_name_from_image,
@@ -107,6 +108,12 @@ def __init__(
107108
:class:`~sagemaker.estimator.Framework` and
108109
:class:`~sagemaker.estimator.EstimatorBase`.
109110
"""
111+
instance_type = renamed_kwargs(
112+
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
113+
)
114+
instance_count = renamed_kwargs(
115+
"train_instance_count", "instance_count", kwargs.get("instance_count"), kwargs
116+
)
110117
validate_version_or_image_args(framework_version, py_version, image_uri)
111118
if py_version and py_version != "py3":
112119
raise AttributeError(
@@ -117,10 +124,8 @@ def __init__(
117124

118125
# SciKit-Learn does not support distributed training or training on GPU instance types.
119126
# Fail fast.
120-
instance_type = kwargs.get("instance_type")
121127
_validate_not_gpu_instance_type(instance_type)
122128

123-
instance_count = kwargs.get("instance_count")
124129
if instance_count:
125130
if instance_count != 1:
126131
raise AttributeError(

src/sagemaker/tensorflow/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker import image_uris, s3, utils
2121
from sagemaker.debugger import DebuggerHookConfig
22+
from sagemaker.deprecations import renamed_kwargs
2223
from sagemaker.estimator import Framework
2324
import sagemaker.fw_utils as fw
2425
from sagemaker.tensorflow import defaults
@@ -112,6 +113,9 @@ def __init__(
112113
:class:`~sagemaker.estimator.Framework` and
113114
:class:`~sagemaker.estimator.EstimatorBase`.
114115
"""
116+
instance_type = renamed_kwargs(
117+
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
118+
)
115119
fw.validate_version_or_image_args(framework_version, py_version, image_uri)
116120
if py_version == "py2":
117121
logger.warning(
@@ -121,7 +125,6 @@ def __init__(
121125
self.py_version = py_version
122126

123127
if distribution is not None:
124-
instance_type = kwargs.get("instance_type")
125128
fw.warn_if_parameter_server_with_multi_gpu(
126129
training_instance_type=instance_type, distribution=distribution
127130
)

src/sagemaker/xgboost/estimator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717

1818
from sagemaker import image_uris
19+
from sagemaker.deprecations import renamed_kwargs
1920
from sagemaker.estimator import Framework, _TrainingJob
2021
from sagemaker.fw_utils import (
2122
framework_name_from_image,
@@ -95,6 +96,9 @@ def __init__(
9596
:class:`~sagemaker.estimator.Framework` and
9697
:class:`~sagemaker.estimator.EstimatorBase`.
9798
"""
99+
instance_type = renamed_kwargs(
100+
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
101+
)
98102
super(XGBoost, self).__init__(
99103
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
100104
)
@@ -111,7 +115,7 @@ def __init__(
111115
self.sagemaker_session.boto_region_name,
112116
version=framework_version,
113117
py_version=self.py_version,
114-
instance_type=kwargs.get("instance_type"),
118+
instance_type=instance_type,
115119
image_scope="training",
116120
)
117121

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_training_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"train_instance_count=1",
4949
"train_instance_type='ml.c4.xlarge'",
5050
"train_max_run=8 * 60 * 60",
51-
"train_max_run_wait=1 * 60 * 60",
51+
"train_max_wait=1 * 60 * 60",
5252
"train_use_spot_instances=True",
5353
"train_volume_size=30",
5454
"train_volume_kms_key='key'",

0 commit comments

Comments
 (0)