Skip to content

Commit f55bc9d

Browse files
authored
breaking: rename distributions to distribution in TF/MXNet estimators (#1662)
This was a typo that wasn't caught until after a couple releases, so the consensus was to wait until other breaking changes to fix this.
1 parent 75198c3 commit f55bc9d

File tree

16 files changed

+185
-66
lines changed

16 files changed

+185
-66
lines changed

doc/frameworks/mxnet/using_mxnet.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ If you want to use parameter servers for distributed training, set the following
200200

201201
.. code:: python
202202
203-
distributions={'parameter_server': {'enabled': True}}
203+
distribution={'parameter_server': {'enabled': True}}
204204
205205
Then, when writing a distributed training script, use an MXNet kvstore to store and share model parameters.
206206
During training, Amazon SageMaker automatically starts an MXNet kvstore server and scheduler processes on hosts in your training job cluster.

doc/frameworks/tensorflow/using_tf.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,12 @@ Distributed Training
212212

213213
To run your training job with multiple instances in a distributed fashion, set ``train_instance_count``
214214
to a number larger than 1. We support two different types of distributed training, parameter server and Horovod.
215-
The ``distributions`` parameter is used to configure which distributed training strategy to use.
215+
The ``distribution`` parameter is used to configure which distributed training strategy to use.
216216

217217
Training with parameter servers
218218
-------------------------------
219219

220-
If you specify parameter_server as the value of the distributions parameter, the container launches a parameter server
220+
If you specify parameter_server as the value of the distribution parameter, the container launches a parameter server
221221
thread on each instance in the training cluster, and then executes your training code. You can find more information on
222222
TensorFlow distributed training at `TensorFlow docs <https://www.tensorflow.org/deploy/distributed>`__.
223223
To enable parameter server training:
@@ -229,7 +229,7 @@ To enable parameter server training:
229229
tf_estimator = TensorFlow(entry_point='tf-train.py', role='SageMakerRole',
230230
train_instance_count=2, train_instance_type='ml.p2.xlarge',
231231
framework_version='1.11', py_version='py3',
232-
distributions={'parameter_server': {'enabled': True}})
232+
distribution={'parameter_server': {'enabled': True}})
233233
tf_estimator.fit('s3://bucket/path/to/training/data')
234234
235235
Training with Horovod
@@ -241,7 +241,7 @@ You can find more details at `Horovod README <https://github.com/uber/horovod>`_
241241
The container sets up the MPI environment and executes the ``mpirun`` command, enabling you to run any Horovod
242242
training script.
243243

244-
Training with ``MPI`` is configured by specifying following fields in ``distributions``:
244+
Training with ``MPI`` is configured by specifying following fields in ``distribution``:
245245

