-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Allow Framework Estimators to use custom image #223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #223 +/- ##
==========================================
+ Coverage 92.27% 92.33% +0.06%
==========================================
Files 49 49
Lines 3261 3274 +13
==========================================
+ Hits 3009 3023 +14
+ Misses 252 251 -1
Continue to review full report at Codecov.
|
Chainer, Tensorflow and MXNet estimators can now pass an image_name argument to the constructor to use that image instead of the default sagemaker ones.
6965ca1
to
2fef9cd
Compare
src/sagemaker/estimator.py
Outdated
@@ -521,6 +526,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl | |||
self.container_log_level = container_log_level | |||
self._hyperparameters = hyperparameters or {} | |||
self.code_location = code_location | |||
self.image_name = image_name | |||
print(self.image_name) | |||
print(kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d'oh
src/sagemaker/estimator.py
Outdated
@@ -226,6 +227,7 @@ def attach(cls, training_job_name, sagemaker_session=None): | |||
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name) | |||
init_params = cls._prepare_init_params_from_job_description(job_details) | |||
|
|||
print(init_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
- Please add documentation to the README about this.
- Add a test which verifies that the custom image is used for the training job when the estimator is fit.
src/sagemaker/estimator.py
Outdated
@@ -226,6 +227,7 @@ def attach(cls, training_job_name, sagemaker_session=None): | |||
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name) | |||
init_params = cls._prepare_init_params_from_job_description(job_details) | |||
|
|||
print(init_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debugging print?
src/sagemaker/estimator.py
Outdated
@@ -521,6 +526,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl | |||
self.container_log_level = container_log_level | |||
self._hyperparameters = hyperparameters or {} | |||
self.code_location = code_location | |||
self.image_name = image_name | |||
print(self.image_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debugging prints?
@@ -513,6 +515,9 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl | |||
code_location (str): Name of the S3 bucket where custom code is uploaded (default: None). | |||
If not specified, default bucket created by ``sagemaker.session.Session`` is used. | |||
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor. | |||
image_name (str): An alternate image name to use instead of the official Sagemaker image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd go with "image" instead of "image_name" here since I think the "_name" part of it doesn't add any descriptiveness (it might even be a little confusing).
In the docstring you should explain the valid formats (e.g. ecr url, dockerhub name + tag) - include examples.
Also state that this is used for both training and deployment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with image_name because that is what the Estimator class uses, so I wanted to at least be consistent with that. I think just image is better but I dont know if we should have a different parameter name across different estimators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay, let's keep image_name then.
src/sagemaker/chainer/estimator.py
Outdated
@@ -67,9 +67,12 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_ | |||
One of 'py2' or 'py3'. | |||
framework_version (str): Chainer version you want to use for executing your model training code. | |||
List of supported versions https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators | |||
image_name (str): The container image to use for training. This will override py_version and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of "this will override", how about something like - "if specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version.
Applies to everywhere this wording is used.
tests/unit/test_chainer.py
Outdated
job_name = 'new_name' | ||
chainer.fit(inputs='s3://mybucket/train', job_name='new_name') | ||
model = chainer.create_model() | ||
chainer.container_log_level |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this line for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line was carried over from another test that I used as a template, I just realized this was present in other tests as well but its basically useless. I will get rid of it.
tests/unit/test_chainer.py
Outdated
chainer.container_log_level | ||
|
||
assert model.sagemaker_session == sagemaker_session | ||
assert model.image == custom_image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I would keep the asserts minimal for this test, and just assert the image, assuming that other unit tests cover the other parameters already.
tests/unit/test_chainer.py
Outdated
assert estimator.hyperparameters()['training_steps'] == '100' | ||
assert estimator.source_dir == 's3://some/sourcedir.tar.gz' | ||
assert estimator.entry_point == 'iris-dnn-classifier.py' | ||
assert estimator.train_image() == training_image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here about keeping minimal asserts (unless there's a specific reason that these interact with overriding the image.)
@@ -175,6 +175,12 @@ The following are optional arguments. When you create a ``Chainer`` object, you | |||
- ``job_name`` Name to assign for the training job that the fit() | |||
method launches. If not specified, the estimator generates a default | |||
job name, based on the training image name and current timestamp | |||
- ``image_name`` An alternative docker image to use for training and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting looks broken here - take a look in the rich view?
Applies to all the readmes changed.
src/sagemaker/chainer/estimator.py
Outdated
framework_version. The image is expected to be a modification of the SageMaker Chainer image. | ||
image_name (str): If specified, the estimator will use this image for training and hosting, instead of | ||
selecting the appropriate SageMaker official image based on framework_version and py_version. It can | ||
be an ECR url or dockerhub image and tag: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets make it more explicit that those are two separate examples. You could do something like:
"It can be an ECR url or dockerhub image and tag. Examples: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest"
Fixed: Linear Learner's kernel to align with SageMaker Notebooks
Chainer, Tensorflow and MXNet estimators can now pass
an image_name argument to the constructor to use that image
instead of the default sagemaker ones.
Issue #, if available:
Description of changes:
Merge Checklist
Put an
x
in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.