Skip to content

Commit c1ab360

Browse files
authored
simplify create_image_uri function (#462)
* add failure test case * fix flaky assert
1 parent b6c9b0c commit c1ab360

File tree

5 files changed

+65
-28
lines changed

5 files changed

+65
-28
lines changed

src/sagemaker/fw_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
'Please add framework_version={} to your constructor to avoid ' \
3434
'an error in the future.'
3535

36+
VALID_PY_VERSIONS = ['py2', 'py3']
3637

37-
def create_image_uri(region, framework, instance_type, framework_version, py_version, account='520713654638',
38+
39+
def create_image_uri(region, framework, instance_type, framework_version, py_version=None, account='520713654638',
3840
optimized_families=[]):
3941
"""Return the ECR URI of an image.
4042
@@ -43,14 +45,18 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
4345
framework (str): framework used by the image.
4446
instance_type (str): SageMaker instance type. Used to determine device type (cpu/gpu/family-specific optimized).
4547
framework_version (str): The version of the framework.
46-
py_version (str): Python version. One of 'py2' or 'py3'.
48+
py_version (str): Optional. Python version. If specified, should be one of 'py2' or 'py3'.
49+
If not specified, image uri will not include a python component.
4750
account (str): AWS account that contains the image. (default: '520713654638')
4851
optimized_families (str): Instance families for which there exist specific optimized images.
4952
5053
Returns:
5154
str: The appropriate image URI based on the given parameters.
5255
"""
5356

57+
if py_version and py_version not in VALID_PY_VERSIONS:
58+
raise ValueError('invalid py_version argument: {}'.format(py_version))
59+
5460
# Handle Account Number for Gov Cloud
5561
if region == 'us-gov-west-1':
5662
account = '246785580436'
@@ -73,7 +79,10 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
7379
else:
7480
device_type = 'cpu'
7581

76-
tag = "{}-{}-{}".format(framework_version, device_type, py_version)
82+
if py_version:
83+
tag = "{}-{}-{}".format(framework_version, device_type, py_version)
84+
else:
85+
tag = "{}-{}".format(framework_version, device_type)
7786
return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \
7887
.format(account, region, framework, tag)
7988

src/sagemaker/tensorflow/serving.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _get_container_env(self):
134134
return self.env
135135

136136
env = dict(self.env)
137-
env['SAGEMAKER_TFS_NGINX_LOGLEVEL'] = Model.LOG_LEVEL_MAP[self._container_log_level]
137+
env[Model.LOG_LEVEL_PARAM_NAME] = Model.LOG_LEVEL_MAP[self._container_log_level]
138138
return env
139139

140140
def _get_image_uri(self, instance_type):
@@ -143,7 +143,5 @@ def _get_image_uri(self, instance_type):
143143

144144
# reuse standard image uri function, then strip unwanted python component
145145
region_name = self.sagemaker_session.boto_region_name
146-
image = create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
147-
self._framework_version, 'py3')
148-
image = image.replace('-py3', '')
149-
return image
146+
return create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
147+
self._framework_version)

tests/integ/test_tfs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import botocore.exceptions
1516
import pytest
1617

