Skip to content

Commit 70c9d4f

Browse files
author
Xiong Zeng
committed
Add model_metadata field to ModelBuilder
1 parent 86b6da6 commit 70c9d4f

File tree

3 files changed

+29
-18
lines changed

3 files changed

+29
-18
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
118118
into a stream. All translations between the server and the client are handled
119119
automatically with the specified input and output.
120120
model (Optional[Union[object, str]): Model object (with ``predict`` method to perform
121-
inference) or a HuggingFace/JumpStart Model ID (followed by ``:task`` if you need
122-
to override the task, e.g. bert-base-uncased:fill-mask). Either ``model`` or
123-
``inference_spec`` is required for the model builder to build the artifact.
121+
inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec``
122+
is required for the model builder to build the artifact.
124123
inference_spec (InferenceSpec): The inference spec file with your customized
125124
``invoke`` and ``load`` functions.
126125
image_uri (Optional[str]): The container image uri (which is derived from a
@@ -140,6 +139,8 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
140139
to the model server). Possible values for this argument are
141140
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
142141
``TRITON``, and``TGI``.
142+
model_metadata (Optional[Dict[str, str]): Dictionary used to override the HuggingFace
143+
model metadata.
143144
"""
144145

145146
model_path: Optional[str] = field(
@@ -206,7 +207,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
206207
"help": (
207208
'Model object with "predict" method to perform inference '
208209
"or HuggingFace/JumpStart Model ID"
209-
"or if you need to override task, provide input as ModelID:Task"
210210
)
211211
},
212212
)
@@ -237,6 +237,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
237237
model_server: Optional[ModelServer] = field(
238238
default=None, metadata={"help": "Define the model server to deploy to."}
239239
)
240+
model_metadata: Optional[Dict[str, str]] = field(
241+
default=None, metadata={"help": "Define the model metadata to override"}
242+
)
240243

241244
def _build_validations(self):
242245
"""Placeholder docstring"""
@@ -613,9 +616,8 @@ def build(
613616

614617
if isinstance(self.model, str):
615618
model_task = None
616-
if ":" in self.model:
617-
model_task = self.model.split(":")[1]
618-
self.model = self.model.split(":")[0]
619+
if self.model_metadata:
620+
model_task = self.model_metadata.get("HF_TASK")
619621
if self._is_jumpstart_model_id():
620622
return self._build_for_jumpstart()
621623
if self._is_djl(): # pylint: disable=R1705

tests/integ/sagemaker/serve/test_schema_builder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_model_builder_negative_path(sagemaker_session):
115115
def test_model_builder_happy_path_with_task_provided(
116116
model_id, task_provided, sagemaker_session, gpu_instance_type
117117
):
118-
model_builder = ModelBuilder(model=f"{model_id}:{task_provided}")
118+
model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": task_provided})
119119

120120
model = model_builder.build(sagemaker_session=sagemaker_session)
121121

@@ -156,7 +156,9 @@ def test_model_builder_happy_path_with_task_provided(
156156

157157

158158
def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
159-
model_builder = ModelBuilder(model="bert-base-uncased:invalid-task")
159+
model_builder = ModelBuilder(
160+
model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"}
161+
)
160162

161163
with pytest.raises(
162164
TaskNotFoundException,

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,9 @@ def test_build_happy_path_override_with_task_provided(
11141114

11151115
mock_image_uris_retrieve.return_value = "https://some-image-uri"
11161116

1117-
model_builder = ModelBuilder(model="bert-base-uncased:text-generation")
1117+
model_builder = ModelBuilder(
1118+
model="bert-base-uncased", model_metadata={"HF_TASK": "text-generation"}
1119+
)
11181120
model_builder.build(sagemaker_session=mock_session)
11191121

11201122
self.assertIsNotNone(model_builder.schema_builder)
@@ -1157,11 +1159,14 @@ def test_build_task_override_with_invalid_task_provided(
11571159
mock_model_urllib.request.Request.side_effect = Mock()
11581160

11591161
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1160-
model_ids_with_invalid_task = ["bert-base-uncased:invalid-task", "bert-base-uncased:"]
1162+
model_ids_with_invalid_task = {
1163+
"bert-base-uncased": "invalid-task",
1164+
"bert-large-uncased-whole-word-masking-finetuned-squad": "",
1165+
}
11611166
for model_id in model_ids_with_invalid_task:
1162-
model_builder = ModelBuilder(model=model_id)
1167+
provided_task = model_ids_with_invalid_task[model_id]
1168+
model_builder = ModelBuilder(model=model_id, model_metadata={"HF_TASK": provided_task})
11631169

1164-
provided_task = model_id.split(":")[1]
11651170
self.assertRaisesRegex(
11661171
TaskNotFoundException,
11671172
f"Error Message: Schema builder for {provided_task} could not be found.",
@@ -1187,9 +1192,11 @@ def test_build_task_override_with_invalid_model_provided(
11871192
mock_model_uris_retrieve.side_effect = KeyError
11881193

11891194
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1190-
invalid_model_ids_with_task = [":fill-mask", "bert-base-uncased;fill-mask"]
1195+
invalid_model_id = ""
1196+
provided_task = "fill-mask"
11911197

1192-
for model_id in invalid_model_ids_with_task:
1193-
model_builder = ModelBuilder(model=model_id)
1194-
with self.assertRaises(Exception):
1195-
model_builder.build(sagemaker_session=mock_session)
1198+
model_builder = ModelBuilder(
1199+
model=invalid_model_id, model_metadata={"HF_TASK": provided_task}
1200+
)
1201+
with self.assertRaises(Exception):
1202+
model_builder.build(sagemaker_session=mock_session)

0 commit comments

Comments
 (0)