Skip to content

Commit 6965ca1

Browse files
author
Ignacio Quintero
committed
Allow Framework Estimators to use custom image
Chainer, Tensorflow and MXNet estimators can now pass an image_name argument to the constructor to use that image instead of the default sagemaker ones.
1 parent 3fb5516 commit 6965ca1

File tree

8 files changed

+242
-59
lines changed

8 files changed

+242
-59
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CHANGELOG
77

88
* bug-fix: Unit Tests: Improve unit test runtime
99
* bug-fix: Estimators: Fix attach for LDA
10+
* feature: Allow Chainer, Tensorflow and MXNet estimators to use a custom docker image.
1011

1112
1.4.1
1213
=====

src/sagemaker/chainer/estimator.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
from sagemaker.estimator import Framework
16-
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
16+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
1717
from sagemaker.chainer.defaults import CHAINER_VERSION
1818
from sagemaker.chainer.model import ChainerModel
1919

@@ -31,7 +31,7 @@ class Chainer(Framework):
3131

3232
def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_per_host=None,
3333
additional_mpi_options=None, source_dir=None, hyperparameters=None, py_version='py3',
34-
framework_version=CHAINER_VERSION, **kwargs):
34+
framework_version=CHAINER_VERSION, image_name=None, **kwargs):
3535
"""
3636
This ``Estimator`` executes an Chainer script in a managed Chainer execution environment, within a SageMaker
3737
Training Job. The managed Chainer environment is an Amazon-built Docker container that executes functions
@@ -67,9 +67,12 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
6767
One of 'py2' or 'py3'.
6868
framework_version (str): Chainer version you want to use for executing your model training code.
6969
List of supported versions https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators
70+
image_name (str): The container image to use for training. This will override py_version and
71+
framework_version. The image is expected to be a modification of the SageMaker Chainer image.
7072
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
7173
"""
72-
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
74+
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
75+
image_name=image_name, **kwargs)
7376
self.py_version = py_version
7477
self.framework_version = framework_version
7578
self.use_mpi = use_mpi
@@ -91,20 +94,6 @@ def hyperparameters(self):
9194
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
9295
return hyperparameters
9396

94-
def train_image(self):
95-
"""Return the Docker image to use for training.
96-
97-
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
98-
find the image to use for model training.
99-
100-
Returns:
101-
str: The URI of the Docker image.
102-
"""
103-
104-
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
105-
self.train_instance_type, framework_version=self.framework_version,
106-
py_version=self.py_version)
107-
10897
def create_model(self, model_server_workers=None):
10998
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``.
11099
@@ -120,7 +109,8 @@ def create_model(self, model_server_workers=None):
120109
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
121110
container_log_level=self.container_log_level, code_location=self.code_location,
122111
py_version=self.py_version, framework_version=self.framework_version,
123-
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
112+
model_server_workers=model_server_workers, image=self.image_name,
113+
sagemaker_session=self.sagemaker_session)
124114

125115
@classmethod
126116
def _prepare_init_params_from_job_description(cls, job_details):
@@ -142,7 +132,14 @@ def _prepare_init_params_from_job_description(cls, job_details):
142132
if value:
143133
init_params[argument[len('sagemaker_'):]] = value
144134

145-
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
135+
image_name = init_params.pop('image')
136+
framework, py_version, tag = framework_name_from_image(image_name)
137+
138+
if not framework:
139+
# If we were unable to parse the framework name from the image it is not one of our
140+
# officially supported images, in this case just add the image to the init params.
141+
init_params['image_name'] = image_name
142+
return init_params
146143

147144
init_params['py_version'] = py_version
148145
init_params['framework_version'] = framework_version_from_tag(tag)

src/sagemaker/estimator.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from six import with_metaclass
2121

