Skip to content

Commit 7b61b1e

Browse files
committed
Add integration test
1 parent ceb367b commit 7b61b1e

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import pytest
17+
18+
import tests.integ
19+
import tests.integ.timeout
20+
21+
from sagemaker import image_uris
22+
from sagemaker.model import Model
23+
from sagemaker.predictor import Predictor
24+
from sagemaker.serializers import CSVSerializer
25+
from sagemaker.utils import unique_name_from_base
26+
27+
from tests.integ import DATA_DIR
28+
29+
30+
ROLE = "SageMakerRole"
31+
INSTANCE_COUNT = 1
32+
INSTANCE_TYPE = "ml.c5.xlarge"
33+
TEST_CSV_DATA = "42,42,42,42,42,42,42"
34+
XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model")
35+
36+
37+
@pytest.yield_fixture(scope="module")
38+
def endpoint_name(sagemaker_session):
39+
endpoint_name = unique_name_from_base("model-inference-id-integ")
40+
xgb_model_data = sagemaker_session.upload_data(
41+
path=os.path.join(XGBOOST_DATA_PATH, "xgb_model.tar.gz"),
42+
key_prefix="integ-test-data/xgboost/model",
43+
)
44+
45+
xgb_image = image_uris.retrieve(
46+
"xgboost", sagemaker_session.boto_region_name, version="1", image_scope="inference"
47+
)
48+
49+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
50+
endpoint_name=endpoint_name, sagemaker_session=sagemaker_session, hours=2
51+
):
52+
xgb_model = Model(
53+
model_data=xgb_model_data,
54+
image_uri=xgb_image,
55+
name=endpoint_name, # model name
56+
role=ROLE,
57+
sagemaker_session=sagemaker_session,
58+
)
59+
xgb_model.deploy(
60+
INSTANCE_COUNT,
61+
INSTANCE_TYPE,
62+
endpoint_name=endpoint_name
63+
)
64+
yield endpoint_name
65+
66+
67+
def test_predict_with_inference_id(sagemaker_session, endpoint_name):
68+
predictor = Predictor(
69+
endpoint_name=endpoint_name,
70+
sagemaker_session=sagemaker_session,
71+
serializer=CSVSerializer(),
72+
)
73+
74+
# Validate that no exception is raised when the target_variant is specified.
75+
response = predictor.predict(TEST_CSV_DATA, inference_id="foo")
76+
assert response
77+
78+
79+
def test_invoke_endpoint_with_inference_id(sagemaker_session, endpoint_name):
80+
response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint(
81+
EndpointName=endpoint_name,
82+
Body=TEST_CSV_DATA,
83+
ContentType="text/csv",
84+
Accept="text/csv",
85+
InferenceId="foo"
86+
)
87+
assert response

0 commit comments

Comments
 (0)