Skip to content

Commit a37a54a

Browse files
authored
Merge branch 'master' into model-code-location
2 parents 85a70f1 + 3feb2eb commit a37a54a

File tree

10 files changed

+19
-8
lines changed

10 files changed

+19
-8
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ CHANGELOG
55
+1.5.1dev
66
+========
77

8+
* enhancement: Let Framework models reuse code uploaded by Framework estimators
89
* enhancement: Unify generation of model uploaded code location
910

1011
1.5.0
1112
=====
13+
1214
* feature: Add Support for PyTorch Framework
1315
* feature: Estimators: add support for TensorFlow 1.7.0
1416
* feature: Estimators: add support for TensorFlow 1.8.0

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def create_model(self, model_server_workers=None):
116116
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel`` object.
117117
See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
118118
"""
119-
return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
119+
return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
120120
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
121121
container_log_level=self.container_log_level, code_location=self.code_location,
122122
py_version=self.py_version, framework_version=self.framework_version,

src/sagemaker/estimator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def _prepare_for_training(self, job_name=None):
561561
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
562562

563563
def _stage_user_code_in_s3(self):
564-
""" Upload the user training script to s3 and return the location.
564+
"""Upload the user training script to s3 and return the location.
565565
566566
Returns: s3 uri
567567
@@ -579,6 +579,14 @@ def _stage_user_code_in_s3(self):
579579
script=self.entry_point,
580580
directory=self.source_dir)
581581

582+
def _model_source_dir(self):
583+
"""Get the appropriate value to pass as source_dir to model constructor on deploying
584+
585+
Returns:
586+
str: Either a local or an S3 path pointing to the source_dir to be used for code by the model to be deployed
587+
"""
588+
return self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix
589+
582590
def hyperparameters(self):
583591
"""Return the hyperparameters as a dictionary to use for training.
584592

src/sagemaker/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
132132
source_dir (str): Path (absolute or relative) to a directory with any other training
133133
source code dependencies aside from tne entry point file (default: None). Structure within this
134134
directory will be preserved when training on SageMaker.
135+
If the directory points to S3, no code will be uploaded and the S3 location will be used instead.
135136
predictor_cls (callable[string, sagemaker.session.Session]): A function to call to create
136137
a predictor (default: None). If not None, ``deploy`` will return the result of invoking
137138
this function on the created endpoint name.

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def create_model(self, model_server_workers=None):
8282
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
8383
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
8484
"""
85-
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
85+
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
8686
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
8787
container_log_level=self.container_log_level, code_location=self.code_location,
8888
py_version=self.py_version, framework_version=self.framework_version,

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def create_model(self, model_server_workers=None):
8181
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object.
8282
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
8383
"""
84-
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
84+
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
8585
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
8686
container_log_level=self.container_log_level, code_location=self.code_location,
8787
py_version=self.py_version, framework_version=self.framework_version,

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def create_model(self, model_server_workers=None):
300300
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
301301
"""
302302
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
303-
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
303+
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self._model_source_dir(),
304304
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env,
305305
name=self._current_job_name, container_log_level=self.container_log_level,
306306
code_location=self.code_location, py_version=self.py_version,

tests/unit/test_chainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
274274
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-gpu-{}'
275275
assert {'Environment':
276276
{'SAGEMAKER_SUBMIT_DIRECTORY':
277-
's3://mybucket/sagemaker-chainer-{}/sourcedir.tar.gz'.format(TIMESTAMP),
277+
's3://mybucket/sagemaker-chainer-{}/source/sourcedir.tar.gz'.format(TIMESTAMP),
278278
'SAGEMAKER_PROGRAM': 'dummy_script.py',
279279
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
280280
'SAGEMAKER_REGION': 'us-west-2',

tests/unit/test_mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_mxnet(strftime, sagemaker_session, mxnet_version):
149149
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2'
150150
environment = {
151151
'Environment': {
152-
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-{}/sourcedir.tar.gz'.format(TIMESTAMP),
152+
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-{}/source/sourcedir.tar.gz'.format(TIMESTAMP),
153153
'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
154154
'SAGEMAKER_REGION': 'us-west-2', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'
155155
},

tests/unit/test_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_pytorch(strftime, sagemaker_session, pytorch_version):
167167
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:{}-gpu-{}'
168168
assert {'Environment':
169169
{'SAGEMAKER_SUBMIT_DIRECTORY':
170-
's3://mybucket/sagemaker-pytorch-{}/sourcedir.tar.gz'.format(TIMESTAMP),
170+
's3://mybucket/sagemaker-pytorch-{}/source/sourcedir.tar.gz'.format(TIMESTAMP),
171171
'SAGEMAKER_PROGRAM': 'dummy_script.py',
172172
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
173173
'SAGEMAKER_REGION': 'us-west-2',

0 commit comments

Comments
 (0)