Skip to content

Commit d63a8f1

Browse files
committed
Let Framework models reuse code uploaded by Framework estimators
- This addresses problem 1 in #226 - This relies on the tar_and_upload_dir's behaviour when given S3 paths for directory - This updates MXNet, TensorFlow and Chainer frameworks
1 parent 42f9de8 commit d63a8f1

File tree

7 files changed

+7
-5
lines changed

7 files changed

+7
-5
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
========
77
* feature: Allow Local Serving of Models in S3
88
* enhancement: Allow option for ``HyperparameterTuner`` to not include estimator metadata in job
9+
* enhancement: Let Framework models reuse code uploaded by Framework estimators
910

1011

1112
1.4.2

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def create_model(self, model_server_workers=None):
116116
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel`` object.
117117
See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
118118
"""
119-
return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
119+
return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self.uploaded_code.s3_prefix,
120120
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
121121
container_log_level=self.container_log_level, code_location=self.code_location,
122122
py_version=self.py_version, framework_version=self.framework_version,

src/sagemaker/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
132132
source_dir (str): Path (absolute or relative) to a directory with any other training
133133
source code dependencies aside from tne entry point file (default: None). Structure within this
134134
directory will be preserved when training on SageMaker.
135+
If the directory points to S3, no code will be uploaded and the S3 location will be used instead.
135136
predictor_cls (callable[string, sagemaker.session.Session]): A function to call to create
136137
a predictor (default: None). If not None, ``deploy`` will return the result of invoking
137138
this function on the created endpoint name.

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def create_model(self, model_server_workers=None):
8282
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
8383
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
8484
"""
85-
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
85+
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self.uploaded_code.s3_prefix,
8686
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
8787
container_log_level=self.container_log_level, code_location=self.code_location,
8888
py_version=self.py_version, framework_version=self.framework_version,

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def create_model(self, model_server_workers=None):
296296
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
297297
"""
298298
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
299-
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
299+
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self.uploaded_code.s3_prefix,
300300
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env,
301301
name=self._current_job_name, container_log_level=self.container_log_level,
302302
code_location=self.code_location, py_version=self.py_version,

tests/unit/test_chainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
274274
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-gpu-{}'
275275
assert {'Environment':
276276
{'SAGEMAKER_SUBMIT_DIRECTORY':
277-
's3://mybucket/sagemaker-chainer-{}/sourcedir.tar.gz'.format(TIMESTAMP),
277+
's3://mybucket/sagemaker-chainer-{}/source/sourcedir.tar.gz'.format(TIMESTAMP),
278278
'SAGEMAKER_PROGRAM': 'dummy_script.py',
279279
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
280280
'SAGEMAKER_REGION': 'us-west-2',

tests/unit/test_mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_mxnet(strftime, sagemaker_session, mxnet_version):
149149
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2'
150150
environment = {
151151
'Environment': {
152-
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-{}/sourcedir.tar.gz'.format(TIMESTAMP),
152+
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-{}/source/sourcedir.tar.gz'.format(TIMESTAMP),
153153
'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
154154
'SAGEMAKER_REGION': 'us-west-2', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'
155155
},

0 commit comments

Comments
 (0)