Skip to content

Unify generation of model uploaded code location #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
CHANGELOG
=========

1.5.1dev
========
+1.5.1dev
+========

* enhancement: Let Framework models reuse code uploaded by Framework estimators
* enhancement: Unify generation of model uploaded code location

1.5.0
=====
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from __future__ import absolute_import

import sagemaker
from sagemaker.fw_utils import create_image_uri
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer.defaults import CHAINER_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
from sagemaker.utils import name_from_image


class ChainerPredictor(RealTimePredictor):
Expand Down Expand Up @@ -85,7 +84,8 @@ def prepare_container_def(self, instance_type):
region_name = self.sagemaker_session.boto_session.region_name
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
self.framework_version, self.py_version)
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)

deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
Expand Down
20 changes: 20 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from collections import namedtuple
from six.moves.urllib.parse import urlparse

from sagemaker.utils import name_from_image

"""This module contains utility functions shared across ``Framework`` components."""


Expand Down Expand Up @@ -199,3 +201,21 @@ def parse_s3_url(url):
if parsed_url.scheme != "s3":
raise ValueError("Expecting 's3' scheme, got: {} in {}".format(parsed_url.scheme, url))
return parsed_url.netloc, parsed_url.path.lstrip('/')


def model_code_key_prefix(code_location_key_prefix, model_name, image):
"""Returns the s3 key prefix for uploading code during model deployment

The location returned is a potential concatenation of 2 parts
1. code_location_key_prefix if it exists
2. model_name or a name derived from the image

Args:
code_location_key_prefix (str): the s3 key prefix from code_location
model_name (str): the name of the model
image (str): the image from which a default name can be extracted

Returns:
str: the key prefix to be used in uploading code
"""
return '/'.join(filter(None, [code_location_key_prefix, model_name or name_from_image(image)]))
5 changes: 3 additions & 2 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import sagemaker

from sagemaker.local import LocalSession
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, model_code_key_prefix
from sagemaker.session import Session
from sagemaker.utils import name_from_image, get_config_value

Expand Down Expand Up @@ -170,7 +170,8 @@ def prepare_container_def(self, instance_type):
Returns:
dict[str, str]: A container definition object usable with the CreateModel API.
"""
self._upload_code(self.key_prefix or self.name or name_from_image(self.image))
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, self.image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
return sagemaker.container_def(self.image, self.model_data, deploy_env)
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from __future__ import absolute_import

import sagemaker
from sagemaker.fw_utils import create_image_uri
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
from sagemaker.utils import name_from_image


class MXNetPredictor(RealTimePredictor):
Expand Down Expand Up @@ -85,7 +84,8 @@ def prepare_container_def(self, instance_type):
region_name = self.sagemaker_session.boto_session.region_name
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
self.framework_version, self.py_version)
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)

deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
Expand Down
5 changes: 2 additions & 3 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import sagemaker
from sagemaker.fw_utils import create_image_uri
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
from sagemaker.utils import name_from_image


class PyTorchPredictor(RealTimePredictor):
Expand Down Expand Up @@ -84,7 +83,7 @@ def prepare_container_def(self, instance_type):
region_name = self.sagemaker_session.boto_session.region_name
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
self.framework_version, self.py_version)
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
Expand Down
6 changes: 2 additions & 4 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
from __future__ import absolute_import

import sagemaker
from sagemaker.fw_utils import create_image_uri
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import RealTimePredictor
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
from sagemaker.utils import name_from_image


class TensorFlowPredictor(RealTimePredictor):
Expand Down Expand Up @@ -89,8 +88,7 @@ def prepare_container_def(self, instance_type):
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
self.framework_version, self.py_version)

deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)

deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
Expand Down
41 changes: 39 additions & 2 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
from __future__ import absolute_import

import inspect
from mock import Mock
from mock import Mock, patch
import os
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag, \
model_code_key_prefix
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
import pytest

from sagemaker.utils import name_from_image

DATA_DIR = 'data_dir'
BUCKET_NAME = 'mybucket'
ROLE = 'Sagemaker'
REGION = 'us-west-2'
SCRIPT_PATH = 'script.py'
TIMESTAMP = '2017-10-10-14-14-15'


@pytest.fixture()
Expand Down Expand Up @@ -189,3 +192,37 @@ def test_parse_s3_url_fail():
with pytest.raises(ValueError) as error:
parse_s3_url('t3://code_location')
assert 'Expecting \'s3\' scheme' in str(error)


def test_model_code_key_prefix_with_all_values_present():
key_prefix = model_code_key_prefix('prefix', 'model_name', 'image_name')
assert key_prefix == 'prefix/model_name'


def test_model_code_key_prefix_with_no_prefix_and_all_other_values_present():
key_prefix = model_code_key_prefix(None, 'model_name', 'image_name')
assert key_prefix == 'model_name'


@patch('time.strftime', return_value=TIMESTAMP)
def test_model_code_key_prefix_with_only_image_present(time):
key_prefix = model_code_key_prefix(None, None, 'image_name')
assert key_prefix == name_from_image('image_name')


@patch('time.strftime', return_value=TIMESTAMP)
def test_model_code_key_prefix_and_image_present(time):
key_prefix = model_code_key_prefix('prefix', None, 'image_name')
assert key_prefix == 'prefix/' + name_from_image('image_name')


def test_model_code_key_prefix_with_prefix_present_and_others_none_fail():
with pytest.raises(TypeError) as error:
model_code_key_prefix('prefix', None, None)
assert 'expected string' in str(error)


def test_model_code_key_prefix_with_all_none_fail():
with pytest.raises(TypeError) as error:
model_code_key_prefix(None, None, None)
assert 'expected string' in str(error)
2 changes: 1 addition & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_create_no_defaults(tfopen, exists, isdir, listdir, time, sagemaker_sess

assert model.prepare_container_def(INSTANCE_TYPE) == {
'Environment': {'SAGEMAKER_PROGRAM': 'blah.py',
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://cb/cp/sourcedir.tar.gz',
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://cb/cp/name/sourcedir.tar.gz',
'SAGEMAKER_CONTAINER_LOG_LEVEL': '55',
'SAGEMAKER_REGION': 'us-west-2',
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'true',
Expand Down