Skip to content

Commit 8711836

Browse files
Pass inference_tool for generating neuron/x image_uris
1 parent c753997 commit 8711836

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

src/sagemaker/huggingface/model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,11 @@ def register(
448448
)
449449

450450
def prepare_container_def(
451-
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
451+
self,
452+
instance_type=None,
453+
accelerator_type=None,
454+
serverless_inference_config=None,
455+
inference_tool=None,
452456
):
453457
"""A container definition with framework configuration set in model environment variables.
454458
@@ -461,6 +465,8 @@ def prepare_container_def(
461465
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
462466
Specifies configuration related to serverless endpoint. Instance type is
463467
not provided in serverless inference. So this is used to find image URIs.
468+
inference_tool (str): the tool that will be used to aid in the inference.
469+
Valid values: "neuron, neuronx, None" (default: None).
464470
465471
Returns:
466472
dict[str, str]: A container definition object usable with the
@@ -479,6 +485,7 @@ def prepare_container_def(
479485
instance_type,
480486
accelerator_type=accelerator_type,
481487
serverless_inference_config=serverless_inference_config,
488+
inference_tool=inference_tool,
482489
)
483490

484491
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
@@ -500,6 +507,7 @@ def serving_image_uri(
500507
instance_type=None,
501508
accelerator_type=None,
502509
serverless_inference_config=None,
510+
inference_tool=None,
503511
):
504512
"""Create a URI for the serving image.
505513
@@ -513,6 +521,8 @@ def serving_image_uri(
513521
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
514522
Specifies configuration related to serverless endpoint. Instance type is
515523
not provided in serverless inference. So this is used used to determine device type.
524+
inference_tool (str): the tool that will be used to aid in the inference.
525+
Valid values: "neuron, neuronx, None" (default: None).
516526
517527
Returns:
518528
str: The appropriate image URI based on the given parameters.
@@ -534,4 +544,5 @@ def serving_image_uri(
534544
image_scope="inference",
535545
base_framework_version=base_framework_version,
536546
serverless_inference_config=serverless_inference_config,
547+
inference_tool=inference_tool,
537548
)

src/sagemaker/image_uris.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def retrieve(
104104
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
105105
(default: None).
106106
inference_tool (str): the tool that will be used to aid in the inference.
107-
Valid values: "neuron, None"
107+
Valid values: "neuron, neuronx, None"
108108
(default: None).
109109
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
110110
Specifies configuration related to serverless endpoint. Instance type is
@@ -158,7 +158,7 @@ def retrieve(
158158
_framework = framework
159159
if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS:
160160
inference_tool = _get_inference_tool(inference_tool, instance_type)
161-
if inference_tool == "neuron":
161+
if inference_tool == "neuron" or inference_tool == "neuronx":
162162
_framework = f"{framework}-{inference_tool}"
163163
final_image_scope = _get_final_image_scope(framework, instance_type, image_scope)
164164
_validate_for_suppported_frameworks_and_instance_type(framework, instance_type)

tests/unit/sagemaker/huggingface/huggingface_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,5 @@ def get_full_neuronx_image_uri(
5050
image_scope="training",
5151
base_framework_version=base_framework_version,
5252
container_version="cu110-ubuntu18.04",
53+
inference_tool="neuronx",
5354
)

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_huggingface_neuron(
269269
pytorch_version=huggingface_neuron_latest_inference_pytorch_version,
270270
py_version=huggingface_neuron_latest_inference_py_version,
271271
)
272-
container = huggingface_model.prepare_container_def("ml.inf1.xlarge")
272+
container = huggingface_model.prepare_container_def("ml.inf1.xlarge", inference_tool="neuron")
273273
assert container["Image"]
274274

275275

@@ -289,7 +289,7 @@ def test_huggingface_neuronx(
289289
pytorch_version=huggingface_neuronx_latest_inference_pytorch_version,
290290
py_version=huggingface_neuronx_latest_inference_py_version,
291291
)
292-
container = huggingface_model.prepare_container_def("ml.inf2.xlarge")
292+
container = huggingface_model.prepare_container_def("ml.inf2.xlarge", inference_tool="neuronx")
293293
assert container["Image"]
294294

295295

0 commit comments

Comments
 (0)