Skip to content

Commit 235e5c5

Browse files
authored
Allow Framework Estimators to use custom image (#223)
* Allow Framework Estimators to use custom image All Estimators can now pass an image_name argument to the constructor to use that image instead of the default sagemaker ones.
1 parent b458d3d commit 235e5c5

File tree

14 files changed

+346
-78
lines changed

14 files changed

+346
-78
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CHANGELOG
88
* enhancement: Let Framework models reuse code uploaded by Framework estimators
99
* enhancement: Unify generation of model uploaded code location
1010
* feature: Change minimum required scipy from 1.0.0 to 0.19.0
11+
* feature: Allow all Framework Estimators to use a custom docker image.
1112

1213
1.5.0
1314
=====

src/sagemaker/chainer/README.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,12 @@ The following are optional arguments. When you create a ``Chainer`` object, you
175175
- ``job_name`` Name to assign for the training job that the fit()
176176
method launches. If not specified, the estimator generates a default
177177
job name, based on the training image name and current timestamp
178+
- ``image_name`` An alternative docker image to use for training and
179+
serving. If specified, the estimator will use this image for training and
180+
hosting, instead of selecting the appropriate SageMaker official image based on
181+
framework_version and py_version. Refer to: `SageMaker Chainer Docker Containers
182+
<#sagemaker-chainer-docker-containers>`_ for details on what the Official images support
183+
and where to find the source code to build your custom image.
178184

179185

180186
Distributed Chainer Training
@@ -657,5 +663,8 @@ Currently supported versions are listed in the above table. You can also set fra
657663
minor version, which will cause your training script to be run on the latest supported patch version of that minor
658664
version.
659665

666+
Alternatively, you can build your own image by following the instructions in the SageMaker Chainer containers
667+
repository, and passing ``image_name`` to the Chainer Estimator constructor.
668+
660669
You can visit the SageMaker Chainer containers repository here: https://github.com/aws/sagemaker-chainer-containers/
661670

src/sagemaker/chainer/estimator.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
from sagemaker.estimator import Framework
16-
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
16+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
1717
from sagemaker.chainer.defaults import CHAINER_VERSION
1818
from sagemaker.chainer.model import ChainerModel
1919

@@ -31,7 +31,7 @@ class Chainer(Framework):
3131

3232
def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_per_host=None,
3333
additional_mpi_options=None, source_dir=None, hyperparameters=None, py_version='py3',
34-
framework_version=CHAINER_VERSION, **kwargs):
34+
framework_version=CHAINER_VERSION, image_name=None, **kwargs):
3535
"""
3636
This ``Estimator`` executes an Chainer script in a managed Chainer execution environment, within a SageMaker
3737
Training Job. The managed Chainer environment is an Amazon-built Docker container that executes functions
@@ -67,9 +67,16 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
6767
One of 'py2' or 'py3'.
6868
framework_version (str): Chainer version you want to use for executing your model training code.
6969
List of supported versions https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators
70+
image_name (str): If specified, the estimator will use this image for training and hosting, instead of
71+
selecting the appropriate SageMaker official image based on framework_version and py_version. It can
72+
be an ECR url or dockerhub image and tag.
73+
Examples:
74+
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
75+
custom-image:latest.
7076
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
7177
"""
72-
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
78+
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
79+
image_name=image_name, **kwargs)
7380
self.py_version = py_version
7481
self.framework_version = framework_version
7582
self.use_mpi = use_mpi
@@ -91,20 +98,6 @@ def hyperparameters(self):
9198
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
9299
return hyperparameters
93100

94-
def train_image(self):
95-
"""Return the Docker image to use for training.
96-
97-
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
98-
find the image to use for model training.
99-
100-
Returns:
101-
str: The URI of the Docker image.
102-
"""
103-
104-
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
105-
self.train_instance_type, framework_version=self.framework_version,
106-
py_version=self.py_version)
107-
108101
def create_model(self, model_server_workers=None):
109102
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``.
110103
@@ -120,7 +113,8 @@ def create_model(self, model_server_workers=None):
120113
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
121114
container_log_level=self.container_log_level, code_location=self.code_location,
122115
py_version=self.py_version, framework_version=self.framework_version,
123-
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
116+
model_server_workers=model_server_workers, image=self.image_name,
117+
sagemaker_session=self.sagemaker_session)
124118

