Skip to content

Commit 776e006

Browse files
author
Jonathan Makunga
committed
Add TEI Serving
1 parent c9b55a4 commit 776e006

File tree

7 files changed

+168
-4
lines changed

7 files changed

+168
-4
lines changed

src/sagemaker/serve/builder/tei_builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,16 @@ def _prepare_for_mode(self):
7474
def _get_client_translators(self):
7575
"""Placeholder docstring"""
7676

77-
def _set_to_tgi(self):
77+
def _set_to_tei(self):
7878
"""Placeholder docstring"""
79-
if self.model_server != ModelServer.TGI:
79+
if self.model_server != ModelServer.TEI:
8080
messaging = (
8181
"HuggingFace Model ID support on model server: "
8282
f"{self.model_server} is not currently supported. "
83-
f"Defaulting to {ModelServer.TGI}"
83+
f"Defaulting to {ModelServer.TEI}"
8484
)
8585
logger.warning(messaging)
86-
self.model_server = ModelServer.TGI
86+
self.model_server = ModelServer.TEI
8787

8888
def _create_tei_model(self, **kwargs) -> Type[Model]:
8989
"""Placeholder docstring"""

src/sagemaker/serve/mode/local_container_mode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
2222
from sagemaker.serve.model_server.triton.server import LocalTritonServer
2323
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
24+
from sagemaker.serve.model_server.tei.server import LocalTeiServing
2425
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
2526
from sagemaker.session import Session
2627

@@ -41,6 +42,7 @@ class LocalContainerMode(
4142
LocalTgiServing,
4243
LocalMultiModelServer,
4344
LocalTensorflowServing,
45+
LocalTeiServing,
4446
):
4547
"""A class that holds methods to deploy model to a container in local environment"""
4648

src/sagemaker/serve/model_server/tei/__init__.py

Whitespace-only changes.

src/sagemaker/serve/model_server/tei/prepare.py

Whitespace-only changes.
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Module for Local TEI Serving"""
2+
3+
from __future__ import absolute_import
4+
5+
import requests
6+
import logging
7+
from pathlib import Path
8+
from docker.types import DeviceRequest
9+
from sagemaker import Session, fw_utils
10+
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
11+
from sagemaker.base_predictor import PredictorBase
12+
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
13+
from sagemaker.s3 import S3Uploader
14+
from sagemaker.local.utils import get_docker_host
15+
16+
17+
MODE_DIR_BINDING = "/opt/ml/model/"
18+
_SHM_SIZE = "2G"
19+
_DEFAULT_ENV_VARS = {
20+
"TRANSFORMERS_CACHE": "/opt/ml/model/",
21+
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
22+
}
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class LocalTeiServing:
28+
"""LocalTeiServing class"""
29+
30+
def _start_tei_serving(
31+
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
32+
):
33+
"""Starts a local tei serving container.
34+
35+
Args:
36+
client: Docker client
37+
image: Image to use
38+
model_path: Path to the model
39+
secret_key: Secret key to use for authentication
40+
env_vars: Environment variables to set
41+
"""
42+
if env_vars and secret_key:
43+
env_vars['SAGEMAKER_SERVE_SECRET_KEY'] = secret_key
44+
45+
self.container = client.containers.run(
46+
image,
47+
shm_size=_SHM_SIZE,
48+
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
49+
network_mode="host",
50+
detach=True,
51+
auto_remove=True,
52+
volumes={
53+
Path(model_path).joinpath("code"): {
54+
"bind": MODE_DIR_BINDING,
55+
"mode": "rw",
56+
},
57+
},
58+
environment=_update_env_vars(env_vars),
59+
)
60+
61+
def _invoke_tei_serving(self, request: object, content_type: str, accept: str):
62+
"""Invokes a local tei serving container.
63+
64+
Args:
65+
request: Request to send
66+
content_type: Content type to use
67+
accept: Accept to use
68+
"""
69+
try:
70+
response = requests.post(
71+
f"http://{get_docker_host()}:8080/invocations",
72+
data=request,
73+
headers={"Content-Type": content_type, "Accept": accept},
74+
timeout=600,
75+
)
76+
response.raise_for_status()
77+
return response.content
78+
except Exception as e:
79+
raise Exception("Unable to send request to the local container server") from e
80+
81+
def _tei_deep_ping(self, predictor: PredictorBase):
82+
"""Checks if the local tei serving container is up and running.
83+
84+
If the container is not up and running, it will raise an exception.
85+
"""
86+
response = None
87+
try:
88+
response = predictor.predict(self.schema_builder.sample_input)
89+
return (True, response)
90+
# pylint: disable=broad-except
91+
except Exception as e:
92+
if "422 Client Error: Unprocessable Entity for url" in str(e):
93+
raise LocalModelInvocationException(str(e))
94+
return (False, response)
95+
96+
return (True, response)
97+
98+
99+
class SageMakerTeiServing:
100+
"""SageMakerTeiServing class"""
101+
102+
def _upload_tei_artifacts(
103+
self,
104+
model_path: str,
105+
sagemaker_session: Session,
106+
s3_model_data_url: str = None,
107+
image: str = None,
108+
env_vars: dict = None,
109+
):
110+
"""Uploads the model artifacts to S3.
111+
112+
Args:
113+
model_path: Path to the model
114+
sagemaker_session: SageMaker session
115+
s3_model_data_url: S3 model data URL
116+
image: Image to use
117+
env_vars: Environment variables to set
118+
"""
119+
if s3_model_data_url:
120+
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
121+
else:
122+
bucket, key_prefix = None, None
123+
124+
code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
125+
126+
bucket, code_key_prefix = determine_bucket_and_prefix(
127+
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
128+
)
129+
130+
code_dir = Path(model_path).joinpath("code")
131+
132+
s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code")
133+
134+
logger.debug("Uploading TGI Model Resources uncompressed to: %s", s3_location)
135+
136+
model_data_url = S3Uploader.upload(
137+
str(code_dir),
138+
s3_location,
139+
None,
140+
sagemaker_session,
141+
)
142+
143+
model_data = {
144+
"S3DataSource": {
145+
"CompressionType": "None",
146+
"S3DataType": "S3Prefix",
147+
"S3Uri": model_data_url + "/",
148+
}
149+
}
150+
151+
return (model_data, _update_env_vars(env_vars))
152+
153+
154+
def _update_env_vars(env_vars: dict) -> dict:
155+
"""Placeholder docstring"""
156+
updated_env_vars = {}
157+
updated_env_vars.update(_DEFAULT_ENV_VARS)
158+
if env_vars:
159+
updated_env_vars.update(env_vars)
160+
return updated_env_vars
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""TEI ModelBuilder Utils"""

src/sagemaker/serve/utils/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __str__(self):
1818
DJL_SERVING = 4
1919
TRITON = 5
2020
TGI = 6
21+
TEI = 7
2122

2223

2324
class _DjlEngine(Enum):

0 commit comments

Comments
 (0)