|
31 | 31 | "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:23.08-py3-cpu"
|
32 | 32 | )
|
33 | 33 | GPU_IMAGE_TRITON = "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:23.08-py3"
|
| 34 | +GRAVITON_IMAGE_TORCHSERVE = ( |
| 35 | + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" |
| 36 | + "pytorch-inference-graviton:2.1.0-cpu-py310-ubuntu20.04-sagemaker" |
| 37 | +) |
| 38 | +INF1_IMAGE_TORCHSERVE = ( |
| 39 | + "763104351884.dkr.ecr.us-west-2.amazonaws.com" |
| 40 | + "/pytorch-inference-neuron:1.13.1-neuron-py310-sdk2.15.0-ubuntu20.04" |
| 41 | +) |
| 42 | + |
| 43 | +INF2_IMAGE_TORCHSERVE = ( |
| 44 | + "763104351884.dkr.ecr.us-west-2.amazonaws.com" |
| 45 | + "/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.15.0-ubuntu20.04" |
| 46 | +) |
34 | 47 |
|
35 | 48 | CPU_INSTANCE = "ml.c5.xlarge"
|
36 | 49 | GPU_INSTANCE = "ml.g4dn.xlarge"
|
| 50 | +INF1_INSTANCE = "ml.inf1.xlarge" |
| 51 | +INF2_INSTANCE = "ml.inf2.xlarge" |
| 52 | +GRAVITON_INSTANCE = "ml.c7g.xlarge" |
37 | 53 |
|
38 | 54 |
|
39 | 55 | class TestValidateImageAndHardware(unittest.TestCase):
|
@@ -116,3 +132,53 @@ def test_triton_gpu_image_with_cpu_instance(self):
|
116 | 132 | )
|
117 | 133 |
|
118 | 134 | mock_logger.assert_called_once()
|
| 135 | + |
| 136 | + def test_torchserve_inf1_image_with_inf1_instance(self): |
| 137 | + |
| 138 | + with patch("logging.Logger.warning") as mock_logger: |
| 139 | + validate_image_uri_and_hardware( |
| 140 | + image_uri=INF1_IMAGE_TORCHSERVE, |
| 141 | + instance_type=INF1_INSTANCE, |
| 142 | + model_server=ModelServer.TORCHSERVE, |
| 143 | + ) |
| 144 | + mock_logger.assert_not_called() |
| 145 | + |
| 146 | + def test_torchserve_inf2_image_with_inf2_instance(self): |
| 147 | + |
| 148 | + with patch("logging.Logger.warning") as mock_logger: |
| 149 | + validate_image_uri_and_hardware( |
| 150 | + image_uri=INF2_IMAGE_TORCHSERVE, |
| 151 | + instance_type=INF2_INSTANCE, |
| 152 | + model_server=ModelServer.TORCHSERVE, |
| 153 | + ) |
| 154 | + mock_logger.assert_not_called() |
| 155 | + |
| 156 | + def test_torchserve_graviton_image_with_graviton_instance(self): |
| 157 | + |
| 158 | + with patch("logging.Logger.warning") as mock_logger: |
| 159 | + validate_image_uri_and_hardware( |
| 160 | + image_uri=GRAVITON_IMAGE_TORCHSERVE, |
| 161 | + instance_type=GRAVITON_INSTANCE, |
| 162 | + model_server=ModelServer.TORCHSERVE, |
| 163 | + ) |
| 164 | + mock_logger.assert_not_called() |
| 165 | + |
| 166 | + def test_torchserve_inf1_image_with_cpu_instance(self): |
| 167 | + |
| 168 | + with patch("logging.Logger.warning") as mock_logger: |
| 169 | + validate_image_uri_and_hardware( |
| 170 | + image_uri=INF1_IMAGE_TORCHSERVE, |
| 171 | + instance_type=CPU_INSTANCE, |
| 172 | + model_server=ModelServer.TORCHSERVE, |
| 173 | + ) |
| 174 | + mock_logger.assert_called_once() |
| 175 | + |
| 176 | + def test_torchserve_graviton_image_with_cpu_instance(self): |
| 177 | + |
| 178 | + with patch("logging.Logger.warning") as mock_logger: |
| 179 | + validate_image_uri_and_hardware( |
| 180 | + image_uri=GRAVITON_IMAGE_TORCHSERVE, |
| 181 | + instance_type=CPU_INSTANCE, |
| 182 | + model_server=ModelServer.TORCHSERVE, |
| 183 | + ) |
| 184 | + mock_logger.assert_called_once() |
0 commit comments