Skip to content

Commit d6210d7

Browse files
author
wanyixia
committed
change:add test for compilation
1 parent 5f39c94 commit d6210d7

File tree

3 files changed

+105
-1
lines changed

3 files changed

+105
-1
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def compile(
653653
if target_instance_family is not None:
654654
if target_instance_family == "ml_eia2":
655655
LOGGER.info("You are using target device ml_eia2...")
656-
elif (target_instance_family != "ml_eia2") and target_instance_family.startswith("ml_"):
656+
elif target_instance_family.startswith("ml_"):
657657
self.image_uri = self._compilation_image_uri(
658658
self.sagemaker_session.boto_region_name,
659659
target_instance_family,

tests/conftest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,31 @@ def neo_pytorch_cpu_instance_type():
223223
return "ml.c5.xlarge"
224224

225225

226+
@pytest.fixture(scope="module")
227+
def tfs_eia_latest_py_version():
228+
return "py3"
229+
230+
231+
@pytest.fixture(scope="module")
232+
def tfs_eia_latest_version():
233+
return "2.3"
234+
235+
236+
@pytest.fixture(scope="module")
237+
def tfs_eia_target_device():
238+
return "ml_eia2"
239+
240+
241+
@pytest.fixture(scope="module")
242+
def tfs_eia_cpu_instance_type():
243+
return "ml.c5.xlarge"
244+
245+
246+
@pytest.fixture(scope="module")
247+
def tfs_eia_compilation_job_name():
248+
return utils.name_from_base("tfs-eia-compilation")
249+
250+
226251
@pytest.fixture(scope="module")
227252
def xgboost_framework_version(xgboost_version):
228253
if xgboost_version in ("1", "latest"):
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
import os
17+
18+
import sagemaker
19+
import sagemaker.predictor
20+
import sagemaker.utils
21+
import tests.integ
22+
import tests.integ.timeout
23+
import numpy as np
24+
import matplotlib.image as mpimg
25+
from sagemaker.tensorflow.model import TensorFlowModel
26+
from tests.integ import (
27+
DATA_DIR,
28+
)
29+
from tests.integ.timeout import timeout_and_delete_endpoint_by_name
30+
31+
INPUT_MODEL = os.path.join(DATA_DIR, "tensorflow-serving-test-model.tar.gz")
32+
INFERENCE_IMAGE = os.path.join(DATA_DIR, "cuteCat.jpg")
33+
34+
35+
def test_compile_and_deploy_with_accelerator(
36+
sagemaker_session,
37+
tfs_eia_cpu_instance_type,
38+
tfs_eia_latest_version,
39+
tfs_eia_latest_py_version,
40+
tfs_eia_target_device,
41+
tfs_eia_compilation_job_name
42+
):
43+
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
44+
model_data = sagemaker_session.upload_data(
45+
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
46+
key_prefix="tensorflow-serving/compiledmodels",
47+
)
48+
bucket = sagemaker_session.default_bucket()
49+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
50+
model = TensorFlowModel(
51+
model_data=model_data,
52+
role="SageMakerRole",
53+
framework_version=tfs_eia_latest_version,
54+
py_version=tfs_eia_latest_py_version,
55+
sagemaker_session=sagemaker_session,
56+
name=endpoint_name,
57+
)
58+
data_shape = {"input": [1, 224, 224, 3]}
59+
compiled_model_path = "s3://{}/{}/output".format(bucket, tfs_eia_compilation_job_name)
60+
compiled_model = model.compile(
61+
target_instance_family=tfs_eia_target_device,
62+
input_shape=data_shape,
63+
output_path=compiled_model_path,
64+
role="SageMakerRole",
65+
job_name=tfs_eia_compilation_job_name,
66+
framework='tensorflow',
67+
framework_version=tfs_eia_latest_version
68+
)
69+
predictor = compiled_model.deploy(
70+
1, tfs_eia_cpu_instance_type, endpoint_name=endpoint_name, accelerator_type="ml.eia2.large"
71+
)
72+
73+
image_path = os.path.join(tests.integ.DATA_DIR, "cuteCat.jpg")
74+
img = mpimg.imread(image_path)
75+
img = np.resize(img, (224, 224, 3))
76+
img = np.expand_dims(img, axis=0)
77+
input_data = {"inputs": img}
78+
result = predictor.predict(input_data)
79+
print("result", result)

0 commit comments

Comments
 (0)