Skip to content

Commit 99afcc8

Browse files
SSRraymondRaymond Liu
andauthored
fix: update image and hardware validation with inf and graviton (#4299)
Co-authored-by: Raymond Liu <[email protected]>
1 parent 8b73191 commit 99afcc8

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

src/sagemaker/serve/utils/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ def __str__(self) -> str:
4040

4141
CPU = 1
4242
GPU = 2
43+
INFERENTIA_1 = 3
44+
INFERENTIA_2 = 4
45+
GRAVITON = 5

src/sagemaker/serve/validations/check_image_and_hardware_type.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@
1717
}
1818

1919

20+
INF1_INSTANCE_FAMILIES = {"ml.inf1"}
21+
INF2_INSTANCE_FAMILIES = {"ml.inf2"}
22+
23+
GRAVITON_INSTANCE_FAMILIES = {
24+
"ml.c7g",
25+
"ml.m6g",
26+
"ml.m6gd",
27+
"ml.c6g",
28+
"ml.c6gd",
29+
"ml.c6gn",
30+
"ml.r6g",
31+
"ml.r6gd",
32+
}
33+
34+
2035
def validate_image_uri_and_hardware(image_uri: str, instance_type: str, model_server: ModelServer):
2136
"""Placeholder docstring"""
2237
if "xgboost" in image_uri:
@@ -57,6 +72,12 @@ def detect_hardware_type_of_instance(instance_type: str) -> HardwareType:
5772
instance_family = instance_type.rsplit(".", 1)[0]
5873
if instance_family in GPU_INSTANCE_FAMILIES:
5974
return HardwareType.GPU
75+
if instance_family in INF1_INSTANCE_FAMILIES:
76+
return HardwareType.INFERENTIA_1
77+
if instance_family in INF2_INSTANCE_FAMILIES:
78+
return HardwareType.INFERENTIA_2
79+
if instance_family in GRAVITON_INSTANCE_FAMILIES:
80+
return HardwareType.GRAVITON
6081
return HardwareType.CPU
6182

6283

@@ -67,4 +88,13 @@ def detect_triton_image_hardware_type(image_uri: str) -> HardwareType:
6788

6889
def detect_torchserve_image_hardware_type(image_uri: str) -> HardwareType:
6990
"""Placeholder docstring"""
70-
return HardwareType.CPU if "cpu" in image_uri else HardwareType.GPU
91+
if "neuronx" in image_uri:
92+
return HardwareType.INFERENTIA_2
93+
if "neuron" in image_uri:
94+
return HardwareType.INFERENTIA_1
95+
if "graviton" in image_uri:
96+
return HardwareType.GRAVITON
97+
if "cpu" in image_uri:
98+
return HardwareType.CPU
99+
100+
return HardwareType.GPU

tests/unit/sagemaker/serve/validations/test_check_image_and_hardware_type.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,25 @@
3131
"301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:23.08-py3-cpu"
3232
)
3333
GPU_IMAGE_TRITON = "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:23.08-py3"
34+
GRAVITON_IMAGE_TORCHSERVE = (
35+
"763104351884.dkr.ecr.us-east-1.amazonaws.com/"
36+
"pytorch-inference-graviton:2.1.0-cpu-py310-ubuntu20.04-sagemaker"
37+
)
38+
INF1_IMAGE_TORCHSERVE = (
39+
"763104351884.dkr.ecr.us-west-2.amazonaws.com"
40+
"/pytorch-inference-neuron:1.13.1-neuron-py310-sdk2.15.0-ubuntu20.04"
41+
)
42+
43+
INF2_IMAGE_TORCHSERVE = (
44+
"763104351884.dkr.ecr.us-west-2.amazonaws.com"
45+
"/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.15.0-ubuntu20.04"
46+
)
3447

3548
CPU_INSTANCE = "ml.c5.xlarge"
3649
GPU_INSTANCE = "ml.g4dn.xlarge"
50+
INF1_INSTANCE = "ml.inf1.xlarge"
51+
INF2_INSTANCE = "ml.inf2.xlarge"
52+
GRAVITON_INSTANCE = "ml.c7g.xlarge"
3753

3854

3955
class TestValidateImageAndHardware(unittest.TestCase):
@@ -116,3 +132,53 @@ def test_triton_gpu_image_with_cpu_instance(self):
116132
)
117133

118134
mock_logger.assert_called_once()
135+
136+
def test_torchserve_inf1_image_with_inf1_instance(self):
137+
138+
with patch("logging.Logger.warning") as mock_logger:
139+
validate_image_uri_and_hardware(
140+
image_uri=INF1_IMAGE_TORCHSERVE,
141+
instance_type=INF1_INSTANCE,
142+
model_server=ModelServer.TORCHSERVE,
143+
)
144+
mock_logger.assert_not_called()
145+
146+
def test_torchserve_inf2_image_with_inf2_instance(self):
147+
148+
with patch("logging.Logger.warning") as mock_logger:
149+
validate_image_uri_and_hardware(
150+
image_uri=INF2_IMAGE_TORCHSERVE,
151+
instance_type=INF2_INSTANCE,
152+
model_server=ModelServer.TORCHSERVE,
153+
)
154+
mock_logger.assert_not_called()
155+
156+
def test_torchserve_graviton_image_with_graviton_instance(self):
157+
158+
with patch("logging.Logger.warning") as mock_logger:
159+
validate_image_uri_and_hardware(
160+
image_uri=GRAVITON_IMAGE_TORCHSERVE,
161+
instance_type=GRAVITON_INSTANCE,
162+
model_server=ModelServer.TORCHSERVE,
163+
)
164+
mock_logger.assert_not_called()
165+
166+
def test_torchserve_inf1_image_with_cpu_instance(self):
167+
168+
with patch("logging.Logger.warning") as mock_logger:
169+
validate_image_uri_and_hardware(
170+
image_uri=INF1_IMAGE_TORCHSERVE,
171+
instance_type=CPU_INSTANCE,
172+
model_server=ModelServer.TORCHSERVE,
173+
)
174+
mock_logger.assert_called_once()
175+
176+
def test_torchserve_graviton_image_with_cpu_instance(self):
177+
178+
with patch("logging.Logger.warning") as mock_logger:
179+
validate_image_uri_and_hardware(
180+
image_uri=GRAVITON_IMAGE_TORCHSERVE,
181+
instance_type=CPU_INSTANCE,
182+
model_server=ModelServer.TORCHSERVE,
183+
)
184+
mock_logger.assert_called_once()

0 commit comments

Comments
 (0)