Skip to content

Commit 81ca3a0

Browse files
SSRraymondRaymond Liu
andauthored
feature: support model.register() with triton model (#4305)
Co-authored-by: Raymond Liu <[email protected]>
1 parent 99afcc8 commit 81ca3a0

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

src/sagemaker/serve/marshalling/triton_translator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def __init__(self) -> None:
1616
import torch
1717

1818
self.convert_from_numpy = torch.from_numpy # pylint: disable=E1101
19+
self.CONTENT_TYPE = "tensor/pt"
20+
self.ACCEPT = "tensor/pt"
1921

2022
def serialize(self, data, content_type: str = "tensor/pt"):
2123
"""Translate torch.Tensor to numpy ndarray"""
@@ -45,6 +47,8 @@ def __init__(self) -> None:
4547
import tensorflow as tf
4648

4749
self.convert_to_tensor = tf.convert_to_tensor
50+
self.CONTENT_TYPE = "tensor/tf"
51+
self.ACCEPT = "tensor/tf"
4852

4953
def serialize(self, data, content_type: str = "tensor/tf"):
5054
"""Translate tf.Tensor to numpy ndarray"""
@@ -70,6 +74,10 @@ def _deserializer(self):
7074
class NumpyTranslator:
7175
"""A dummy class to make sure the translator interface is aligned"""
7276

77+
def __init__(self) -> None:
78+
self.CONTENT_TYPE = "application/x-npy"
79+
self.ACCEPT = "application/x-npy"
80+
7381
def serialize(self, data, content_type: str = "application/x-npy"):
7482
"""Placeholder docstring"""
7583
return data
@@ -86,6 +94,10 @@ def _deserializer(self):
8694
class ListTranslator:
8795
"""Translate python list from and to numpy.ndarray"""
8896

97+
def __init__(self) -> None:
98+
self.CONTENT_TYPE = "application/list"
99+
self.ACCEPT = "application/list"
100+
89101
def serialize(self, data, content_type: str = "application/list"):
90102
"""Placeholder docstring"""
91103
try:

src/sagemaker/serve/model_server/triton/triton_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ def _create_triton_model(self) -> Type[Model]:
430430
# unique method to models created via ModelBuilder()
431431
self._original_deploy = self.pysdk_model.deploy
432432
self.pysdk_model.deploy = self._model_builder_deploy_wrapper
433+
self._original_register = self.pysdk_model.register
434+
self.pysdk_model.register = self._model_builder_register_wrapper
435+
self.model_package = None
433436
return self.pysdk_model
434437

435438
def _get_triton_predictor(self, endpoint_name: str, sagemaker_session: Session) -> Predictor:

tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MOCK_SESSION = Mock()
3434
MOCK_MODES = Mock()
3535
MOCK_DEPLOY_WRAPPER = Mock()
36+
MOCK_RESIGTER_WRAPPER = Mock()
3637

3738

3839
class pytorch:
@@ -56,6 +57,7 @@ def prepare_triton_builder_for_model(self, triton_builder: Triton) -> Triton:
5657
triton_builder.sagemaker_session = MOCK_SESSION
5758
triton_builder.modes = MOCK_MODES
5859
triton_builder._model_builder_deploy_wrapper = MOCK_DEPLOY_WRAPPER
60+
triton_builder._model_builder_register_wrapper = MOCK_RESIGTER_WRAPPER
5961
triton_builder.inference_spec = None
6062

6163
mock_export = Mock()

0 commit comments

Comments
 (0)