Skip to content

Commit af5fde2

Browse files
author
Jonathan Makunga
committed
Refactoring
1 parent 2044a7e commit af5fde2

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
169169
in order for model builder to build the artifacts correctly (according
170170
to the model server). Possible values for this argument are
171171
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
172-
``TRITON``, and``TGI``.
172+
``TRITON``,``TGI``, and ``TEI``.
173173
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
174174
Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for
175175
new models without task metadata in the Hub, adding unsupported task types will throw

src/sagemaker/serve/mode/local_container_mode.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class LocalContainerMode(
4242
LocalTgiServing,
4343
LocalMultiModelServer,
4444
LocalTensorflowServing,
45-
LocalTeiServing,
4645
):
4746
"""A class that holds methods to deploy model to a container in local environment"""
4847

@@ -72,6 +71,8 @@ def __init__(
7271
self.secret_key = None
7372
self._ping_container = None
7473

74+
self._tei_serving = LocalTeiServing()
75+
7576
def load(self, model_path: str = None):
7677
"""Placeholder docstring"""
7778
path = Path(model_path if model_path else self.model_path)
@@ -159,14 +160,14 @@ def create_server(
159160
)
160161
self._ping_container = self._tensorflow_serving_deep_ping
161162
elif self.model_server == ModelServer.TEI:
162-
self._start_tei_serving(
163+
self._tei_serving._start_tei_serving(
163164
client=self.client,
164165
image=image,
165166
model_path=model_path if model_path else self.model_path,
166167
secret_key=secret_key,
167168
env_vars=env_vars if env_vars else self.env_vars,
168169
)
169-
self._ping_container = self._tei_deep_ping
170+
self._ping_container = self._tei_serving._tei_deep_ping
170171

171172
# allow some time for container to be ready
172173
time.sleep(10)

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class SageMakerEndpointMode(
2727
SageMakerTgiServing,
2828
SageMakerMultiModelServer,
2929
SageMakerTensorflowServing,
30-
SageMakerTeiServing,
3130
):
3231
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
3332

@@ -39,6 +38,8 @@ def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServe
3938
self.inference_spec = inference_spec
4039
self.model_server = model_server
4140

41+
self._tei_serving = SageMakerTeiServing()
42+
4243
def load(self, model_path: str):
4344
"""Placeholder docstring"""
4445
path = Path(model_path)
@@ -122,7 +123,7 @@ def prepare(
122123
)
123124

124125
if self.model_server == ModelServer.TEI:
125-
upload_artifacts = self._upload_tei_artifacts(
126+
upload_artifacts = self._tei_serving._upload_tei_artifacts(
126127
model_path=model_path,
127128
sagemaker_session=sagemaker_session,
128129
s3_model_data_url=s3_model_data_url,

0 commit comments

Comments
 (0)