Skip to content

Support MXNet 1.3 with its training script format changes #446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
CHANGELOG
=========

========
1.13.0
========
======

* feature: Estimator: add input mode to training channels
* feature: Estimator: add model_uri and model_channel_name parameters
* enhancement: Local Mode: support output_path. Can be either file:// or s3://
* enhancement: Added image uris for SageMaker built-in algorithms for SIN/LHR/BOM/SFO/YUL
* feature: Estimators: add support for MXNet 1.3.0, which introduces a new training script format
* feature: Documentation: add explanation for the new training script format used with MXNet
* feature: Estimators: add ``distributions`` for customizing distributed training with the new training script format

1.12.0
======
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
custom-image:latest.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
"""
if framework_version is None:
logger.warning(empty_framework_version_warning(CHAINER_VERSION))
self.framework_version = framework_version or CHAINER_VERSION

super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
image_name=image_name, **kwargs)
self.py_version = py_version
Expand All @@ -89,10 +93,6 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
self.process_slots_per_host = process_slots_per_host
self.additional_mpi_options = additional_mpi_options

if framework_version is None:
logger.warning(empty_framework_version_warning(CHAINER_VERSION))
self.framework_version = framework_version or CHAINER_VERSION

def hyperparameters(self):
"""Return hyperparameters used by your custom Chainer code during training."""
hyperparameters = super(Chainer, self).hyperparameters()
Expand Down
27 changes: 25 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,12 @@ class Framework(EstimatorBase):
such as training/deployment images and predictor instances.
"""

_DISTRIBUTION_SUPPORTED_FRAMEWORKS = ('mxnet',)
LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'

def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs):
container_log_level=logging.INFO, code_location=None, image_name=None,
distributions=None, **kwargs):
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``

Args:
Expand All @@ -650,6 +654,8 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
image_name (str): An alternate image name to use instead of the official Sagemaker image
for the framework. This is useful to run one of the Sagemaker supported frameworks
with an image containing custom dependencies.
distributions (dict): A dictionary with information on how to run distributed training
(default: None).
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
"""
super(Framework, self).__init__(**kwargs)
Expand All @@ -660,10 +666,27 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
DeprecationWarning)
self.enable_cloudwatch_metrics = False
self.container_log_level = container_log_level
self._hyperparameters = hyperparameters or {}
self.code_location = code_location
self.image_name = image_name

self._hyperparameters = hyperparameters or {}
self._configure_distributions(distributions)

def _configure_distributions(self, distributions):
if distributions is None:
return

if self.__framework_name__ not in self._DISTRIBUTION_SUPPORTED_FRAMEWORKS:
raise ValueError('This framework does not support the distributions option.')

if self.framework_version.split('.') < self._LOWEST_SCRIPT_MODE_VERSION:
raise ValueError('The distributions option is valid for only versions {} and higher'
.format('.'.join(self._LOWEST_SCRIPT_MODE_VERSION)))

if 'parameter_server' in distributions:
enabled = distributions['parameter_server'].get('enabled', False)
self._hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled

def _prepare_for_training(self, job_name=None):
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.

Expand Down
Loading