Skip to content

Commit 5f305cd

Browse files
author
wanyixia
committed
change: revise compilation test
1 parent d6210d7 commit 5f305cd

File tree

3 files changed

+53
-104
lines changed

3 files changed

+53
-104
lines changed

tests/conftest.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -223,31 +223,6 @@ def neo_pytorch_cpu_instance_type():
223223
return "ml.c5.xlarge"
224224

225225

226-
@pytest.fixture(scope="module")
227-
def tfs_eia_latest_py_version():
228-
return "py3"
229-
230-
231-
@pytest.fixture(scope="module")
232-
def tfs_eia_latest_version():
233-
return "2.3"
234-
235-
236-
@pytest.fixture(scope="module")
237-
def tfs_eia_target_device():
238-
return "ml_eia2"
239-
240-
241-
@pytest.fixture(scope="module")
242-
def tfs_eia_cpu_instance_type():
243-
return "ml.c5.xlarge"
244-
245-
246-
@pytest.fixture(scope="module")
247-
def tfs_eia_compilation_job_name():
248-
return utils.name_from_base("tfs-eia-compilation")
249-
250-
251226
@pytest.fixture(scope="module")
252227
def xgboost_framework_version(xgboost_version):
253228
if xgboost_version in ("1", "latest"):

tests/integ/test_tfs.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,42 @@ def tfs_predictor_with_accelerator(
138138
yield predictor
139139

140140

141+
@pytest.fixture(scope="module")
142+
def tfs_trt_predictor_with_accelerator(
143+
sagemaker_session, tensorflow_eia_latest_version, cpu_instance_type
144+
):
145+
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
146+
model_data = sagemaker_session.upload_data(
147+
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
148+
key_prefix="tensorflow-serving/compiledmodels",
149+
)
150+
bucket = sagemaker_session.default_bucket()
151+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
152+
model = TensorFlowModel(
153+
model_data=model_data,
154+
role="SageMakerRole",
155+
framework_version=tensorflow_eia_latest_version,
156+
sagemaker_session=sagemaker_session,
157+
name=endpoint_name,
158+
)
159+
data_shape = {"input": [1, 224, 224, 3]}
160+
tfs_eia_compilation_job_name = "tfs_eia_compilation_job_name"
161+
compiled_model_path = "s3://{}/{}/output".format(bucket, tfs_eia_compilation_job_name)
162+
compiled_model = model.compile(
163+
target_instance_family='ml_eia2',
164+
input_shape=data_shape,
165+
output_path=compiled_model_path,
166+
role="SageMakerRole",
167+
job_name=tfs_eia_compilation_job_name,
168+
framework='tensorflow',
169+
framework_version='2.3'
170+
)
171+
predictor = compiled_model.deploy(
172+
1, cpu_instance_type, endpoint_name=endpoint_name, accelerator_type="ml.eia2.large"
173+
)
174+
yield predictor
175+
176+
141177
@pytest.mark.release
142178
def test_predict(tfs_predictor):
143179
input_data = {"instances": [1.0, 2.0, 5.0]}
@@ -160,6 +196,23 @@ def test_predict_with_accelerator(tfs_predictor_with_accelerator):
160196
assert expected_result == result
161197

162198

199+
@pytest.mark.skipif(
200+
tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS,
201+
reason="EI is not supported in region {}".format(tests.integ.test_region()),
202+
)
203+
@pytest.mark.release
204+
def test_trt_predict_with_accelerator(tfs_predictor_with_accelerator):
205+
import numpy as np
206+
import matplotlib.image as mpimg
207+
path = os.path.join(tests.integ.DATA_DIR, "cuteCat.jpg")
208+
img = mpimg.imread(path)
209+
img = np.resize(img, (224, 224, 3))
210+
img = np.expand_dims(img, axis=0)
211+
input_data = {"inputs": img}
212+
result = tfs_trt_predictor_with_accelerator.predict(input_data)
213+
print("trt predictor result is: " + result)
214+
215+
163216
@pytest.mark.local_mode
164217
def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_tar):
165218
input_data = {"instances": [1.0, 2.0, 5.0]}

tests/integ/test_tfs_eia_compilation.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

0 commit comments

Comments
 (0)