-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Support MXNet 1.3 with its training script format changes #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
b7db5a0
fe2e9bf
2008468
f34b128
ca21338
ffd20f6
51fe929
9e9a3be
fe7bab3
1fddc2e
072f19f
7175165
623c722
3e2ab64
82257f2
eaac852
faec731
9d69562
e018f42
01dcb9d
1ff9836
6a44d88
c9ef5dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,10 +27,13 @@ | |
class MXNet(Framework): | ||
"""Handle end-to-end training and deployment of custom MXNet code.""" | ||
|
||
__framework_name__ = "mxnet" | ||
__framework_name__ = 'mxnet' | ||
|
||
LOWEST_SCRIPT_MODE_VERSION = ['1', '3'] | ||
LAUNCH_PS_ENV_NAME = 'sagemaker_mxnet_launch_parameter_server' | ||
|
||
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2', | ||
framework_version=None, image_name=None, **kwargs): | ||
framework_version=None, image_name=None, launch_parameter_server=False, **kwargs): | ||
""" | ||
This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker | ||
Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions | ||
|
@@ -64,15 +67,28 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio | |
Examples: | ||
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0 | ||
custom-image:latest. | ||
launch_parameter_server (bool): Whether or not to launch the default parameter server | ||
implementation for use with distributed training (default: False). Valid for only | ||
versions 1.3 and higher of MXNet. | ||
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. | ||
""" | ||
if framework_version is None: | ||
logger.warning(empty_framework_version_warning(MXNET_VERSION)) | ||
self.framework_version = framework_version or MXNET_VERSION | ||
|
||
if self._script_mode_version(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still launch the parameter server with single host training? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, one can still use the kvstore needed even with only one host |
||
hyperparameters = hyperparameters or {} | ||
hyperparameters[self.LAUNCH_PS_ENV_NAME] = launch_parameter_server | ||
else: | ||
if launch_parameter_server: | ||
raise ValueError('launch_parameter_server is used for only versions 1.3 and higher') | ||
|
||
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters, | ||
image_name=image_name, **kwargs) | ||
self.py_version = py_version | ||
|
||
if framework_version is None: | ||
logger.warning(empty_framework_version_warning(MXNET_VERSION)) | ||
self.framework_version = framework_version or MXNET_VERSION | ||
def _script_mode_version(self): | ||
return self.framework_version.split('.') >= self.LOWEST_SCRIPT_MODE_VERSION | ||
laurenyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT): | ||
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we bump it to 2.x since this is a breaking change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, this PR doesn't technically have breaking changes because we're not bumping the default version of MXNet. I was going to wait until the PR that makes
framework_version
required.