Skip to content

Commit 85a70f1

Browse files
committed
Unify generation of model uploaded code location
- Extract repeated key_prefix calculation in different models (base, mxnet, tensorflow, chainer, pytorch) into a common utility method in fw_utils.py - The generated location now always respects the code_location prefix if present but it also concatenates either the model name or image-derived name to it - This is related to 2.b in #226
1 parent 92eb47d commit 85a70f1

File tree

9 files changed

+78
-18
lines changed

9 files changed

+78
-18
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
+1.5.1dev
6+
+========
7+
8+
* enhancement: Unify generation of model uploaded code location
9+
510
1.5.0
611
=====
712
* feature: Add Support for PyTorch Framework

src/sagemaker/chainer/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
from __future__ import absolute_import
1414

1515
import sagemaker
16-
from sagemaker.fw_utils import create_image_uri
16+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
1717
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1818
from sagemaker.chainer.defaults import CHAINER_VERSION
1919
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
20-
from sagemaker.utils import name_from_image
2120

2221

2322
class ChainerPredictor(RealTimePredictor):
@@ -85,7 +84,8 @@ def prepare_container_def(self, instance_type):
8584
region_name = self.sagemaker_session.boto_session.region_name
8685
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
8786
self.framework_version, self.py_version)
88-
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
87+
88+
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
8989
self._upload_code(deploy_key_prefix)
9090
deploy_env = dict(self.env)
9191
deploy_env.update(self._framework_env_vars())

src/sagemaker/fw_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from collections import namedtuple
2020
from six.moves.urllib.parse import urlparse
2121

22+
from sagemaker.utils import name_from_image
23+
2224
"""This module contains utility functions shared across ``Framework`` components."""
2325

2426

@@ -199,3 +201,21 @@ def parse_s3_url(url):
199201
if parsed_url.scheme != "s3":
200202
raise ValueError("Expecting 's3' scheme, got: {} in {}".format(parsed_url.scheme, url))
201203
return parsed_url.netloc, parsed_url.path.lstrip('/')
204+
205+
206+
def model_code_key_prefix(code_location_key_prefix, model_name, image):
207+
"""Returns the s3 key prefix for uploading code during model deployment
208+
209+
The location returned is a potential concatenation of 2 parts
210+
1. code_location_key_prefix if it exists
211+
2. model_name or a name derived from the image
212+
213+
Args:
214+
code_location_key_prefix (str): the s3 key prefix from code_location
215+
model_name (str): the name of the model
216+
image (str): the image from which a default name can be extracted
217+
218+
Returns:
219+
str: the key prefix to be used in uploading code
220+
"""
221+
return '/'.join(filter(None, [code_location_key_prefix, model_name or name_from_image(image)]))

src/sagemaker/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import sagemaker
1818

1919
from sagemaker.local import LocalSession
20-
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url
20+
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, model_code_key_prefix
2121
from sagemaker.session import Session
2222
from sagemaker.utils import name_from_image, get_config_value
2323

