Skip to content

Allow Framework Estimators to use custom image #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 25, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CHANGELOG

* bug-fix: Unit Tests: Improve unit test runtime
* bug-fix: Estimators: Fix attach for LDA
* feature: Allow Chainer, Tensorflow and MXNet estimators to use a custom docker image.

1.4.1
=====
Expand Down
35 changes: 16 additions & 19 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from __future__ import absolute_import

from sagemaker.estimator import Framework
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
from sagemaker.chainer.defaults import CHAINER_VERSION
from sagemaker.chainer.model import ChainerModel

Expand All @@ -31,7 +31,7 @@ class Chainer(Framework):

def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_per_host=None,
additional_mpi_options=None, source_dir=None, hyperparameters=None, py_version='py3',
framework_version=CHAINER_VERSION, **kwargs):
framework_version=CHAINER_VERSION, image_name=None, **kwargs):
"""
This ``Estimator`` executes an Chainer script in a managed Chainer execution environment, within a SageMaker
Training Job. The managed Chainer environment is an Amazon-built Docker container that executes functions
Expand Down Expand Up @@ -67,9 +67,12 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
One of 'py2' or 'py3'.
framework_version (str): Chainer version you want to use for executing your model training code.
List of supported versions https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators
image_name (str): The container image to use for training. This will override py_version and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of "this will override", how about something like - "if specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version.

Applies to everywhere this wording is used.

framework_version. The image is expected to be a modification of the SageMaker Chainer image.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
"""
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
image_name=image_name, **kwargs)
self.py_version = py_version
self.framework_version = framework_version
self.use_mpi = use_mpi
Expand All @@ -91,20 +94,6 @@ def hyperparameters(self):
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
return hyperparameters

def train_image(self):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
find the image to use for model training.

Returns:
str: The URI of the Docker image.
"""

return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
self.train_instance_type, framework_version=self.framework_version,
py_version=self.py_version)

def create_model(self, model_server_workers=None):
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``.

Expand All @@ -120,7 +109,8 @@ def create_model(self, model_server_workers=None):
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
container_log_level=self.container_log_level, code_location=self.code_location,
py_version=self.py_version, framework_version=self.framework_version,
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
model_server_workers=model_server_workers, image=self.image_name,
sagemaker_session=self.sagemaker_session)

@classmethod
def _prepare_init_params_from_job_description(cls, job_details):
Expand All @@ -142,7 +132,14 @@ def _prepare_init_params_from_job_description(cls, job_details):
if value:
init_params[argument[len('sagemaker_'):]] = value

framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
image_name = init_params.pop('image')
framework, py_version, tag = framework_name_from_image(image_name)

if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params['image_name'] = image_name
return init_params

init_params['py_version'] = py_version
init_params['framework_version'] = framework_version_from_tag(tag)
Expand Down
27 changes: 25 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from six import with_metaclass

from sagemaker.analytics import TrainingJobAnalytics
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, parse_s3_url, UploadedCode,
validate_source_dir)
from sagemaker.job import _Job
from sagemaker.local import LocalSession
from sagemaker.model import Model
Expand Down Expand Up @@ -226,6 +227,7 @@ def attach(cls, training_job_name, sagemaker_session=None):
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
init_params = cls._prepare_init_params_from_job_description(job_details)

print(init_params)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove debugging print?

estimator = cls(sagemaker_session=sagemaker_session, **init_params)
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
training_job_name=init_params['base_job_name'])
Expand Down Expand Up @@ -493,7 +495,7 @@ class Framework(EstimatorBase):
"""

def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False,
container_log_level=logging.INFO, code_location=None, **kwargs):
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs):
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``

Args:
Expand All @@ -513,6 +515,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
code_location (str): Name of the S3 bucket where custom code is uploaded (default: None).
If not specified, default bucket created by ``sagemaker.session.Session`` is used.
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
image_name (str): An alternate image name to use instead of the official Sagemaker image
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd go with "image" instead of "image_name" here since I think the "_name" part of it doesn't add any descriptiveness (it might even be a little confusing).

In the docstring you should explain the valid formats (e.g. ecr url, dockerhub name + tag) - include examples.

Also state that this is used for both training and deployment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with image_name because that is what the Estimator class uses, so I wanted to at least be consistent with that. I think just image is better but I dont know if we should have a different parameter name across different estimators.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, let's keep image_name then.

for the framework. This is useful to run one of the Sagemaker supported frameworks
with an image containing custom dependencies.
"""
super(Framework, self).__init__(**kwargs)
self.source_dir = source_dir
Expand All @@ -521,6 +526,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
self.container_log_level = container_log_level
self._hyperparameters = hyperparameters or {}
self.code_location = code_location
self.image_name = image_name
print(self.image_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove debugging prints?

print(kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d'oh


def _prepare_for_training(self, job_name=None):
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
Expand Down Expand Up @@ -624,6 +632,21 @@ def _prepare_init_params_from_job_description(cls, job_details):

return init_params

def train_image(self):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training,
calls this method to find the image to use for model training.

Returns:
str: The URI of the Docker image.
"""
if self.image_name:
return self.image_name
else:
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
self.train_instance_type, self.framework_version, py_version=self.py_version)