125119
@classmethod
126120
def _prepare_init_params_from_job_description(cls, job_details):
@@ -142,7 +136,14 @@ def _prepare_init_params_from_job_description(cls, job_details):
142136
if value:
143137
init_params[argument[len('sagemaker_'):]] = value
144138

145-
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
139+
image_name = init_params.pop('image')
140+
framework, py_version, tag = framework_name_from_image(image_name)
141+
142+
if not framework:
143+
# If we were unable to parse the framework name from the image it is not one of our
144+
# officially supported images, in this case just add the image to the init params.
145+
init_params['image_name'] = image_name
146+
return init_params
146147

147148
init_params['py_version'] = py_version
148149
init_params['framework_version'] = framework_version_from_tag(tag)

src/sagemaker/estimator.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from six import with_metaclass
2121

2222
from sagemaker.analytics import TrainingJobAnalytics
23-
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
23+
from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, parse_s3_url, UploadedCode,
24+
validate_source_dir)
2425
from sagemaker.job import _Job
2526
from sagemaker.local import LocalSession
2627
from sagemaker.model import Model
@@ -493,7 +494,7 @@ class Framework(EstimatorBase):
493494
"""
494495

495496
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
496-
container_log_level=logging.INFO, code_location=None, **kwargs):
497+
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs):
497498
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
498499
499500
Args:
@@ -513,6 +514,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
513514
code_location (str): Name of the S3 bucket where custom code is uploaded (default: None).
514515
If not specified, default bucket created by ``sagemaker.session.Session`` is used.
515516
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
517+
image_name (str): An alternate image name to use instead of the official Sagemaker image
518+
for the framework. This is useful to run one of the Sagemaker supported frameworks
519+
with an image containing custom dependencies.
516520
"""
517521
super(Framework, self).__init__(**kwargs)
518522
self.source_dir = source_dir
@@ -521,6 +525,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
521525
self.container_log_level = container_log_level
522526
self._hyperparameters = hyperparameters or {}
523527
self.code_location = code_location
528+
self.image_name = image_name
524529

525530
def _prepare_for_training(self, job_name=None):
526531
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
@@ -632,6 +637,21 @@ def _prepare_init_params_from_job_description(cls, job_details):
632637

633638
return init_params
634639