@@ -169,7 +169,8 @@ def prepare_container_def(self, instance_type):
169169
Returns:
170170
dict[str, str]: A container definition object usable with the CreateModel API.
171171
"""
172-
self._upload_code(self.key_prefix or self.name or name_from_image(self.image))
172+
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, self.image)
173+
self._upload_code(deploy_key_prefix)
173174
deploy_env = dict(self.env)
174175
deploy_env.update(self._framework_env_vars())
175176
return sagemaker.container_def(self.image, self.model_data, deploy_env)

src/sagemaker/mxnet/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
from __future__ import absolute_import
1414

1515
import sagemaker
16-
from sagemaker.fw_utils import create_image_uri
16+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
1717
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1818
from sagemaker.mxnet.defaults import MXNET_VERSION
1919
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
20-
from sagemaker.utils import name_from_image
2120

2221

2322
class MXNetPredictor(RealTimePredictor):
@@ -85,7 +84,8 @@ def prepare_container_def(self, instance_type):
8584
region_name = self.sagemaker_session.boto_session.region_name
8685
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
8786
self.framework_version, self.py_version)
88-
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
87+
88+
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
8989
self._upload_code(deploy_key_prefix)
9090
deploy_env = dict(self.env)
9191
deploy_env.update(self._framework_env_vars())

src/sagemaker/pytorch/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
import sagemaker
15-
from sagemaker.fw_utils import create_image_uri
15+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
1616
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1717
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
1818
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
19-
from sagemaker.utils import name_from_image
2019

2120

2221
class PyTorchPredictor(RealTimePredictor):
@@ -84,7 +83,7 @@ def prepare_container_def(self, instance_type):
8483
region_name = self.sagemaker_session.boto_session.region_name
8584
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
8685
self.framework_version, self.py_version)
87-
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
86+
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
8887
self._upload_code(deploy_key_prefix)
8988
deploy_env = dict(self.env)
9089
deploy_env.update(self._framework_env_vars())

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
from __future__ import absolute_import
1414

1515
import sagemaker
16-
from sagemaker.fw_utils import create_image_uri
16+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
1717
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1818
from sagemaker.predictor import RealTimePredictor
1919
from sagemaker.tensorflow.defaults import TF_VERSION
2020
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
21-
from sagemaker.utils import name_from_image
2221

2322

2423
class TensorFlowPredictor(RealTimePredictor):
@@ -89,8 +88,7 @@ def prepare_container_def(self, instance_type):
8988
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
9089
self.framework_version, self.py_version)
9190

92-
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
93-
91+
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
9492
self._upload_code(deploy_key_prefix)
9593
deploy_env = dict(self.env)
9694
deploy_env.update(self._framework_env_vars())

tests/unit/test_fw_utils.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,21 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16-
from mock import Mock
16+
from mock import Mock, patch
1717
import os
18-
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
18+
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag, \
19+
model_code_key_prefix
1920
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
2021
import pytest
2122

23+
from sagemaker.utils import name_from_image
2224

2325
DATA_DIR = 'data_dir'
2426
BUCKET_NAME = 'mybucket'
2527
ROLE = 'Sagemaker'
2628
REGION = 'us-west-2'
2729
SCRIPT_PATH = 'script.py'
30+
TIMESTAMP = '2017-10-10-14-14-15'
2831

2932

3033
@pytest.fixture()
@@ -189,3 +192,37 @@ def test_parse_s3_url_fail():
189192
with pytest.raises(ValueError) as error:
190193
parse_s3_url('t3://code_location')
191194
assert 'Expecting \'s3\' scheme' in str(error)
195+
196+
197+
def test_model_code_key_prefix_with_all_values_present():
198+
key_prefix = model_code_key_prefix('prefix', 'model_name', 'image_name')
199+
assert key_prefix == 'prefix/model_name'
200+
201+
202+
def test_model_code_key_prefix_with_no_prefix_and_all_other_values_present():
203+
key_prefix = model_code_key_prefix(None, 'model_name', 'image_name')
204+
assert key_prefix == 'model_name'
205+
206+
207+
@patch('time.strftime', return_value=TIMESTAMP)
208+
def test_model_code_key_prefix_with_only_image_present(time):
209+
key_prefix = model_code_key_prefix(None, None, 'image_name')
210+
assert key_prefix == name_from_image('image_name')
211+
212+
213+
@patch('time.strftime', return_value=TIMESTAMP)
214+
def test_model_code_key_prefix_and_image_present(time):
215+
key_prefix = model_code_key_prefix('prefix', None, 'image_name')
216+
assert key_prefix == 'prefix/' + name_from_image('image_name')
217+
218+
219+
def test_model_code_key_prefix_with_prefix_present_and_others_none_fail():
220+
with pytest.raises(TypeError) as error:
221+
model_code_key_prefix('prefix', None, None)
222+
assert 'expected string' in str(error)
223+
224+
225+
def test_model_code_key_prefix_with_all_none_fail():
226+
with pytest.raises(TypeError) as error:
227+
model_code_key_prefix(None, None, None)
228+
assert 'expected string' in str(error)

tests/unit/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_create_no_defaults(tfopen, exists, isdir, listdir, time, sagemaker_sess
8080

8181
assert model.prepare_container_def(INSTANCE_TYPE) == {
8282
'Environment': {'SAGEMAKER_PROGRAM': 'blah.py',
83-
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://cb/cp/sourcedir.tar.gz',
83+
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://cb/cp/name/sourcedir.tar.gz',
8484
'SAGEMAKER_CONTAINER_LOG_LEVEL': '55',
8585
'SAGEMAKER_REGION': 'us-west-2',
8686
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'true',

0 commit comments

Comments
 (0)