Skip to content

Commit 566f17b

Browse files
committed
feat: Enable galactus integ tests
1 parent 086c946 commit 566f17b

File tree

3 files changed

+190
-158
lines changed

3 files changed

+190
-158
lines changed

requirements/extras/test_requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ tritonclient[http]<2.37.0
3939
onnx==1.14.1
4040
# tf2onnx==1.15.1
4141
nbformat>=5.9,<6
42+
torch==2.0.1
43+
torchvision==0.15.2
44+
torchaudio==2.0.2

tests/integ/sagemaker/serve/constants.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
# import os
15+
import os
1616
import platform
1717

18-
# from tests.integ import DATA_DIR
18+
from tests.integ import DATA_DIR
1919

20-
# SERVE_IN_PROCESS_TIMEOUT = 5
21-
# SERVE_MODEL_PACKAGE_TIMEOUT = 10
22-
# SERVE_LOCAL_CONTAINER_TIMEOUT = 10
20+
SERVE_IN_PROCESS_TIMEOUT = 5
21+
SERVE_MODEL_PACKAGE_TIMEOUT = 10
22+
SERVE_LOCAL_CONTAINER_TIMEOUT = 10
2323
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT = 15
24-
# SERVE_SAVE_TIMEOUT = 2
24+
SERVE_SAVE_TIMEOUT = 2
2525

26-
# NOT_RUNNING_ON_PY38 = platform.python_version_tuple()[1] != "8"
26+
NOT_RUNNING_ON_PY38 = platform.python_version_tuple()[1] != "8"
2727
NOT_RUNNING_ON_PY310 = platform.python_version_tuple()[1] != "10"
28-
# NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE = os.getenv("TEST_OWNER") != "INF_EXP_DEV"
28+
NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE = os.getenv("TEST_OWNER") != "INF_EXP_DEV"
2929

30-
# XGB_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "xgboost")
31-
# PYTORCH_SQUEEZENET_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "pytorch")
32-
# TF_EFFICIENT_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "tensorflow")
33-
# HF_DIR = os.path.join(DATA_DIR, "serve_resources", "hf")
30+
XGB_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "xgboost")
31+
PYTORCH_SQUEEZENET_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "pytorch")
32+
TF_EFFICIENT_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "tensorflow")
33+
HF_DIR = os.path.join(DATA_DIR, "serve_resources", "hf")
3434

35-
# BYOC_IMAGE_URI_TEMPLATE = "661407751302.dkr.ecr.{}.amazonaws.com/byoc-integ-test-images:{}"
35+
BYOC_IMAGE_URI_TEMPLATE = "661407751302.dkr.ecr.{}.amazonaws.com/byoc-integ-test-images:{}"

tests/integ/sagemaker/serve/test_serve_pt_happy.py

