Skip to content

simplify create_image_uri function #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
'Please add framework_version={} to your constructor to avoid ' \
'an error in the future.'

VALID_PY_VERSIONS = ['py2', 'py3']

def create_image_uri(region, framework, instance_type, framework_version, py_version, account='520713654638',

def create_image_uri(region, framework, instance_type, framework_version, py_version=None, account='520713654638',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this.

optimized_families=[]):
"""Return the ECR URI of an image.

Expand All @@ -43,14 +45,18 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
framework (str): framework used by the image.
instance_type (str): SageMaker instance type. Used to determine device type (cpu/gpu/family-specific optimized).
framework_version (str): The version of the framework.
py_version (str): Python version. One of 'py2' or 'py3'.
py_version (str): Optional. Python version. If specified, should be one of 'py2' or 'py3'.
If not specified, image uri will not include a python component.
account (str): AWS account that contains the image. (default: '520713654638')
optimized_families (str): Instance families for which there exist specific optimized images.

Returns:
str: The appropriate image URI based on the given parameters.
"""

if py_version and py_version not in VALID_PY_VERSIONS:
raise ValueError('invalid py_version argument: {}'.format(py_version))

# Handle Account Number for Gov Cloud
if region == 'us-gov-west-1':
account = '246785580436'
Expand All @@ -73,7 +79,10 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
else:
device_type = 'cpu'

tag = "{}-{}-{}".format(framework_version, device_type, py_version)
if py_version:
tag = "{}-{}-{}".format(framework_version, device_type, py_version)
else:
tag = "{}-{}".format(framework_version, device_type)
return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \
.format(account, region, framework, tag)

Expand Down
8 changes: 3 additions & 5 deletions src/sagemaker/tensorflow/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _get_container_env(self):
return self.env

env = dict(self.env)
env['SAGEMAKER_TFS_NGINX_LOGLEVEL'] = Model.LOG_LEVEL_MAP[self._container_log_level]
env[Model.LOG_LEVEL_PARAM_NAME] = Model.LOG_LEVEL_MAP[self._container_log_level]
return env

def _get_image_uri(self, instance_type):
Expand All @@ -143,7 +143,5 @@ def _get_image_uri(self, instance_type):

# reuse standard image uri function, then strip unwanted python component
region_name = self.sagemaker_session.boto_region_name
image = create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
self._framework_version, 'py3')
image = image.replace('-py3', '')
return image
return create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
self._framework_version)
12 changes: 10 additions & 2 deletions tests/integ/test_tfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import botocore.exceptions
import pytest

import sagemaker
Expand All @@ -29,8 +30,9 @@ def instance_type(request):
@pytest.fixture(scope='module')
def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
endpoint_name = sagemaker.utils.name_from_base('sagemaker-tensorflow-serving')
model_data = sagemaker_session.upload_data(path='tests/data/tensorflow-serving-test-model.tar.gz',
key_prefix='tensorflow-serving/models')
model_data = sagemaker_session.upload_data(
path='tests/data/tensorflow-serving-test-model.tar.gz',
key_prefix='tensorflow-serving/models')
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
model = Model(model_data=model_data, role='SageMakerRole',
framework_version=tf_full_version,
Expand Down Expand Up @@ -80,3 +82,9 @@ def test_predict_csv(tfs_predictor):

result = predictor.predict(input_data)
assert expected_result == result


def test_predict_bad_input(tfs_predictor):
input_data = {'junk': 'data'}
with pytest.raises(botocore.exceptions.ClientError):
tfs_predictor.predict(input_data)
44 changes: 30 additions & 14 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
from __future__ import absolute_import

import inspect
from mock import Mock, patch
import os
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag, \
model_code_key_prefix
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir

import pytest
from mock import Mock, patch

from sagemaker.fw_utils import create_image_uri, framework_name_from_image, \
framework_version_from_tag, \
model_code_key_prefix
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
from sagemaker.utils import name_from_image

DATA_DIR = 'data_dir'
Expand All @@ -33,12 +35,12 @@
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name='boto_session', region_name=REGION)
ims = Mock(name='sagemaker_session', boto_session=boto_mock)
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
ims.expand_role = Mock(name="expand_role", return_value=ROLE)
ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts':
{'S3ModelArtifacts': 's3://m/m.tar.gz'}})
return ims
session_mock = Mock(name='sagemaker_session', boto_session=boto_mock)
session_mock.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
session_mock.expand_role = Mock(name="expand_role", return_value=ROLE)
session_mock.sagemaker_client.describe_training_job = \
Mock(return_value={'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}})
return session_mock


def test_create_image_uri_cpu():
Expand All @@ -49,6 +51,16 @@ def test_create_image_uri_cpu():
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'


def test_create_image_uri_no_python():
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.c4.large', '1.0rc', account='23')
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu'


def test_create_image_uri_bad_python():
with pytest.raises(ValueError):
create_image_uri('mars-south-3', 'mlfw', 'ml.c4.large', '1.0rc', 'py0')


def test_create_image_uri_gpu():
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.p3.2xlarge', '1.0rc', 'py3', '23')
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
Expand Down Expand Up @@ -127,7 +139,8 @@ def test_tar_and_upload_dir_not_s3(sagemaker_session):
script = os.path.basename(__file__)
directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
result = tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)
assert result == UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script)
assert result == UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix),
script)


def test_framework_name_from_image_mxnet():
Expand All @@ -149,21 +162,24 @@ def test_legacy_name_from_framework_image():


def test_legacy_name_from_wrong_framework():
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1')
framework, py_ver, tag = framework_name_from_image(
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1')
assert framework is None
assert py_ver is None
assert tag is None


def test_legacy_name_from_wrong_python():
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
framework, py_ver, tag = framework_name_from_image(
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
assert framework is None
assert py_ver is None
assert tag is None


def test_legacy_name_from_wrong_device():
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
framework, py_ver, tag = framework_name_from_image(
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
assert framework is None
assert py_ver is None
assert tag is None
Expand Down
14 changes: 10 additions & 4 deletions tests/unit/test_tfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from mock import Mock

from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.serving import Model, Predictor
from sagemaker.tensorflow.predictor import csv_serializer
from sagemaker.tensorflow.serving import Model, Predictor

JSON_CONTENT_TYPE = 'application/json'
CSV_CONTENT_TYPE = 'text/csv'
Expand Down Expand Up @@ -94,7 +94,8 @@ def test_estimator_deploy(sagemaker_session):

job_name = 'doing something'
tf.fit(inputs='s3://mybucket/train', job_name=job_name)
predictor = tf.deploy(INSTANCE_COUNT, INSTANCE_TYPE, 'endpoint', endpoint_type='tensorflow-serving')
predictor = tf.deploy(INSTANCE_COUNT, INSTANCE_TYPE, 'endpoint',
endpoint_type='tensorflow-serving')
assert isinstance(predictor, Predictor)


Expand Down Expand Up @@ -190,8 +191,13 @@ def test_predictor_classify_bad_content_type():
predictor.classify(CLASSIFY_INPUT)


def assert_invoked(sagemaker_session, **args):
sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(**args)
def assert_invoked(sagemaker_session, **kwargs):
call = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
cargs, ckwargs = call
assert not cargs
assert len(kwargs) == len(ckwargs)
for k in ckwargs:
assert kwargs[k] == ckwargs[k]


def mock_response(expected_response, sagemaker_session, content_type=JSON_CONTENT_TYPE):
Expand Down