Skip to content

Commit 388ea2a

Browse files
author
Jonathan Makunga
committed
Add TEI Serving
1 parent 776e006 commit 388ea2a

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

src/sagemaker/serve/builder/tei_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_get_nb_instance,
2626
)
2727
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
28-
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
28+
from sagemaker.serve.utils.predictors import TeiLocalModePredictor
2929
from sagemaker.serve.utils.types import ModelServer
3030
from sagemaker.serve.mode.function_pointers import Mode
3131
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
@@ -142,7 +142,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
142142
if self.mode == Mode.LOCAL_CONTAINER:
143143
timeout = kwargs.get("model_data_download_timeout")
144144

145-
predictor = TgiLocalModePredictor(
145+
predictor = TeiLocalModePredictor(
146146
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
147147
)
148148

src/sagemaker/serve/mode/local_container_mode.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,15 @@ def create_server(
158158
env_vars=env_vars if env_vars else self.env_vars,
159159
)
160160
self._ping_container = self._tensorflow_serving_deep_ping
161+
elif self.model_server == ModelServer.TEI:
162+
self._start_tei_serving(
163+
client=self.client,
164+
image=image,
165+
model_path=model_path if model_path else self.model_path,
166+
secret_key=secret_key,
167+
env_vars=env_vars if env_vars else self.env_vars,
168+
)
169+
self._ping_container = self._tei_deep_ping
161170

162171
# allow some time for container to be ready
163172
time.sleep(10)

src/sagemaker/serve/utils/predictors.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,49 @@ def delete_predictor(self):
209209
self._mode_obj.destroy_server()
210210

211211

212+
class TeiLocalModePredictor(PredictorBase):
213+
"""Lightweight Tei predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes"""
214+
215+
def __init__(
216+
self,
217+
mode_obj: Type[LocalContainerMode],
218+
serializer=JSONSerializer(),
219+
deserializer=JSONDeserializer(),
220+
):
221+
self._mode_obj = mode_obj
222+
self.serializer = serializer
223+
self.deserializer = deserializer
224+
225+
def predict(self, data):
226+
"""Placeholder docstring"""
227+
return [
228+
self.deserializer.deserialize(
229+
io.BytesIO(
230+
self._mode_obj._invoke_tei_serving(
231+
self.serializer.serialize(data),
232+
self.content_type,
233+
self.deserializer.ACCEPT[0],
234+
)
235+
),
236+
self.content_type,
237+
)
238+
]
239+
240+
@property
241+
def content_type(self):
242+
"""The MIME type of the data sent to the inference endpoint."""
243+
return self.serializer.CONTENT_TYPE
244+
245+
@property
246+
def accept(self):
247+
"""The content type(s) that are expected from the inference endpoint."""
248+
return self.deserializer.ACCEPT
249+
250+
def delete_predictor(self):
251+
"""Shut down and remove the container that you created in LOCAL_CONTAINER mode"""
252+
self._mode_obj.destroy_server()
253+
254+
212255
class TensorflowServingLocalPredictor(PredictorBase):
213256
"""Lightweight predictor for local deployment in LOCAL_CONTAINER modes"""
214257

0 commit comments

Comments
 (0)