1718
import sagemaker
@@ -29,8 +30,9 @@ def instance_type(request):
2930
@pytest.fixture(scope='module')
3031
def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
3132
endpoint_name = sagemaker.utils.name_from_base('sagemaker-tensorflow-serving')
32-
model_data = sagemaker_session.upload_data(path='tests/data/tensorflow-serving-test-model.tar.gz',
33-
key_prefix='tensorflow-serving/models')
33+
model_data = sagemaker_session.upload_data(
34+
path='tests/data/tensorflow-serving-test-model.tar.gz',
35+
key_prefix='tensorflow-serving/models')
3436
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
3537
model = Model(model_data=model_data, role='SageMakerRole',
3638
framework_version=tf_full_version,
@@ -80,3 +82,9 @@ def test_predict_csv(tfs_predictor):
8082

8183
result = predictor.predict(input_data)
8284
assert expected_result == result
85+
86+
87+
def test_predict_bad_input(tfs_predictor):
88+
input_data = {'junk': 'data'}
89+
with pytest.raises(botocore.exceptions.ClientError):
90+
tfs_predictor.predict(input_data)

tests/unit/test_fw_utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16-
from mock import Mock, patch
1716
import os
18-
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag, \
19-
model_code_key_prefix
20-
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
17+
2118
import pytest
19+
from mock import Mock, patch
2220

21+
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, \
22+
framework_version_from_tag, \
23+
model_code_key_prefix
24+
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
2325
from sagemaker.utils import name_from_image
2426

2527
DATA_DIR = 'data_dir'
@@ -33,12 +35,12 @@
3335
@pytest.fixture()
3436
def sagemaker_session():
3537
boto_mock = Mock(name='boto_session', region_name=REGION)
36-
ims = Mock(name='sagemaker_session', boto_session=boto_mock)
37-
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
38-
ims.expand_role = Mock(name="expand_role", return_value=ROLE)
39-
ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts':
40-
{'S3ModelArtifacts': 's3://m/m.tar.gz'}})
41-
return ims
38+
session_mock = Mock(name='sagemaker_session', boto_session=boto_mock)
39+
session_mock.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
40+
session_mock.expand_role = Mock(name="expand_role", return_value=ROLE)
41+
session_mock.sagemaker_client.describe_training_job = \
42+
Mock(return_value={'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}})
43+
return session_mock
4244

4345

4446
def test_create_image_uri_cpu():
@@ -49,6 +51,16 @@ def test_create_image_uri_cpu():
4951
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'
5052

5153

54+
def test_create_image_uri_no_python():
55+
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.c4.large', '1.0rc', account='23')
56+
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu'
57+
58+
59+
def test_create_image_uri_bad_python():
60+
with pytest.raises(ValueError):
61+
create_image_uri('mars-south-3', 'mlfw', 'ml.c4.large', '1.0rc', 'py0')
62+
63+
5264
def test_create_image_uri_gpu():
5365
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.p3.2xlarge', '1.0rc', 'py3', '23')
5466
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
@@ -127,7 +139,8 @@ def test_tar_and_upload_dir_not_s3(sagemaker_session):
127139
script = os.path.basename(__file__)
128140
directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
129141
result = tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)
130-
assert result == UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script)
142+
assert result == UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix),
143+
script)
131144

132145

133146
def test_framework_name_from_image_mxnet():
@@ -149,21 +162,24 @@ def test_legacy_name_from_framework_image():
149162

150163

151164
def test_legacy_name_from_wrong_framework():
152-
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1')
165+
framework, py_ver, tag = framework_name_from_image(
166+
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1')
153167
assert framework is None
154168
assert py_ver is None
155169
assert tag is None
156170

157171

158172
def test_legacy_name_from_wrong_python():
159-
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
173+
framework, py_ver, tag = framework_name_from_image(
174+
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
160175
assert framework is None
161176
assert py_ver is None
162177
assert tag is None
163178

164179

165180
def test_legacy_name_from_wrong_device():
166-
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
181+
framework, py_ver, tag = framework_name_from_image(
182+
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
167183
assert framework is None
168184
assert py_ver is None
169185
assert tag is None

tests/unit/test_tfs.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from mock import Mock
2121

2222
from sagemaker.tensorflow import TensorFlow
23-
from sagemaker.tensorflow.serving import Model, Predictor
2423
from sagemaker.tensorflow.predictor import csv_serializer
24+
from sagemaker.tensorflow.serving import Model, Predictor
2525

2626
JSON_CONTENT_TYPE = 'application/json'
2727
CSV_CONTENT_TYPE = 'text/csv'
@@ -94,7 +94,8 @@ def test_estimator_deploy(sagemaker_session):
9494

9595
job_name = 'doing something'
9696
tf.fit(inputs='s3://mybucket/train', job_name=job_name)
97-
predictor = tf.deploy(INSTANCE_COUNT, INSTANCE_TYPE, 'endpoint', endpoint_type='tensorflow-serving')
97+
predictor = tf.deploy(INSTANCE_COUNT, INSTANCE_TYPE, 'endpoint',
98+
endpoint_type='tensorflow-serving')
9899
assert isinstance(predictor, Predictor)
99100

100101

@@ -190,8 +191,13 @@ def test_predictor_classify_bad_content_type():
190191
predictor.classify(CLASSIFY_INPUT)
191192

192193

193-
def assert_invoked(sagemaker_session, **args):
194-
sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(**args)
194+
def assert_invoked(sagemaker_session, **kwargs):
195+
call = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
196+
cargs, ckwargs = call
197+
assert not cargs
198+
assert len(kwargs) == len(ckwargs)
199+
for k in ckwargs:
200+
assert kwargs[k] == ckwargs[k]
195201

196202

197203
def mock_response(expected_response, sagemaker_session, content_type=JSON_CONTENT_TYPE):

0 commit comments

Comments
 (0)