Skip to content

Commit cdebccb

Browse files
author
Ignacio Quintero
committed
Revert the distributions change
1 parent beb4ddb commit cdebccb

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ CHANGELOG
55
1.14.2dev
66
=========
77

8-
* breaking-change: rename MXNet argument from distributions -> distribution
8+
* build: added pylint
99

1010
1.14.1
1111
======

src/sagemaker/mxnet/estimator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class MXNet(Framework):
3333
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
3434

3535
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2',
36-
framework_version=None, image_name=None, distribution=None, **kwargs):
36+
framework_version=None, image_name=None, distributions=None, **kwargs):
3737
"""
3838
This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker
3939
Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions
@@ -78,18 +78,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
7878
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters,
7979
image_name=image_name, **kwargs)
8080
self.py_version = py_version
81-
self._configure_distribution(distribution)
81+
self._configure_distribution(distributions)
8282

83-
def _configure_distribution(self, distribution):
84-
if distribution is None:
83+
def _configure_distribution(self, distributions):
84+
if distributions is None:
8585
return
8686

8787
if self.framework_version.split('.') < self._LOWEST_SCRIPT_MODE_VERSION:
88-
raise ValueError('The distribution option is valid for only versions {} and higher'
88+
raise ValueError('The distributions option is valid for only versions {} and higher'
8989
.format('.'.join(self._LOWEST_SCRIPT_MODE_VERSION)))
9090

91-
if 'parameter_server' in distribution:
92-
enabled = distribution['parameter_server'].get('enabled', False)
91+
if 'parameter_server' in distributions:
92+
enabled = distributions['parameter_server'].get('enabled', False)
9393
self._hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
9494

9595
def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):

tests/integ/test_mxnet_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_async_fit(sagemaker_session, mxnet_full_version):
8181
mx = MXNet(entry_point=script_path, role='SageMakerRole', py_version=PYTHON_VERSION,
8282
train_instance_count=1, train_instance_type='ml.c4.xlarge',
8383
sagemaker_session=sagemaker_session, framework_version=mxnet_full_version,
84-
distribution={'parameter_server': {'enabled': True}})
84+
distributions={'parameter_server': {'enabled': True}})
8585

8686
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
8787
key_prefix='integ-test-data/mxnet_mnist/train')

tests/unit/test_mxnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,20 @@ def test_attach_custom_image(sagemaker_session):
378378
def test_estimator_script_mode_launch_parameter_server(sagemaker_session):
379379
mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
380380
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
381-
distribution=LAUNCH_PS_DISTRIBUTIONS_DICT, framework_version='1.3.0')
381+
distributions=LAUNCH_PS_DISTRIBUTIONS_DICT, framework_version='1.3.0')
382382
assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == 'true'
383383

384384

385385
def test_estimator_script_mode_dont_launch_parameter_server(sagemaker_session):
386386
mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
387387
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
388-
distribution={'parameter_server': {'enabled': False}}, framework_version='1.3.0')
388+
distributions={'parameter_server': {'enabled': False}}, framework_version='1.3.0')
389389
assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == 'false'
390390

391391

392392
def test_estimator_wrong_version_launch_parameter_server(sagemaker_session):
393393
with pytest.raises(ValueError) as e:
394394
MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
395395
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
396-
distribution=LAUNCH_PS_DISTRIBUTIONS_DICT, framework_version='1.2.1')
397-
assert 'The distribution option is valid for only versions 1.3 and higher' in str(e)
396+
distributions=LAUNCH_PS_DISTRIBUTIONS_DICT, framework_version='1.2.1')
397+
assert 'The distributions option is valid for only versions 1.3 and higher' in str(e)

0 commit comments

Comments
 (0)