Skip to content

Commit 98d8e93

Browse files
author
Jonathan Esterhazy
committed
add failure test case
1 parent b6c9b0c commit 98d8e93

File tree

4 files changed

+57
-25
lines changed

4 files changed

+57
-25
lines changed

src/sagemaker/fw_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
'Please add framework_version={} to your constructor to avoid ' \
3434
'an error in the future.'
3535

36+
VALID_PY_VERSIONS = ['py2', 'py3']
3637

37-
def create_image_uri(region, framework, instance_type, framework_version, py_version, account='520713654638',
38+
39+
def create_image_uri(region, framework, instance_type, framework_version, py_version=None, account='520713654638',
3840
optimized_families=[]):
3941
"""Return the ECR URI of an image.
4042
@@ -43,14 +45,18 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
4345
framework (str): framework used by the image.
4446
instance_type (str): SageMaker instance type. Used to determine device type (cpu/gpu/family-specific optimized).
4547
framework_version (str): The version of the framework.
46-
py_version (str): Python version. One of 'py2' or 'py3'.
48+
py_version (str): Optional. Python version. If specified, should be one of 'py2' or 'py3'.
49+
If not specified, image uri will not include a python component.
4750
account (str): AWS account that contains the image. (default: '520713654638')
4851
optimized_families (str): Instance families for which there exist specific optimized images.
4952
5053
Returns:
5154
str: The appropriate image URI based on the given parameters.
5255
"""
5356

57+
if py_version and py_version not in VALID_PY_VERSIONS:
58+
raise ValueError('invalid py_version argument: {}'.format(py_version))
59+
5460
# Handle Account Number for Gov Cloud
5561
if region == 'us-gov-west-1':
5662
account = '246785580436'
@@ -73,7 +79,10 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
7379
else:
7480
device_type = 'cpu'
7581

76-
tag = "{}-{}-{}".format(framework_version, device_type, py_version)
82+
if py_version:
83+
tag = "{}-{}-{}".format(framework_version, device_type, py_version)
84+
else:
85+
tag = "{}-{}".format(framework_version, device_type)
7786
return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \
7887
.format(account, region, framework, tag)
7988

src/sagemaker/tensorflow/serving.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _get_container_env(self):
134134
return self.env
135135

136136
env = dict(self.env)
137-
env['SAGEMAKER_TFS_NGINX_LOGLEVEL'] = Model.LOG_LEVEL_MAP[self._container_log_level]
137+
env[Model.LOG_LEVEL_PARAM_NAME] = Model.LOG_LEVEL_MAP[self._container_log_level]
138138
return env
139139

140140
def _get_image_uri(self, instance_type):
@@ -143,7 +143,5 @@ def _get_image_uri(self, instance_type):
143143

144144
# reuse standard image uri function, then strip unwanted python component
145145
region_name = self.sagemaker_session.boto_region_name
146-
image = create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
147-
self._framework_version, 'py3')
148-
image = image.replace('-py3', '')
149-
return image
146+
return create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
147+
self._framework_version)

tests/integ/test_tfs.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import botocore.exceptions
1516
import pytest
1617

1718
import sagemaker
@@ -28,11 +29,13 @@ def instance_type(request):
2829

2930
@pytest.fixture(scope='module')
3031
def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
32+
image = '237082650222.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-serving:1.11-cpu'
3133
endpoint_name = sagemaker.utils.name_from_base('sagemaker-tensorflow-serving')
32-
model_data = sagemaker_session.upload_data(path='tests/data/tensorflow-serving-test-model.tar.gz',
33-
key_prefix='tensorflow-serving/models')
34+
model_data = sagemaker_session.upload_data(
35+
path='tests/data/tensorflow-serving-test-model.tar.gz',
36+
key_prefix='tensorflow-serving/models')
3437
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
35-
model = Model(model_data=model_data, role='SageMakerRole',
38+
model = Model(model_data=model_data, role='SageMakerRole', image=image,
3639
framework_version=tf_full_version,
3740
sagemaker_session=sagemaker_session)
3841
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name)
@@ -80,3 +83,9 @@ def test_predict_csv(tfs_predictor):
8083

8184
result = predictor.predict(input_data)
8285
assert expected_result == result
86+
87+
88+
def test_predict_bad_input(tfs_predictor):
89+
input_data = {'junk': 'data'}
90+
with pytest.raises(botocore.exceptions.ClientError):
91+
tfs_predictor.predict(input_data)

tests/unit/test_fw_utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16-
from mock import Mock, patch
1716
import os
18-
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag, \
19-
model_code_key_prefix
20-
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
17+
2118
import pytest
19+
from mock import Mock, patch
2220

21+
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, \
22+
framework_version_from_tag, \
23+
model_code_key_prefix
24+
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
2325
from sagemaker.utils import name_from_image
2426

2527
DATA_DIR = 'data_dir'
@@ -33,12 +35,12 @@
3335
@pytest.fixture()
3436
def sagemaker_session():
3537
boto_mock = Mock(name='boto_session', region_name=REGION)
36-
ims = Mock(name='sagemaker_session', boto_session=boto_mock)
37-
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
38-
ims.expand_role = Mock(name="expand_role", return_value=ROLE)
39-
ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts':
40-
{'S3ModelArtifacts': 's3://m/m.tar.gz'}})
41-
return ims
38+
session_mock = Mock(name='sagemaker_session', boto_session=boto_mock)
39+
session_mock.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
40+
session_mock.expand_role = Mock(name="expand_role", return_value=ROLE)
41+
session_mock.sagemaker_client.describe_training_job = \
42+
Mock(return_value={'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}})
43+
return session_mock
4244

4345

4446
def test_create_image_uri_cpu():
@@ -49,6 +51,16 @@ def test_create_image_uri_cpu():
4951
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'
5052

5153

54+
def test_create_image_uri_no_python():
55+
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.c4.large', '1.0rc', account='23')
56+
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu'
57+
58+
59+
def test_create_image_uri_bad_python():
60+
with pytest.raises(ValueError):
61+
create_image_uri('mars-south-3', 'mlfw', 'ml.c4.large', '1.0rc', 'py0')
62+
63+
5264
def test_create_image_uri_gpu():
5365
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.p3.2xlarge', '1.0rc', 'py3', '23')
5466
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
@@ -127,7 +139,8 @@ def test_tar_and_upload_dir_not_s3(sagemaker_session):
127139
script = os.path.basename(__file__)
128140
directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
129141
result = tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)
130-
assert result == UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script)
142+
assert result == UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix),
143+
script)
131144

132145

133146
def test_framework_name_from_image_mxnet():
@@ -149,21 +162,24 @@ def test_legacy_name_from_framework_image():
149162

150163

151164
def test_legacy_name_from_wrong_framework():
152-
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1')
165+
framework, py_ver, tag = framework_name_from_image(
166+
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1')
153167
assert framework is None
154168
assert py_ver is None
155169
assert tag is None
156170

157171

158172
def test_legacy_name_from_wrong_python():
159-
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
173+
framework, py_ver, tag = framework_name_from_image(
174+
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
160175
assert framework is None
161176
assert py_ver is None
162177
assert tag is None
163178

164179

165180
def test_legacy_name_from_wrong_device():
166-
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
181+
framework, py_ver, tag = framework_name_from_image(
182+
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
167183
assert framework is None
168184
assert py_ver is None
169185
assert tag is None

0 commit comments

Comments
 (0)