Skip to content

Commit 66a5461

Browse files
Xiong Zengsamruds
authored andcommitted
Add integ test for model builder with GPU
1 parent 8eea80a commit 66a5461

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker.serve import Mode
17+
from sagemaker.serve.builder.model_builder import ModelBuilder
18+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
19+
from tests.integ.sagemaker.serve.constants import (
20+
HF_DIR,
21+
PYTHON_VERSION_IS_NOT_310,
22+
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
23+
)
24+
from tests.integ.timeout import timeout
25+
from tests.integ.utils import cleanup_model_resources
26+
import logging
27+
28+
logger = logging.getLogger(__name__)
29+
30+
model_id = "bert-base-uncased"
31+
32+
sample_input = {"inputs": "Hello I'm a [MASK] model."}
33+
34+
sample_output = [
35+
{
36+
"score": 0.10731109976768494,
37+
"token": 4827,
38+
"token_str": "fashion",
39+
"sequence": "hello i'm a fashion model.",
40+
},
41+
{
42+
"score": 0.08774465322494507,
43+
"token": 2535,
44+
"token_str": "role",
45+
"sequence": "hello i'm a role model.",
46+
},
47+
{
48+
"score": 0.05338414013385773,
49+
"token": 2047,
50+
"token_str": "new",
51+
"sequence": "hello i'm a new model.",
52+
},
53+
{
54+
"score": 0.04667224362492561,
55+
"token": 3565,
56+
"token_str": "super",
57+
"sequence": "hello i'm a super model.",
58+
},
59+
{
60+
"score": 0.027096163481473923,
61+
"token": 2986,
62+
"token_str": "fine",
63+
"sequence": "hello i'm a fine model.",
64+
},
65+
]
66+
67+
68+
@pytest.fixture
69+
def model_input():
70+
return {"inputs": "The man worked as a [MASK]."}
71+
72+
73+
@pytest.fixture
74+
def model_builder_model_schema_builder():
75+
return ModelBuilder(
76+
model_path=HF_DIR, model=model_id, schema_builder=SchemaBuilder(sample_input, sample_output)
77+
)
78+
79+
80+
@pytest.fixture
81+
def model_builder(request):
82+
return request.getfixturevalue(request.param)
83+
84+
85+
@pytest.mark.skipif(
86+
PYTHON_VERSION_IS_NOT_310,
87+
reason="Testing feature",
88+
)
89+
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
90+
def test_non_text_generation_model_single_GPU(sagemaker_session, model_builder, model_input):
91+
iam_client = sagemaker_session.boto_session.client("iam")
92+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
93+
caught_ex = None
94+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
95+
try:
96+
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
97+
logger.info("Running in SAGEMAKER_ENDPOINT mode")
98+
predictor = model.deploy(
99+
mode=Mode.SAGEMAKER_ENDPOINT,
100+
instance_type="ml.g4dn.xlarge",
101+
initial_instance_count=1,
102+
)
103+
logger.info("Endpoint successfully deployed.")
104+
prediction = predictor.predict(model_input)
105+
assert prediction is not None
106+
107+
endpoint_name = predictor.endpoint_name
108+
sagemaker_client = sagemaker_session.boto_session.client("sagemaker")
109+
endpoint_config_name = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)[
110+
"EndpointConfigName"
111+
]
112+
actual_instance_type = sagemaker_client.describe_endpoint_config(
113+
EndpointConfigName=endpoint_config_name
114+
)["ProductionVariants"][0]["InstanceType"]
115+
assert "ml.g4dn.xlarge" == actual_instance_type
116+
except Exception as e:
117+
caught_ex = e
118+
finally:
119+
cleanup_model_resources(
120+
sagemaker_session=model_builder.sagemaker_session,
121+
model_name=model.name,
122+
endpoint_name=model.endpoint_name,
123+
)
124+
if caught_ex:
125+
logger.exception(caught_ex)
126+
assert (
127+
False
128+
), f"Exception {caught_ex} was thrown when running model builder single GPU test"
129+
130+
131+
@pytest.mark.skipif(
132+
PYTHON_VERSION_IS_NOT_310,
133+
reason="Testing feature",
134+
)
135+
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
136+
def test_non_text_generation_model_multi_GPU(sagemaker_session, model_builder, model_input):
137+
iam_client = sagemaker_session.boto_session.client("iam")
138+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
139+
caught_ex = None
140+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
141+
try:
142+
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
143+
logger.info("Running in SAGEMAKER_ENDPOINT mode")
144+
predictor = model.deploy(
145+
mode=Mode.SAGEMAKER_ENDPOINT,
146+
instance_type="ml.g4dn.12xlarge",
147+
initial_instance_count=1,
148+
)
149+
logger.info("Endpoint successfully deployed.")
150+
prediction = predictor.predict(model_input)
151+
assert prediction is not None
152+
153+
endpoint_name = predictor.endpoint_name
154+
sagemaker_client = sagemaker_session.boto_session.client("sagemaker")
155+
endpoint_config_name = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)[
156+
"EndpointConfigName"
157+
]
158+
actual_instance_type = sagemaker_client.describe_endpoint_config(
159+
EndpointConfigName=endpoint_config_name
160+
)["ProductionVariants"][0]["InstanceType"]
161+
assert "ml.g4dn.12xlarge" == actual_instance_type
162+
except Exception as e:
163+
caught_ex = e
164+
finally:
165+
cleanup_model_resources(
166+
sagemaker_session=model_builder.sagemaker_session,
167+
model_name=model.name,
168+
endpoint_name=model.endpoint_name,
169+
)
170+
if caught_ex:
171+
logger.exception(caught_ex)
172+
assert (
173+
False
174+
), f"Exception {caught_ex} was thrown when running model builder multi GPU test"

0 commit comments

Comments
 (0)