-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Allow Framework Estimators to use custom image #223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2fef9cd
4926e10
e3f6ab5
17d14d1
74775e4
1ecf93d
e6108b7
471908d
f2cab5f
31643bd
df57063
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,8 @@ | |
from six import with_metaclass | ||
|
||
from sagemaker.analytics import TrainingJobAnalytics | ||
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir | ||
from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, parse_s3_url, UploadedCode, | ||
validate_source_dir) | ||
from sagemaker.job import _Job | ||
from sagemaker.local import LocalSession | ||
from sagemaker.model import Model | ||
|
@@ -493,7 +494,7 @@ class Framework(EstimatorBase): | |
""" | ||
|
||
def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False, | ||
container_log_level=logging.INFO, code_location=None, **kwargs): | ||
container_log_level=logging.INFO, code_location=None, image_name=None, **kwargs): | ||
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()`` | ||
|
||
Args: | ||
|
@@ -513,6 +514,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl | |
code_location (str): Name of the S3 bucket where custom code is uploaded (default: None). | ||
If not specified, default bucket created by ``sagemaker.session.Session`` is used. | ||
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor. | ||
image_name (str): An alternate image name to use instead of the official Sagemaker image | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd go with "image" instead of "image_name" here since I think the "_name" part of it doesn't add any descriptiveness (it might even be a little confusing). In the docstring you should explain the valid formats (e.g. ecr url, dockerhub name + tag) - include examples. Also state that this is used for both training and deployment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I went with image_name because that is what the Estimator class uses, so I wanted to at least be consistent with that. I think just image is better but I dont know if we should have a different parameter name across different estimators. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh okay, let's keep image_name then. |
||
for the framework. This is useful to run one of the Sagemaker supported frameworks | ||
with an image containing custom dependencies. | ||
""" | ||
super(Framework, self).__init__(**kwargs) | ||
self.source_dir = source_dir | ||
|
@@ -521,6 +525,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl | |
self.container_log_level = container_log_level | ||
self._hyperparameters = hyperparameters or {} | ||
self.code_location = code_location | ||
self.image_name = image_name | ||
|
||
def _prepare_for_training(self, job_name=None): | ||
"""Set hyperparameters needed for training. This method will also validate ``source_dir``. | ||
|
@@ -632,6 +637,21 @@ def _prepare_init_params_from_job_description(cls, job_details): | |
|
||
return init_params | ||
|
||
def train_image(self): | ||
"""Return the Docker image to use for training. | ||
|
||
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, | ||
calls this method to find the image to use for model training. | ||
|
||
Returns: | ||
str: The URI of the Docker image. | ||
""" | ||
if self.image_name: | ||
return self.image_name | ||
else: | ||
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__, | ||
self.train_instance_type, self.framework_version, py_version=self.py_version) | ||
|
||
@classmethod | ||
def attach(cls, training_job_name, sagemaker_session=None): | ||
"""Attach to an existing training job. | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting looks broken here - take a look in the rich view?
Applies to all the readmes changed.