640+
def train_image(self):
641+
"""Return the Docker image to use for training.
642+
643+
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training,
644+
calls this method to find the image to use for model training.
645+
646+
Returns:
647+
str: The URI of the Docker image.
648+
"""
649+
if self.image_name:
650+
return self.image_name
651+
else:
652+
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
653+
self.train_instance_type, self.framework_version, py_version=self.py_version)
654+
635655
@classmethod
636656
def attach(cls, training_job_name, sagemaker_session=None):
637657
"""Attach to an existing training job.

src/sagemaker/mxnet/README.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ The following are optional arguments. When you create an ``MXNet`` object, you c
153153
- ``job_name`` Name to assign for the training job that the fit()
154154
method launches. If not specified, the estimator generates a default
155155
job name, based on the training image name and current timestamp
156+
- ``image_name`` An alternative docker image to use for training and
157+
serving. If specified, the estimator will use this image for training and
158+
hosting, instead of selecting the appropriate SageMaker official image based on
159+
framework_version and py_version. Refer to: `SageMaker MXNet Docker Containers
160+
<#sagemaker-mxnet-docker-containers>`_ for details on what the Official images support
161+
and where to find the source code to build your custom image.
156162

157163
Calling fit
158164
^^^^^^^^^^^
@@ -595,5 +601,6 @@ The Docker images have the following dependencies installed:
595601
The Docker images extend Ubuntu 16.04.
596602

597603
You can select version of MXNet by passing a ``framework_version`` keyword arg to the MXNet Estimator constructor. Currently supported versions are listed in the above table. You can also set ``framework_version`` to only specify major and minor version, e.g ``1.1``, which will cause your training script to be run on the latest supported patch version of that minor version, which in this example would be 1.1.0.
604+
Alternatively, you can build your own image by following the instructions in the SageMaker MXNet containers repository, and passing ``image_name`` to the MXNet Estimator constructor.
598605

599606
You can visit the SageMaker MXNet containers repository here: https://github.com/aws/sagemaker-mxnet-containers/

src/sagemaker/mxnet/estimator.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
from sagemaker.estimator import Framework
16-
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
16+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
1717
from sagemaker.mxnet.defaults import MXNET_VERSION
1818
from sagemaker.mxnet.model import MXNetModel
1919

@@ -24,7 +24,7 @@ class MXNet(Framework):
2424
__framework_name__ = "mxnet"
2525

2626
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2',
27-
framework_version=MXNET_VERSION, **kwargs):
27+
framework_version=MXNET_VERSION, image_name=None, **kwargs):
2828
"""
2929
This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker
3030
Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions
@@ -52,25 +52,19 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
5252
One of 'py2' or 'py3'.
5353
framework_version (str): MXNet version you want to use for executing your model training code.
5454
List of supported versions https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators
55+
image_name (str): If specified, the estimator will use this image for training and hosting, instead of
56+
selecting the appropriate SageMaker official image based on framework_version and py_version. It can
57+
be an ECR url or dockerhub image and tag.
58+
Examples:
59+
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
60+
custom-image:latest.
5561
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
5662
"""
57-
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
63+
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters,
64+
image_name=image_name, **kwargs)
5865
self.py_version = py_version
5966
self.framework_version = framework_version
6067

61-
def train_image(self):
62-
"""Return the Docker image to use for training.
63-
64-
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
65-
find the image to use for model training.
66-
67-
Returns:
68-
str: The URI of the Docker image.
69-
"""
70-
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
71-
self.train_instance_type, framework_version=self.framework_version,
72-
py_version=self.py_version)
73-
7468
def create_model(self, model_server_workers=None):
7569
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``.
7670
@@ -85,7 +79,7 @@ def create_model(self, model_server_workers=None):
8579
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
8680
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
8781
container_log_level=self.container_log_level, code_location=self.code_location,
88-
py_version=self.py_version, framework_version=self.framework_version,
82+
py_version=self.py_version, framework_version=self.framework_version, image=self.image_name,
8983
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
9084

9185
@classmethod
@@ -100,7 +94,14 @@ def _prepare_init_params_from_job_description(cls, job_details):
10094
10195
"""
10296
init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details)
103-
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
97+
image_name = init_params.pop('image')
98+
framework, py_version, tag = framework_name_from_image(image_name)
99+
100+
if not framework:
101+
# If we were unable to parse the framework name from the image it is not one of our
102+
# officially supported images, in this case just add the image to the init params.
103+
init_params['image_name'] = image_name
104+
return init_params
104105

105106
init_params['py_version'] = py_version
106107

src/sagemaker/pytorch/README.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,12 @@ The following are optional arguments. When you create a ``PyTorch`` object, you
204204
- ``job_name`` Name to assign for the training job that the ``fit```
205205
method launches. If not specified, the estimator generates a default
206206
job name, based on the training image name and current timestamp
207-
207+
- ``image_name`` An alternative docker image to use for training and
208+
serving. If specified, the estimator will use this image for training and
209+
hosting, instead of selecting the appropriate SageMaker official image based on
210+
framework_version and py_version. Refer to: `SageMaker PyTorch Docker Containers
211+
<#sagemaker-pytorch-docker-containers>`_ for details on what the Official images support
212+
and where to find the source code to build your custom image.
208213

209214
Calling fit
210215
~~~~~~~~~~~
@@ -705,4 +710,7 @@ Currently supported versions are listed in the above table. You can also set ``f
705710
minor version, which will cause your training script to be run on the latest supported patch version of that minor
706711
version.
707712

713+
Alternatively, you can build your own image by following the instructions in the SageMaker Chainer containers
714+
repository, and passing ``image_name`` to the Chainer Estimator constructor.
715+
708716
You can visit `the SageMaker PyTorch containers repository <https://github.com/aws/sagemaker-pytorch-containers>`_.

0 commit comments

Comments
 (0)