2222
from sagemaker.analytics import TrainingJobAnalytics
23-
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
23+
from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, parse_s3_url, UploadedCode,
24+
validate_source_dir)
2425
from sagemaker.job import _Job
2526
from sagemaker.local import LocalSession
2627
from sagemaker.model import Model
@@ -226,6 +227,7 @@ def attach(cls, training_job_name, sagemaker_session=None):
226227
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
227228
init_params = cls._prepare_init_params_from_job_description(job_details)
228229

230+
print(init_params)
229231
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
230232
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
231233
training_job_name=init_params['base_job_name'])
@@ -493,7 +495,7 @@ class Framework(EstimatorBase):
493495
"""
494496

495497
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
496-
container_log_level=logging.INFO, code_location=None, **kwargs):
498+
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs):
497499
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
498500
499501
Args:
@@ -513,6 +515,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
513515
code_location (str): Name of the S3 bucket where custom code is uploaded (default: None).
514516
If not specified, default bucket created by ``sagemaker.session.Session`` is used.
515517
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
518+
image_name (str): An alternate image name to use instead of the official Sagemaker image
519+
for the framework. This is useful to run one of the Sagemaker supported frameworks
520+
with an image containing custom dependencies.
516521
"""
517522
super(Framework, self).__init__(**kwargs)
518523
self.source_dir = source_dir
@@ -521,6 +526,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
521526
self.container_log_level = container_log_level
522527
self._hyperparameters = hyperparameters or {}
523528
self.code_location = code_location
529+
self.image_name = image_name
530+
print(self.image_name)
531+
print(kwargs)
524532

525533
def _prepare_for_training(self, job_name=None):
526534
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
@@ -624,6 +632,21 @@ def _prepare_init_params_from_job_description(cls, job_details):
624632

625633
return init_params
626634

635+
def train_image(self):
636+
"""Return the Docker image to use for training.
637+
638+
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training,
639+
calls this method to find the image to use for model training.
640+
641+
Returns:
642+
str: The URI of the Docker image.
643+
"""
644+
if self.image_name:
645+
return self.image_name
646+
else:
647+
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
648+
self.train_instance_type, self.framework_version, py_version=self.py_version)
649+
627650
@classmethod
628651
def attach(cls, training_job_name, sagemaker_session=None):
629652
"""Attach to an existing training job.

src/sagemaker/mxnet/estimator.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
from sagemaker.estimator import Framework
16-
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
16+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
1717
from sagemaker.mxnet.defaults import MXNET_VERSION
1818
from sagemaker.mxnet.model import MXNetModel
1919

@@ -24,7 +24,7 @@ class MXNet(Framework):
2424
__framework_name__ = "mxnet"
2525

2626
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2',
27-
framework_version=MXNET_VERSION, **kwargs):
27+
framework_version=MXNET_VERSION, image_name=None, **kwargs):
2828
"""
2929
This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker
3030
Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions
@@ -52,25 +52,15 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
5252
One of 'py2' or 'py3'.
5353
framework_version (str): MXNet version you want to use for executing your model training code.
5454
List of supported versions https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators
55+
image_name (str): The container image to use for training. This will override py_version and
56+
framework_version. The image is expected to be a modification of the SageMaker MXNet image.
5557
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
5658
"""
57-
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
59+
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters,
60+
image_name=image_name, **kwargs)
5861
self.py_version = py_version
5962
self.framework_version = framework_version
6063

61-
def train_image(self):
62-
"""Return the Docker image to use for training.
63-
64-
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
65-
find the image to use for model training.
66-
67-
Returns:
68-
str: The URI of the Docker image.
69-
"""
70-
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
71-
self.train_instance_type, framework_version=self.framework_version,
72-
py_version=self.py_version)
73-
7464
def create_model(self, model_server_workers=None):
7565
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``.
7666
@@ -82,11 +72,16 @@ def create_model(self, model_server_workers=None):
8272
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
8373
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
8474
"""
75+
kwargs = {}
76+
# pass our custom image if there is one.
77+
if self.image_name:
78+
kwargs['image'] = self.image_name
79+
8580
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
8681
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
8782
container_log_level=self.container_log_level, code_location=self.code_location,
8883
py_version=self.py_version, framework_version=self.framework_version,
89-
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
84+
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session, **kwargs)
9085