Lines changed: 174 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -10,112 +10,141 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
# from __future__ import absolute_import
14-
15-
# import pytest
16-
# import torch
17-
# from PIL import Image
18-
# import os
19-
20-
# from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
21-
# from sagemaker.serve.builder.schema_builder import SchemaBuilder
22-
# from sagemaker.serve.spec.inference_spec import InferenceSpec
23-
# from torchvision.transforms import transforms
24-
# from torchvision.models.squeezenet import squeezenet1_1
25-
26-
# from tests.integ.sagemaker.serve.constants import (
27-
# PYTORCH_SQUEEZENET_RESOURCE_DIR,
28-
# SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
29-
# NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE,
30-
# NOT_RUNNING_ON_PY310,
31-
# )
32-
# from tests.integ.timeout import timeout
33-
# from tests.integ.utils import cleanup_model_resources
34-
# import logging
35-
36-
# logger = logging.getLogger(__name__)
37-
38-
# ROLE_NAME = "Admin"
39-
40-
# GH_USER_NAME = os.getenv("GH_USER_NAME")
41-
# GH_ACCESS_TOKEN = os.getenv("GH_ACCESS_TOKEN")
42-
43-
44-
# @pytest.fixture
45-
# def pt_dependencies():
46-
# return {
47-
# "auto": True,
48-
# "custom": [
49-
# "boto3==1.26.*",
50-
# "botocore==1.29.*",
51-
# "s3transfer==0.6.*",
52-
# (
53-
# f"git+https://{GH_USER_NAME}:{GH_ACCESS_TOKEN}@github.com"
54-
# "/aws/sagemaker-python-sdk-staging.git@inference-experience-dev"
55-
# ),
56-
# ],
57-
# }
58-
59-
60-
# @pytest.fixture
61-
# def test_image():
62-
# return Image.open(str(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg")))
63-
64-
65-
# @pytest.fixture
66-
# def squeezenet_inference_spec():
67-
# class MySqueezeNetModel(InferenceSpec):
68-
# def __init__(self) -> None:
69-
# super().__init__()
70-
# self.transform = transforms.Compose(
71-
# [
72-
# transforms.Resize(256),
73-
# transforms.CenterCrop(224),
74-
# transforms.ToTensor(),
75-
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
76-
# ]
77-
# )
78-
79-
# def invoke(self, input_object: object, model: object):
80-
# # transform
81-
# image_tensor = self.transform(input_object)
82-
# input_batch = image_tensor.unsqueeze(0)
83-
# # invoke
84-
# with torch.no_grad():
85-
# output = model(input_batch)
86-
# return output
87-
88-
# def load(self, model_dir: str):
89-
# model = squeezenet1_1()
90-
# model.load_state_dict(torch.load(model_dir + "/model.pth"))
91-
# model.eval()
92-
# return model
93-
94-
# return MySqueezeNetModel()
95-
96-
97-
# @pytest.fixture
98-
# def squeezenet_schema():
99-
# input_image = Image.open(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg"))
100-
# output_tensor = torch.rand(3, 4)
101-
# return SchemaBuilder(sample_input=input_image, sample_output=output_tensor)
102-
103-
104-
# @pytest.fixture
105-
# def model_builder_inference_spec_schema_builder(
106-
# squeezenet_inference_spec, squeezenet_schema, pt_dependencies
107-
# ):
108-
# return ModelBuilder(
109-
# model_path=PYTORCH_SQUEEZENET_RESOURCE_DIR,
110-
# inference_spec=squeezenet_inference_spec,
111-
# schema_builder=squeezenet_schema,
112-
# dependencies=pt_dependencies,
113-
# )
114-
115-
116-
# @pytest.fixture
117-
# def model_builder(request):
118-
# return request.getfixturevalue(request.param)
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
import torch
17+
from PIL import Image
18+
import os
19+
import io
20+
import numpy as np
21+
22+
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
23+
from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator
24+
from sagemaker.serve.spec.inference_spec import InferenceSpec
25+
from torchvision.transforms import transforms
26+
from torchvision.models.squeezenet import squeezenet1_1
27+
28+
from tests.integ.sagemaker.serve.constants import (
29+
PYTORCH_SQUEEZENET_RESOURCE_DIR,
30+
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
31+
NOT_RUNNING_ON_PY310,
32+
)
33+
from tests.integ.timeout import timeout
34+
from tests.integ.utils import cleanup_model_resources
35+
import logging
36+
37+
logger = logging.getLogger(__name__)
38+
39+
ROLE_NAME = "SageMakerRole"
40+
41+
42+
@pytest.fixture
43+
def test_image():
44+
return Image.open(str(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg")))
45+
46+
47+
@pytest.fixture
48+
def squeezenet_inference_spec():
49+
class MySqueezeNetModel(InferenceSpec):
50+
def __init__(self) -> None:
51+
super().__init__()
52+
53+
def invoke(self, input_object: object, model: object):
54+
with torch.no_grad():
55+
output = model(input_object)
56+
return output
57+
58+
def load(self, model_dir: str):
59+
model = squeezenet1_1()
60+
model.load_state_dict(torch.load(model_dir + "/model.pth"))
61+
model.eval()
62+
return model
63+
64+
return MySqueezeNetModel()
65+
66+
67+
@pytest.fixture
68+
def custom_request_translator():
69+
# request translator
70+
class MyRequestTranslator(CustomPayloadTranslator):
71+
def __init__(self):
72+
super().__init__()
73+
# Define image transformation
74+
self.transform = transforms.Compose(
75+
[
76+
transforms.Resize(256),
77+
transforms.CenterCrop(224),
78+
transforms.ToTensor(),
79+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
80+
]
81+
)
82+
83+
# This function converts the payload to bytes - happens on client side
84+
def serialize_payload_to_bytes(self, payload: object) -> bytes:
85+
# converts an image to bytes
86+
image_tensor = self.transform(payload)
87+
input_batch = image_tensor.unsqueeze(0)
88+
input_ndarray = input_batch.numpy()
89+
return self._convert_numpy_to_bytes(input_ndarray)
90+
91+
# This function converts the bytes to payload - happens on server side
92+
def deserialize_payload_from_stream(self, stream) -> torch.Tensor:
93+
# convert payload back to torch.Tensor
94+
np_array = np.load(io.BytesIO(stream.read()))
95+
return torch.from_numpy(np_array)
96+
97+
def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes:
98+
buffer = io.BytesIO()
99+
np.save(buffer, np_array)
100+
return buffer.getvalue()
101+
102+
return MyRequestTranslator()
103+
104+
105+
@pytest.fixture
106+
def custom_response_translator():
107+
# response translator
108+
class MyResponseTranslator(CustomPayloadTranslator):
109+
# This function converts the payload to bytes - happens on server side
110+
def serialize_payload_to_bytes(self, payload: torch.Tensor) -> bytes:
111+
return self._convert_numpy_to_bytes(payload.numpy())
112+
113+
# This function converts the bytes to payload - happens on client side
114+
def deserialize_payload_from_stream(self, stream) -> object:
115+
return torch.from_numpy(np.load(io.BytesIO(stream.read())))
116+
117+
def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes:
118+
buffer = io.BytesIO()
119+
np.save(buffer, np_array)
120+
return buffer.getvalue()
121+
122+
return MyResponseTranslator()
123+
124+
125+
@pytest.fixture
126+
def squeezenet_schema(custom_request_translator, custom_response_translator):
127+
input_image = Image.open(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg"))
128+
output_tensor = torch.rand(3, 4)
129+
return SchemaBuilder(
130+
sample_input=input_image,
131+
sample_output=output_tensor,
132+
input_translator=custom_request_translator,
133+
output_translator=custom_response_translator,
134+
)
135+
136+
@pytest.fixture
137+
def model_builder_inference_spec_schema_builder(squeezenet_inference_spec, squeezenet_schema):
138+
return ModelBuilder(
139+
model_path=PYTORCH_SQUEEZENET_RESOURCE_DIR,
140+
inference_spec=squeezenet_inference_spec,
141+
schema_builder=squeezenet_schema,
142+
)
143+
144+
145+
@pytest.fixture
146+
def model_builder(request):
147+
return request.getfixturevalue(request.param)
119148

120149

121150
# @pytest.mark.skipif(
@@ -149,45 +178,45 @@
149178
# ), f"{caught_ex} was thrown when running pytorch squeezenet local container test"
150179

151180

152-
# @pytest.mark.skipif(
153-
# NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE or NOT_RUNNING_ON_PY310,
154-
# reason="The goal of these test are to test the serving components of our feature",
155-
# )
156-
# @pytest.mark.parametrize(
157-
# "model_builder", ["model_builder_inference_spec_schema_builder"], indirect=True
158-
# )
159-
# def test_happy_pytorch_sagemaker_endpoint(
160-
# sagemaker_session, model_builder, cpu_instance_type, test_image
161-
# ):
162-
# logger.info("Running in SAGEMAKER_ENDPOINT mode...")
163-
# caught_ex = None
164-
165-
# iam_client = sagemaker_session.boto_session.client("iam")
166-
# role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
167-
168-
# model = model_builder.build(
169-
# mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session
170-
# )
171-
172-
# with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
173-
# try:
174-
# logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
175-
# predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1)
176-
# logger.info("Endpoint successfully deployed.")
177-
# predictor.predict(test_image)
178-
# except Exception as e:
179-
# caught_ex = e
180-
# finally:
181-
# cleanup_model_resources(
182-
# sagemaker_session=model_builder.sagemaker_session,
183-
# model_name=model.name,
184-
# endpoint_name=model.endpoint_name,
185-
# )
186-
# if caught_ex:
187-
# logger.exception(caught_ex)
188-
# assert (
189-
# False
190-
# ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test"
181+
@pytest.mark.skipif(
182+
NOT_RUNNING_ON_PY310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE,
183+
reason="The goal of these test are to test the serving components of our feature",
184+
)
185+
@pytest.mark.parametrize(
186+
"model_builder", ["model_builder_inference_spec_schema_builder"], indirect=True
187+
)
188+
def test_happy_pytorch_sagemaker_endpoint(
189+
sagemaker_session, model_builder, cpu_instance_type, test_image
190+
):
191+
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
192+
caught_ex = None
193+
194+
iam_client = sagemaker_session.boto_session.client("iam")
195+
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
196+
197+
model = model_builder.build(
198+
mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session
199+
)
200+
201+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
202+
try:
203+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
204+
predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1)
205+
logger.info("Endpoint successfully deployed.")
206+
predictor.predict(test_image)
207+
except Exception as e:
208+
caught_ex = e
209+
finally:
210+
cleanup_model_resources(
211+
sagemaker_session=model_builder.sagemaker_session,
212+
model_name=model.name,
213+
endpoint_name=model.endpoint_name,
214+
)
215+
if caught_ex:
216+
logger.exception(caught_ex)
217+
assert (
218+
False
219+
), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test"
191220

192221

193222
# @pytest.mark.skipif(

0 commit comments

Comments
 (0)