Skip to content

Commit 868f81b

Browse files
authored
Support MXNet 1.3 with its training script format changes (#446)
1 parent bc45bbd commit 868f81b

File tree

17 files changed

+342
-156
lines changed

17 files changed

+342
-156
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
CHANGELOG
33
=========
44

5-
========
65
1.13.0
7-
========
6+
======
87

98
* feature: Estimator: add input mode to training channels
109
* feature: Estimator: add model_uri and model_channel_name parameters
1110
* enhancement: Local Mode: support output_path. Can be either file:// or s3://
1211
* enhancement: Added image uris for SageMaker built-in algorithms for SIN/LHR/BOM/SFO/YUL
12+
* feature: Estimators: add support for MXNet 1.3.0, which introduces a new training script format
13+
* feature: Documentation: add explanation for the new training script format used with MXNet
14+
* feature: Estimators: add ``distributions`` for customizing distributed training with the new training script format
1315

1416
1.12.0
1517
======

src/sagemaker/chainer/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
8181
custom-image:latest.
8282
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
8383
"""
84+
if framework_version is None:
85+
logger.warning(empty_framework_version_warning(CHAINER_VERSION))
86+
self.framework_version = framework_version or CHAINER_VERSION
87+
8488
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
8589
image_name=image_name, **kwargs)
8690
self.py_version = py_version
@@ -89,10 +93,6 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
8993
self.process_slots_per_host = process_slots_per_host
9094
self.additional_mpi_options = additional_mpi_options
9195

92-
if framework_version is None:
93-
logger.warning(empty_framework_version_warning(CHAINER_VERSION))
94-
self.framework_version = framework_version or CHAINER_VERSION
95-
9696
def hyperparameters(self):
9797
"""Return hyperparameters used by your custom Chainer code during training."""
9898
hyperparameters = super(Chainer, self).hyperparameters()

src/sagemaker/estimator.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,8 +627,12 @@ class Framework(EstimatorBase):
627627
such as training/deployment images and predictor instances.
628628
"""
629629

630+
_DISTRIBUTION_SUPPORTED_FRAMEWORKS = ('mxnet',)
631+
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
632+
630633
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
631-
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs):
634+
container_log_level=logging.INFO, code_location=None, image_name=None,
635+
distributions=None, **kwargs):
632636
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
633637
634638
Args:
@@ -650,6 +654,8 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
650654
image_name (str): An alternate image name to use instead of the official Sagemaker image
651655
for the framework. This is useful to run one of the Sagemaker supported frameworks
652656
with an image containing custom dependencies.
657+
distributions (dict): A dictionary with information on how to run distributed training
658+
(default: None).
653659
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
654660
"""
655661
super(Framework, self).__init__(**kwargs)
@@ -660,10 +666,27 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
660666
DeprecationWarning)
661667
self.enable_cloudwatch_metrics = False
662668
self.container_log_level = container_log_level
663-
self._hyperparameters = hyperparameters or {}
664669
self.code_location = code_location
665670
self.image_name = image_name
666671

672+
self._hyperparameters = hyperparameters or {}
673+
self._configure_distributions(distributions)
674+
675+
def _configure_distributions(self, distributions):
676+
if distributions is None:
677+
return
678+
679+
if self.__framework_name__ not in self._DISTRIBUTION_SUPPORTED_FRAMEWORKS:
680+
raise ValueError('This framework does not support the distributions option.')
681+
682+
if self.framework_version.split('.') < self._LOWEST_SCRIPT_MODE_VERSION:
683+
raise ValueError('The distributions option is valid for only versions {} and higher'
684+
.format('.'.join(self._LOWEST_SCRIPT_MODE_VERSION)))
685+
686+
if 'parameter_server' in distributions:
687+
enabled = distributions['parameter_server'].get('enabled', False)
688+
self._hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
689+
667690
def _prepare_for_training(self, job_name=None):
668691
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
669692

0 commit comments

Comments
 (0)