9186
@classmethod
9287
def _prepare_init_params_from_job_description(cls, job_details):
@@ -100,7 +95,14 @@ def _prepare_init_params_from_job_description(cls, job_details):
10095
10196
"""
10297
init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details)
103-
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
98+
image_name = init_params.pop('image')
99+
framework, py_version, tag = framework_name_from_image(image_name)
100+
101+
if not framework:
102+
# If we were unable to parse the framework name from the image it is not one of our
103+
# officially supported images, in this case just add the image to the init params.
104+
init_params['image_name'] = image_name
105+
return init_params
104106

105107
init_params['py_version'] = py_version
106108

src/sagemaker/tensorflow/estimator.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import threading
2222

2323
from sagemaker.estimator import Framework
24-
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
24+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
2525
from sagemaker.utils import get_config_value
2626

2727
from sagemaker.tensorflow.defaults import TF_VERSION
@@ -157,7 +157,7 @@ class TensorFlow(Framework):
157157
__framework_name__ = 'tensorflow'
158158

159159
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
160-
framework_version=TF_VERSION, requirements_file='', **kwargs):
160+
framework_version=TF_VERSION, requirements_file='', image_name=None, **kwargs):
161161
"""Initialize an ``TensorFlow`` estimator.
162162
Args:
163163
training_steps (int): Perform this many steps of training. `None`, the default means train forever.
@@ -171,9 +171,11 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
171171
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
172172
relative to ``source_dir``. Details on the format can be found in the
173173
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
174+
image_name (str): The container image to use for training. This will override py_version and
175+
framework_version. The image is expected to be a modification of the SageMaker TensorFlow image.
174176
**kwargs: Additional kwargs passed to the Framework constructor.
175177
"""
176-
super(TensorFlow, self).__init__(**kwargs)
178+
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
177179
self.checkpoint_path = checkpoint_path
178180
self.py_version = py_version
179181
self.framework_version = framework_version
@@ -257,7 +259,14 @@ def _prepare_init_params_from_job_description(cls, job_details):
257259
if value is not None:
258260
init_params[argument] = value
259261

260-
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
262+
image_name = init_params.pop('image')
263+
framework, py_version, tag = framework_name_from_image(image_name)
264+
if not framework:
265+
# If we were unable to parse the framework name from the image it is not one of our
266+
# officially supported images, in this case just add the image to the init params.
267+
init_params['image_name'] = image_name
268+
return init_params
269+
261270
init_params['py_version'] = py_version
262271

263272
# We switched image tagging scheme from regular image version (e.g. '1.0') to more expressive
@@ -272,18 +281,6 @@ def _prepare_init_params_from_job_description(cls, job_details):
272281

273282
return init_params
274283

275-
def train_image(self):
276-
"""Return the Docker image to use for training.
277-
278-
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
279-
find the image to use for model training.
280-
281-
Returns:
282-
str: The URI of the Docker image.
283-
"""
284-
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
285-
self.train_instance_type, self.framework_version, py_version=self.py_version)
286-
287284
def create_model(self, model_server_workers=None):
288285
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
289286
@@ -296,9 +293,9 @@ def create_model(self, model_server_workers=None):
296293
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
297294
"""
298295
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
299-
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
300-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env,
301-
name=self._current_job_name, container_log_level=self.container_log_level,
296+
return TensorFlowModel(self.model_data, self.role, self.entry_point, image=self.image_name,
297+
source_dir=self.source_dir, enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
298+
env=env, name=self._current_job_name, container_log_level=self.container_log_level,
302299
code_location=self.code_location, py_version=self.py_version,
303300
framework_version=self.framework_version, model_server_workers=model_server_workers,
304301
sagemaker_session=self.sagemaker_session)

0 commit comments

Comments
 (0)