246246
- ``enabled (bool)``: If set to ``True``, the MPI setup is performed and ``mpirun`` command is executed.
247247
- ``processes_per_host (int)``: Number of processes MPI should launch on each host. Note, this should not be
@@ -260,7 +260,7 @@ In the below example we create an estimator to launch Horovod distributed traini
260260
tf_estimator = TensorFlow(entry_point='tf-train.py', role='SageMakerRole',
261261
train_instance_count=1, train_instance_type='ml.p3.8xlarge',
262262
framework_version='2.1.0', py_version='py3',
263-
distributions={
263+
distribution={
264264
'mpi': {
265265
'enabled': True,
266266
'processes_per_host': 4,

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
modifiers.tfs.TensorFlowServingConstructorRenamer(),
2626
modifiers.predictors.PredictorConstructorRefactor(),
2727
modifiers.airflow.ModelConfigArgModifier(),
28+
modifiers.estimators.DistributionParameterRenamer(),
2829
]
2930

3031
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
1717
airflow,
1818
deprecated_params,
19+
estimators,
1920
framework_version,
2021
predictors,
2122
tf_legacy_mode,
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify Predictor code to be compatible
14+
with version 2.0 and later of the SageMaker Python SDK.
15+
"""
16+
from __future__ import absolute_import
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import matching
19+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
20+
21+
ESTIMATORS_WITH_DISTRIBUTION_PARAM = {
22+
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
23+
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
24+
}
25+
26+
27+
class DistributionParameterRenamer(Modifier):
28+
"""A class to rename the ``distributions`` attribute in MXNet and TensorFlow estimators."""
29+
30+
def node_should_be_modified(self, node):
31+
"""Checks if the ``ast.Call`` node instantiates an MXNet or TensorFlow estimator and
32+
contains the ``distributions`` parameter.
33+
34+
This looks for the following calls:
35+
36+
- ``<Framework>``
37+
- ``sagemaker.<framework>.<Framework>``
38+
- ``sagemaker.<framework>.estimator.<Framework>``
39+
40+
where ``<Framework>`` is either ``TensorFlow`` or ``MXNet``.
41+
42+
Args:
43+
node (ast.Call): a node that represents a function call. For more,
44+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
45+
46+
Returns:
47+
bool: If the ``ast.Call`` instantiates an MXNet or TensorFlow estimator with
48+
the ``distributions`` parameter.
49+
"""
50+
return matching.matches_any(
51+
node, ESTIMATORS_WITH_DISTRIBUTION_PARAM
52+
) and self._has_distribution_arg(node)
53+
54+
def _has_distribution_arg(self, node):
55+
"""Checks if the node has the ``distributions`` parameter in its keywords."""
56+
for kw in node.keywords:
57+
if kw.arg == "distributions":
58+
return True
59+
60+
return False
61+
62+
def modify_node(self, node):
63+
"""Modifies the ``ast.Call`` node to rename the ``distributions`` attribute to
64+
``distribution``.
65+
66+
Args:
67+
node (ast.Call): a node that represents an MXNet or TensorFlow constructor.
68+
"""
69+
for kw in node.keywords:
70+
if kw.arg == "distributions":
71+
kw.arg = "distribution"
72+
break

src/sagemaker/fw_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -599,15 +599,15 @@ def later_framework_version_warning(latest_version):
599599
return LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version)
600600

601601

602-
def warn_if_parameter_server_with_multi_gpu(training_instance_type, distributions):
602+
def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution):
603603
"""Warn the user that training will not fully leverage all the GPU
604604
cores if parameter server is enabled and a multi-GPU instance is selected.
605605
Distributed training with the default parameter server setup doesn't
606606
support multi-GPU instances.
607607
608608
Args:
609609
training_instance_type (str): A string representing the type of training instance selected.
610-
distributions (dict): A dictionary with information to enable distributed training.
610+
distribution (dict): A dictionary with information to enable distributed training.
611611
(Defaults to None if distributed training is not enabled.) For example:
612612
613613
.. code:: python
@@ -621,15 +621,15 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
621621
622622
623623
"""
624-
if training_instance_type == "local" or distributions is None:
624+
if training_instance_type == "local" or distribution is None:
625625
return
626626

627627
is_multi_gpu_instance = (
628628
training_instance_type == "local_gpu"
629629
or training_instance_type.split(".")[1].startswith("p")
630630
) and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES
631631

632-
ps_enabled = "parameter_server" in distributions and distributions["parameter_server"].get(
632+
ps_enabled = "parameter_server" in distribution and distribution["parameter_server"].get(
633633
"enabled", False
634634
)
635635

src/sagemaker/mxnet/estimator.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
framework_version_from_tag,
2222
is_version_equal_or_higher,
2323
python_deprecation_warning,
24-
parameter_v2_rename_warning,
2524
validate_version_or_image_args,
2625
warn_if_parameter_server_with_multi_gpu,
2726
)
@@ -46,7 +45,7 @@ def __init__(
4645
source_dir=None,
4746
hyperparameters=None,
4847
image_name=None,
49-
distributions=None,
48+
distribution=None,
5049
**kwargs
5150
):
5251
"""This ``Estimator`` executes an MXNet script in a managed MXNet
@@ -100,7 +99,7 @@ def __init__(
10099
If ``framework_version`` or ``py_version`` are ``None``, then
101100
``image_name`` is required. If also ``None``, then a ``ValueError``
102101
will be raised.
103-
distributions (dict): A dictionary with information on how to run distributed
102+
distribution (dict): A dictionary with information on how to run distributed
104103
training (default: None). To have parameter servers launched for training,
105104
set this value to be ``{'parameter_server': {'enabled': True}}``.
106105
**kwargs: Additional kwargs passed to the
@@ -131,35 +130,34 @@ def __init__(
131130
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
132131
)
133132

134-
if distributions is not None:
135-
logger.warning(parameter_v2_rename_warning("distributions", "distribution"))
133+
if distribution is not None:
136134
train_instance_type = kwargs.get("train_instance_type")
137135
warn_if_parameter_server_with_multi_gpu(
138-
training_instance_type=train_instance_type, distributions=distributions
136+
training_instance_type=train_instance_type, distribution=distribution
139137
)
140138

141-
self._configure_distribution(distributions)
139+
self._configure_distribution(distribution)
142140

143-
def _configure_distribution(self, distributions):
141+
def _configure_distribution(self, distribution):
144142
"""
145143
Args:
146-
distributions:
144+
distribution:
147145
"""
148-
if distributions is None:
146+
if distribution is None:
149147
return
150148

151149
if (
152150
self.framework_version
153151
and self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION
154152
):
155153
raise ValueError(
156-
"The distributions option is valid for only versions {} and higher".format(
154+
"The distribution option is valid for only versions {} and higher".format(
157155
".".join(self._LOWEST_SCRIPT_MODE_VERSION)
158156
)
159157
)
160158

161-
if "parameter_server" in distributions:
162-
enabled = distributions["parameter_server"].get("enabled", False)
159+
if "parameter_server" in distribution:
160+
enabled = distribution["parameter_server"].get("enabled", False)
163161
self._hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
164162

165163
def create_model(

src/sagemaker/tensorflow/estimator.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
framework_version=None,
4646
model_dir=None,
4747
image_name=None,
48-
distributions=None,
48+
distribution=None,
4949
**kwargs
5050
):
5151
"""Initialize a ``TensorFlow`` estimator.
@@ -81,7 +81,7 @@ def __init__(
8181
If ``framework_version`` or ``py_version`` are ``None``, then
8282
``image_name`` is required. If also ``None``, then a ``ValueError``
8383
will be raised.
84-
distributions (dict): A dictionary with information on how to run distributed training
84+
distribution (dict): A dictionary with information on how to run distributed training
8585
(default: None). Currently we support distributed training with parameter servers
8686
and MPI.
8787
To enable parameter server use the following setup:
@@ -122,11 +122,10 @@ def __init__(
122122
self.framework_version = framework_version
123123
self.py_version = py_version
124124

125-
if distributions is not None:
126-
logger.warning(fw.parameter_v2_rename_warning("distribution", distributions))
125+
if distribution is not None:
127126
train_instance_type = kwargs.get("train_instance_type")
128127
fw.warn_if_parameter_server_with_multi_gpu(
129-
training_instance_type=train_instance_type, distributions=distributions
128+
training_instance_type=train_instance_type, distribution=distribution
130129
)
131130

132131
if "enable_sagemaker_metrics" not in kwargs:
@@ -137,7 +136,7 @@ def __init__(
137136
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
138137

139138
self.model_dir = model_dir
140-
self.distributions = distributions or {}
139+
self.distribution = distribution or {}
141140

142141
self._validate_args(py_version=py_version)
143142

@@ -295,13 +294,13 @@ def hyperparameters(self):
295294
hyperparameters = super(TensorFlow, self).hyperparameters()
296295
additional_hyperparameters = {}
297296

298-
if "parameter_server" in self.distributions:
299-
ps_enabled = self.distributions["parameter_server"].get("enabled", False)
297+
if "parameter_server" in self.distribution:
298+
ps_enabled = self.distribution["parameter_server"].get("enabled", False)
300299
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
301300

302301
mpi_enabled = False
303-
if "mpi" in self.distributions:
304-
mpi_dict = self.distributions["mpi"]
302+
if "mpi" in self.distribution:
303+
mpi_dict = self.distribution["mpi"]
305304
mpi_enabled = mpi_dict.get("enabled", False)
306305
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
307306

@@ -338,7 +337,7 @@ def _validate_and_set_debugger_configs(self):
338337
339338
Else, set default HookConfig
340339
"""
341-
ps_enabled = "parameter_server" in self.distributions and self.distributions[
340+
ps_enabled = "parameter_server" in self.distribution and self.distribution[
342341
"parameter_server"
343342
].get("enabled", False)
344343
if ps_enabled:

tests/integ/test_horovod.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_horovod_local_mode(
8282
output_path=output_path,
8383
framework_version=tf_training_latest_version,
8484
py_version=tf_training_latest_py_version,
85-
distributions={"mpi": {"enabled": True, "processes_per_host": processes}},
85+
distribution={"mpi": {"enabled": True, "processes_per_host": processes}},
8686
)
8787

8888
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
@@ -128,7 +128,7 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
128128
sagemaker_session=sagemaker_session,
129129
py_version=py_version,
130130
framework_version=tf_version,
131-
distributions={"mpi": {"enabled": True}},
131+
distribution={"mpi": {"enabled": True}},
132132
)
133133

134134
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):

tests/integ/test_local_mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_mxnet_distributed_local_mode(
169169
train_instance_type="local",
170170
sagemaker_session=sagemaker_local_session,
171171
framework_version=mxnet_full_version,
172-
distributions={"parameter_server": {"enabled": True}},
172+
distribution={"parameter_server": {"enabled": True}},
173173
)
174174

175175
train_input = mx.sagemaker_session.upload_data(

tests/integ/test_mxnet_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def test_async_fit(sagemaker_session, mxnet_full_version, mxnet_full_py_version,
300300
train_instance_type=cpu_instance_type,
301301
sagemaker_session=sagemaker_session,
302302
framework_version=mxnet_full_version,
303-
distributions={"parameter_server": {"enabled": True}},
303+
distribution={"parameter_server": {"enabled": True}},
304304
)
305305

306306
train_input = mx.sagemaker_session.upload_data(

tests/integ/test_tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_mnist_distributed(
134134
sagemaker_session=sagemaker_session,
135135
framework_version=tf_training_latest_version,
136136
py_version=tf_training_latest_py_version,
137-
distributions=PARAMETER_SERVER_DISTRIBUTION,
137+
distribution=PARAMETER_SERVER_DISTRIBUTION,
138138
)
139139
inputs = estimator.sagemaker_session.upload_data(
140140
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/distributed_mnist"

0 commit comments

Comments
 (0)