@classmethod
def attach(cls, training_job_name, sagemaker_session=None):
"""Attach to an existing training job.
Expand Down
38 changes: 20 additions & 18 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from __future__ import absolute_import

from sagemaker.estimator import Framework
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.mxnet.model import MXNetModel

Expand All @@ -24,7 +24,7 @@ class MXNet(Framework):
__framework_name__ = "mxnet"

def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2',
framework_version=MXNET_VERSION, **kwargs):
framework_version=MXNET_VERSION, image_name=None, **kwargs):
"""
This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker
Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions
Expand Down Expand Up @@ -52,25 +52,15 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
One of 'py2' or 'py3'.
framework_version (str): MXNet version you want to use for executing your model training code.
List of supported versions https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators
image_name (str): The container image to use for training. This will override py_version and
framework_version. The image is expected to be a modification of the SageMaker MXNet image.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
"""
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters,
image_name=image_name, **kwargs)
self.py_version = py_version
self.framework_version = framework_version

def train_image(self):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
find the image to use for model training.

Returns:
str: The URI of the Docker image.
"""
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
self.train_instance_type, framework_version=self.framework_version,
py_version=self.py_version)

def create_model(self, model_server_workers=None):
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``.

Expand All @@ -82,11 +72,16 @@ def create_model(self, model_server_workers=None):
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
"""
kwargs = {}
# pass our custom image if there is one.
if self.image_name:
kwargs['image'] = self.image_name

return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
container_log_level=self.container_log_level, code_location=self.code_location,
py_version=self.py_version, framework_version=self.framework_version,
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session, **kwargs)

@classmethod
def _prepare_init_params_from_job_description(cls, job_details):
Expand All @@ -100,7 +95,14 @@ def _prepare_init_params_from_job_description(cls, job_details):

"""
init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details)
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
image_name = init_params.pop('image')
framework, py_version, tag = framework_name_from_image(image_name)

if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params['image_name'] = image_name
return init_params

init_params['py_version'] = py_version

Expand Down
35 changes: 16 additions & 19 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import threading

from sagemaker.estimator import Framework
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
from sagemaker.utils import get_config_value

from sagemaker.tensorflow.defaults import TF_VERSION
Expand Down Expand Up @@ -157,7 +157,7 @@ class TensorFlow(Framework):
__framework_name__ = 'tensorflow'

def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
framework_version=TF_VERSION, requirements_file='', **kwargs):
framework_version=TF_VERSION, requirements_file='', image_name=None, **kwargs):
"""Initialize an ``TensorFlow`` estimator.
Args:
training_steps (int): Perform this many steps of training. `None`, the default means train forever.
Expand All @@ -171,9 +171,11 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
relative to ``source_dir``. Details on the format can be found in the
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
image_name (str): The container image to use for training. This will override py_version and
framework_version. The image is expected to be a modification of the SageMaker TensorFlow image.
**kwargs: Additional kwargs passed to the Framework constructor.
"""
super(TensorFlow, self).__init__(**kwargs)
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
self.checkpoint_path = checkpoint_path
self.py_version = py_version
self.framework_version = framework_version
Expand Down Expand Up @@ -257,7 +259,14 @@ def _prepare_init_params_from_job_description(cls, job_details):
if value is not None:
init_params[argument] = value

framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
image_name = init_params.pop('image')
framework, py_version, tag = framework_name_from_image(image_name)
if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params['image_name'] = image_name
return init_params

init_params['py_version'] = py_version

# We switched image tagging scheme from regular image version (e.g. '1.0') to more expressive
Expand All @@ -272,18 +281,6 @@ def _prepare_init_params_from_job_description(cls, job_details):

return init_params

def train_image(self):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
find the image to use for model training.

Returns:
str: The URI of the Docker image.
"""
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
self.train_instance_type, self.framework_version, py_version=self.py_version)

def create_model(self, model_server_workers=None):
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.

Expand All @@ -296,9 +293,9 @@ def create_model(self, model_server_workers=None):
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
"""
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env,
name=self._current_job_name, container_log_level=self.container_log_level,
return TensorFlowModel(self.model_data, self.role, self.entry_point, image=self.image_name,
source_dir=self.source_dir, enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
env=env, name=self._current_job_name, container_log_level=self.container_log_level,
code_location=self.code_location, py_version=self.py_version,
framework_version=self.framework_version, model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session)
Expand Down
Loading