|
10 | 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 |
| -# from __future__ import absolute_import |
14 |
| - |
15 |
| -# import pytest |
16 |
| -# import torch |
17 |
| -# from PIL import Image |
18 |
| -# import os |
19 |
| - |
20 |
| -# from sagemaker.serve.builder.model_builder import ModelBuilder, Mode |
21 |
| -# from sagemaker.serve.builder.schema_builder import SchemaBuilder |
22 |
| -# from sagemaker.serve.spec.inference_spec import InferenceSpec |
23 |
| -# from torchvision.transforms import transforms |
24 |
| -# from torchvision.models.squeezenet import squeezenet1_1 |
25 |
| - |
26 |
| -# from tests.integ.sagemaker.serve.constants import ( |
27 |
| -# PYTORCH_SQUEEZENET_RESOURCE_DIR, |
28 |
| -# SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, |
29 |
| -# NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE, |
30 |
| -# NOT_RUNNING_ON_PY310, |
31 |
| -# ) |
32 |
| -# from tests.integ.timeout import timeout |
33 |
| -# from tests.integ.utils import cleanup_model_resources |
34 |
| -# import logging |
35 |
| - |
36 |
| -# logger = logging.getLogger(__name__) |
37 |
| - |
38 |
| -# ROLE_NAME = "Admin" |
39 |
| - |
40 |
| -# GH_USER_NAME = os.getenv("GH_USER_NAME") |
41 |
| -# GH_ACCESS_TOKEN = os.getenv("GH_ACCESS_TOKEN") |
42 |
| - |
43 |
| - |
44 |
| -# @pytest.fixture |
45 |
| -# def pt_dependencies(): |
46 |
| -# return { |
47 |
| -# "auto": True, |
48 |
| -# "custom": [ |
49 |
| -# "boto3==1.26.*", |
50 |
| -# "botocore==1.29.*", |
51 |
| -# "s3transfer==0.6.*", |
52 |
| -# ( |
53 |
| -# f"git+https://{GH_USER_NAME}:{GH_ACCESS_TOKEN}@github.com" |
54 |
| -# "/aws/sagemaker-python-sdk-staging.git@inference-experience-dev" |
55 |
| -# ), |
56 |
| -# ], |
57 |
| -# } |
58 |
| - |
59 |
| - |
60 |
| -# @pytest.fixture |
61 |
| -# def test_image(): |
62 |
| -# return Image.open(str(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg"))) |
63 |
| - |
64 |
| - |
65 |
| -# @pytest.fixture |
66 |
| -# def squeezenet_inference_spec(): |
67 |
| -# class MySqueezeNetModel(InferenceSpec): |
68 |
| -# def __init__(self) -> None: |
69 |
| -# super().__init__() |
70 |
| -# self.transform = transforms.Compose( |
71 |
| -# [ |
72 |
| -# transforms.Resize(256), |
73 |
| -# transforms.CenterCrop(224), |
74 |
| -# transforms.ToTensor(), |
75 |
| -# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
76 |
| -# ] |
77 |
| -# ) |
78 |
| - |
79 |
| -# def invoke(self, input_object: object, model: object): |
80 |
| -# # transform |
81 |
| -# image_tensor = self.transform(input_object) |
82 |
| -# input_batch = image_tensor.unsqueeze(0) |
83 |
| -# # invoke |
84 |
| -# with torch.no_grad(): |
85 |
| -# output = model(input_batch) |
86 |
| -# return output |
87 |
| - |
88 |
| -# def load(self, model_dir: str): |
89 |
| -# model = squeezenet1_1() |
90 |
| -# model.load_state_dict(torch.load(model_dir + "/model.pth")) |
91 |
| -# model.eval() |
92 |
| -# return model |
93 |
| - |
94 |
| -# return MySqueezeNetModel() |
95 |
| - |
96 |
| - |
97 |
| -# @pytest.fixture |
98 |
| -# def squeezenet_schema(): |
99 |
| -# input_image = Image.open(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg")) |
100 |
| -# output_tensor = torch.rand(3, 4) |
101 |
| -# return SchemaBuilder(sample_input=input_image, sample_output=output_tensor) |
102 |
| - |
103 |
| - |
104 |
| -# @pytest.fixture |
105 |
| -# def model_builder_inference_spec_schema_builder( |
106 |
| -# squeezenet_inference_spec, squeezenet_schema, pt_dependencies |
107 |
| -# ): |
108 |
| -# return ModelBuilder( |
109 |
| -# model_path=PYTORCH_SQUEEZENET_RESOURCE_DIR, |
110 |
| -# inference_spec=squeezenet_inference_spec, |
111 |
| -# schema_builder=squeezenet_schema, |
112 |
| -# dependencies=pt_dependencies, |
113 |
| -# ) |
114 |
| - |
115 |
| - |
116 |
| -# @pytest.fixture |
117 |
| -# def model_builder(request): |
118 |
| -# return request.getfixturevalue(request.param) |
| 13 | +from __future__ import absolute_import |
| 14 | + |
| 15 | +import pytest |
| 16 | +import torch |
| 17 | +from PIL import Image |
| 18 | +import os |
| 19 | +import io |
| 20 | +import numpy as np |
| 21 | + |
| 22 | +from sagemaker.serve.builder.model_builder import ModelBuilder, Mode |
| 23 | +from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator |
| 24 | +from sagemaker.serve.spec.inference_spec import InferenceSpec |
| 25 | +from torchvision.transforms import transforms |
| 26 | +from torchvision.models.squeezenet import squeezenet1_1 |
| 27 | + |
| 28 | +from tests.integ.sagemaker.serve.constants import ( |
| 29 | + PYTORCH_SQUEEZENET_RESOURCE_DIR, |
| 30 | + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, |
| 31 | + NOT_RUNNING_ON_PY310, |
| 32 | +) |
| 33 | +from tests.integ.timeout import timeout |
| 34 | +from tests.integ.utils import cleanup_model_resources |
| 35 | +import logging |
| 36 | + |
| 37 | +logger = logging.getLogger(__name__) |
| 38 | + |
| 39 | +ROLE_NAME = "SageMakerRole" |
| 40 | + |
| 41 | + |
| 42 | +@pytest.fixture |
| 43 | +def test_image(): |
| 44 | + return Image.open(str(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg"))) |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture |
| 48 | +def squeezenet_inference_spec(): |
| 49 | + class MySqueezeNetModel(InferenceSpec): |
| 50 | + def __init__(self) -> None: |
| 51 | + super().__init__() |
| 52 | + |
| 53 | + def invoke(self, input_object: object, model: object): |
| 54 | + with torch.no_grad(): |
| 55 | + output = model(input_object) |
| 56 | + return output |
| 57 | + |
| 58 | + def load(self, model_dir: str): |
| 59 | + model = squeezenet1_1() |
| 60 | + model.load_state_dict(torch.load(model_dir + "/model.pth")) |
| 61 | + model.eval() |
| 62 | + return model |
| 63 | + |
| 64 | + return MySqueezeNetModel() |
| 65 | + |
| 66 | + |
| 67 | +@pytest.fixture |
| 68 | +def custom_request_translator(): |
| 69 | + # request translator |
| 70 | + class MyRequestTranslator(CustomPayloadTranslator): |
| 71 | + def __init__(self): |
| 72 | + super().__init__() |
| 73 | + # Define image transformation |
| 74 | + self.transform = transforms.Compose( |
| 75 | + [ |
| 76 | + transforms.Resize(256), |
| 77 | + transforms.CenterCrop(224), |
| 78 | + transforms.ToTensor(), |
| 79 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| 80 | + ] |
| 81 | + ) |
| 82 | + |
| 83 | + # This function converts the payload to bytes - happens on client side |
| 84 | + def serialize_payload_to_bytes(self, payload: object) -> bytes: |
| 85 | + # converts an image to bytes |
| 86 | + image_tensor = self.transform(payload) |
| 87 | + input_batch = image_tensor.unsqueeze(0) |
| 88 | + input_ndarray = input_batch.numpy() |
| 89 | + return self._convert_numpy_to_bytes(input_ndarray) |
| 90 | + |
| 91 | + # This function converts the bytes to payload - happens on server side |
| 92 | + def deserialize_payload_from_stream(self, stream) -> torch.Tensor: |
| 93 | + # convert payload back to torch.Tensor |
| 94 | + np_array = np.load(io.BytesIO(stream.read())) |
| 95 | + return torch.from_numpy(np_array) |
| 96 | + |
| 97 | + def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes: |
| 98 | + buffer = io.BytesIO() |
| 99 | + np.save(buffer, np_array) |
| 100 | + return buffer.getvalue() |
| 101 | + |
| 102 | + return MyRequestTranslator() |
| 103 | + |
| 104 | + |
| 105 | +@pytest.fixture |
| 106 | +def custom_response_translator(): |
| 107 | + # response translator |
| 108 | + class MyResponseTranslator(CustomPayloadTranslator): |
| 109 | + # This function converts the payload to bytes - happens on server side |
| 110 | + def serialize_payload_to_bytes(self, payload: torch.Tensor) -> bytes: |
| 111 | + return self._convert_numpy_to_bytes(payload.numpy()) |
| 112 | + |
| 113 | + # This function converts the bytes to payload - happens on client side |
| 114 | + def deserialize_payload_from_stream(self, stream) -> object: |
| 115 | + return torch.from_numpy(np.load(io.BytesIO(stream.read()))) |
| 116 | + |
| 117 | + def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes: |
| 118 | + buffer = io.BytesIO() |
| 119 | + np.save(buffer, np_array) |
| 120 | + return buffer.getvalue() |
| 121 | + |
| 122 | + return MyResponseTranslator() |
| 123 | + |
| 124 | + |
| 125 | +@pytest.fixture |
| 126 | +def squeezenet_schema(custom_request_translator, custom_response_translator): |
| 127 | + input_image = Image.open(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg")) |
| 128 | + output_tensor = torch.rand(3, 4) |
| 129 | + return SchemaBuilder( |
| 130 | + sample_input=input_image, |
| 131 | + sample_output=output_tensor, |
| 132 | + input_translator=custom_request_translator, |
| 133 | + output_translator=custom_response_translator, |
| 134 | + ) |
| 135 | + |
| 136 | +@pytest.fixture |
| 137 | +def model_builder_inference_spec_schema_builder(squeezenet_inference_spec, squeezenet_schema): |
| 138 | + return ModelBuilder( |
| 139 | + model_path=PYTORCH_SQUEEZENET_RESOURCE_DIR, |
| 140 | + inference_spec=squeezenet_inference_spec, |
| 141 | + schema_builder=squeezenet_schema, |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +@pytest.fixture |
| 146 | +def model_builder(request): |
| 147 | + return request.getfixturevalue(request.param) |
119 | 148 |
|
120 | 149 |
|
121 | 150 | # @pytest.mark.skipif(
|
|
149 | 178 | # ), f"{caught_ex} was thrown when running pytorch squeezenet local container test"
|
150 | 179 |
|
151 | 180 |
|
152 |
| -# @pytest.mark.skipif( |
153 |
| -# NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE or NOT_RUNNING_ON_PY310, |
154 |
| -# reason="The goal of these test are to test the serving components of our feature", |
155 |
| -# ) |
156 |
| -# @pytest.mark.parametrize( |
157 |
| -# "model_builder", ["model_builder_inference_spec_schema_builder"], indirect=True |
158 |
| -# ) |
159 |
| -# def test_happy_pytorch_sagemaker_endpoint( |
160 |
| -# sagemaker_session, model_builder, cpu_instance_type, test_image |
161 |
| -# ): |
162 |
| -# logger.info("Running in SAGEMAKER_ENDPOINT mode...") |
163 |
| -# caught_ex = None |
164 |
| - |
165 |
| -# iam_client = sagemaker_session.boto_session.client("iam") |
166 |
| -# role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] |
167 |
| - |
168 |
| -# model = model_builder.build( |
169 |
| -# mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session |
170 |
| -# ) |
171 |
| - |
172 |
| -# with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): |
173 |
| -# try: |
174 |
| -# logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") |
175 |
| -# predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) |
176 |
| -# logger.info("Endpoint successfully deployed.") |
177 |
| -# predictor.predict(test_image) |
178 |
| -# except Exception as e: |
179 |
| -# caught_ex = e |
180 |
| -# finally: |
181 |
| -# cleanup_model_resources( |
182 |
| -# sagemaker_session=model_builder.sagemaker_session, |
183 |
| -# model_name=model.name, |
184 |
| -# endpoint_name=model.endpoint_name, |
185 |
| -# ) |
186 |
| -# if caught_ex: |
187 |
| -# logger.exception(caught_ex) |
188 |
| -# assert ( |
189 |
| -# False |
190 |
| -# ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test" |
| 181 | +@pytest.mark.skipif( |
| 182 | + NOT_RUNNING_ON_PY310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE, |
| 183 | + reason="The goal of these test are to test the serving components of our feature", |
| 184 | +) |
| 185 | +@pytest.mark.parametrize( |
| 186 | + "model_builder", ["model_builder_inference_spec_schema_builder"], indirect=True |
| 187 | +) |
| 188 | +def test_happy_pytorch_sagemaker_endpoint( |
| 189 | + sagemaker_session, model_builder, cpu_instance_type, test_image |
| 190 | +): |
| 191 | + logger.info("Running in SAGEMAKER_ENDPOINT mode...") |
| 192 | + caught_ex = None |
| 193 | + |
| 194 | + iam_client = sagemaker_session.boto_session.client("iam") |
| 195 | + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] |
| 196 | + |
| 197 | + model = model_builder.build( |
| 198 | + mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session |
| 199 | + ) |
| 200 | + |
| 201 | + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): |
| 202 | + try: |
| 203 | + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") |
| 204 | + predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) |
| 205 | + logger.info("Endpoint successfully deployed.") |
| 206 | + predictor.predict(test_image) |
| 207 | + except Exception as e: |
| 208 | + caught_ex = e |
| 209 | + finally: |
| 210 | + cleanup_model_resources( |
| 211 | + sagemaker_session=model_builder.sagemaker_session, |
| 212 | + model_name=model.name, |
| 213 | + endpoint_name=model.endpoint_name, |
| 214 | + ) |
| 215 | + if caught_ex: |
| 216 | + logger.exception(caught_ex) |
| 217 | + assert ( |
| 218 | + False |
| 219 | + ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test" |
191 | 220 |
|
192 | 221 |
|
193 | 222 | # @pytest.mark.skipif(
|
|
0 commit comments