Skip to content

Commit 0a0ce3c

Browse files
committed
feature: Adding serial inference pipeline support to RegisterModel Step
1 parent c1aa201 commit 0a0ce3c

File tree

4 files changed

+79
-5
lines changed

4 files changed

+79
-5
lines changed

src/sagemaker/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def _get_model_package_args(
195195
marketplace_cert=False,
196196
approval_status=None,
197197
description=None,
198+
container_def_list=None,
198199
):
199200
"""Get arguments for session.create_model_package method.
200201
@@ -219,6 +220,7 @@ def _get_model_package_args(
219220
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
220221
or "PendingManualApproval" (default: "PendingManualApproval").
221222
description (str): Model Package description (default: None).
223+
container_def_list (list): A list of container defintiions.
222224
Returns:
223225
dict: A dictionary of method argument names and values.
224226
"""
@@ -229,8 +231,13 @@ def _get_model_package_args(
229231
"ModelDataUrl": self.model_data,
230232
}
231233

234+
if container_def_list is not None:
235+
containers = container_def_list
236+
else:
237+
containers = [container]
238+
232239
model_package_args = {
233-
"containers": [container],
240+
"containers": containers,
234241
"content_types": content_types,
235242
"response_types": response_types,
236243
"inference_instances": inference_instances,

src/sagemaker/workflow/_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,19 @@ class _RegisterModelStep(Step):
205205
Estimator's training container image will be used (default: None).
206206
compile_model_family (str): Instance family for compiled model, if specified, a compiled
207207
model will be used (default: None).
208+
container_def_list (list): A list of container defintiions.
208209
**kwargs: additional arguments to `create_model`.
209210
"""
210211

211212
def __init__(
212213
self,
213214
name: str,
214215
estimator: EstimatorBase,
215-
model_data,
216216
content_types,
217217
response_types,
218218
inference_instances,
219219
transform_instances,
220+
model_data=None,
220221
model_package_group_name=None,
221222
model_metrics=None,
222223
metadata_properties=None,
@@ -225,6 +226,7 @@ def __init__(
225226
compile_model_family=None,
226227
description=None,
227228
depends_on: List[str] = None,
229+
container_def_list=None,
228230
**kwargs,
229231
):
230232
"""Constructor of a register model step.
@@ -254,6 +256,7 @@ def __init__(
254256
description (str): Model Package description (default: None).
255257
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
256258
depends on
259+
container_def_list (list): A list of container defintiions.
257260
**kwargs: additional arguments to `create_model`.
258261
"""
259262
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)
@@ -270,6 +273,7 @@ def __init__(
270273
self.image_uri = image_uri
271274
self.compile_model_family = compile_model_family
272275
self.description = description
276+
self.container_def_list = container_def_list
273277
self.kwargs = kwargs
274278

275279
self._properties = Properties(
@@ -301,7 +305,7 @@ def arguments(self) -> RequestType:
301305
self.estimator.output_path = output_path
302306

303307
# yeah, there is some framework stuff going on that we need to pull in here
304-
if model.image_uri is None:
308+
if model.image_uri is None and model.container_def_list is None:
305309
region_name = self.estimator.sagemaker_session.boto_session.region_name
306310
model.image_uri = image_uris.retrieve(
307311
model._framework_name,
@@ -324,6 +328,7 @@ def arguments(self) -> RequestType:
324328
metadata_properties=self.metadata_properties,
325329
approval_status=self.approval_status,
326330
description=self.description,
331+
container_def_list=self.container_def_list
327332
)
328333
request_dict = model.sagemaker_session._get_create_model_package_request(
329334
**model_package_args

src/sagemaker/workflow/step_collections.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,19 @@ def __init__(
5555
self,
5656
name: str,
5757
estimator: EstimatorBase,
58-
model_data,
5958
content_types,
6059
response_types,
6160
inference_instances,
6261
transform_instances,
62+
model_data=None,
6363
depends_on: List[str] = None,
6464
model_package_group_name=None,
6565
model_metrics=None,
6666
approval_status=None,
6767
image_uri=None,
6868
compile_model_family=None,
6969
description=None,
70+
container_def_list=None,
7071
**kwargs,
7172
):
7273
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -94,6 +95,7 @@ def __init__(
9495
compile_model_family (str): The instance family for the compiled model. If
9596
specified, a compiled model is used (default: None).
9697
description (str): Model Package description (default: None).
98+
container_def_list (list): A list of container defintiions.
9799
**kwargs: additional arguments to `create_model`.
98100
"""
99101
steps: List[Step] = []
@@ -134,6 +136,7 @@ def __init__(
134136
image_uri=image_uri,
135137
compile_model_family=compile_model_family,
136138
description=description,
139+
container_def_list=container_def_list,
137140
**kwargs,
138141
)
139142
if not repack_model:

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ def test_register_model(estimator, model_metrics):
215215
]
216216
)
217217

218-
219218
def test_register_model_tf(estimator_tf, model_metrics):
220219
model_data = f"s3://{BUCKET}/model.tar.gz"
221220
register_model = RegisterModel(
@@ -265,6 +264,66 @@ def test_register_model_tf(estimator_tf, model_metrics):
265264
]
266265
)
267266

267+
def test_register_model_sip(estimator, model_metrics):
268+
container_def_list = [
269+
{
270+
"Image":"fakeimage1", "ModelDataUrl":"Url1",
271+
"Environment": [{"k1": "v1"}, {"k2": "v2"}]
272+
},
273+
{
274+
"Image":"fakeimage2", "ModelDataUrl":"Url2",
275+
"Environment": [{"k3": "v3"}, {"k4": "v4"}]
276+
}
277+
]
278+
279+
register_model = RegisterModel(
280+
name="RegisterModelStep",
281+
estimator=estimator,
282+
content_types=["content_type"],
283+
response_types=["response_type"],
284+
inference_instances=["inference_instance"],
285+
transform_instances=["transform_instance"],
286+
model_package_group_name="mpg",
287+
model_metrics=model_metrics,
288+
approval_status="Approved",
289+
description="description",
290+
container_def_list=container_def_list,
291+
depends_on=["TestStep"],
292+
)
293+
assert ordered(register_model.request_dicts()) == ordered(
294+
[
295+
{
296+
"Name": "RegisterModelStep",
297+
"Type": "RegisterModel",
298+
"DependsOn": ["TestStep"],
299+
"Arguments": {
300+
"InferenceSpecification": {
301+
"Containers": [
302+
{"Image": "fakeimage1", "ModelDataUrl": "Url1",
303+
"Environment":[{"k1":"v1"},{"k2": "v2"}] },
304+
{"Image": "fakeimage2", "ModelDataUrl": "Url2",
305+
"Environment":[{"k3":"v3"},{"k4": "v4"}] }
306+
],
307+
"SupportedContentTypes": ["content_type"],
308+
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
309+
"SupportedResponseMIMETypes": ["response_type"],
310+
"SupportedTransformInstanceTypes": ["transform_instance"],
311+
},
312+
"ModelApprovalStatus": "Approved",
313+
"ModelMetrics": {
314+
"ModelQuality": {
315+
"Statistics": {
316+
"ContentType": "text/csv",
317+
"S3Uri": f"s3://{BUCKET}/metrics.csv",
318+
},
319+
},
320+
},
321+
"ModelPackageDescription": "description",
322+
"ModelPackageGroupName": "mpg",
323+
},
324+
},
325+
]
326+
)
268327

269328
def test_register_model_with_model_repack(estimator, model_metrics):
270329
model_data = f"s3://{BUCKET}/model.tar.gz"

0 commit comments

Comments
 (0)