Skip to content

Commit 6561856

Browse files
committed
feature: Adding serial inference pipeline support to RegisterModel Step
1 parent c1aa201 commit 6561856

File tree

4 files changed

+89
-4
lines changed

4 files changed

+89
-4
lines changed

src/sagemaker/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def _get_model_package_args(
195195
marketplace_cert=False,
196196
approval_status=None,
197197
description=None,
198+
container_def_list=None,
198199
):
199200
"""Get arguments for session.create_model_package method.
200201
@@ -219,6 +220,7 @@ def _get_model_package_args(
219220
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
220221
or "PendingManualApproval" (default: "PendingManualApproval").
221222
description (str): Model Package description (default: None).
223+
container_def_list (list): A list of container defintiions.
222224
Returns:
223225
dict: A dictionary of method argument names and values.
224226
"""
@@ -229,8 +231,13 @@ def _get_model_package_args(
229231
"ModelDataUrl": self.model_data,
230232
}
231233

234+
if container_def_list is not None:
235+
containers = container_def_list
236+
else:
237+
containers = [container]
238+
232239
model_package_args = {
233-
"containers": [container],
240+
"containers": containers,
234241
"content_types": content_types,
235242
"response_types": response_types,
236243
"inference_instances": inference_instances,

src/sagemaker/workflow/_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,19 @@ class _RegisterModelStep(Step):
205205
Estimator's training container image will be used (default: None).
206206
compile_model_family (str): Instance family for compiled model, if specified, a compiled
207207
model will be used (default: None).
208+
container_def_list (list): A list of container defintiions.
208209
**kwargs: additional arguments to `create_model`.
209210
"""
210211

211212
def __init__(
212213
self,
213214
name: str,
214215
estimator: EstimatorBase,
215-
model_data,
216216
content_types,
217217
response_types,
218218
inference_instances,
219219
transform_instances,
220+
model_data=None,
220221
model_package_group_name=None,
221222
model_metrics=None,
222223
metadata_properties=None,
@@ -225,6 +226,7 @@ def __init__(
225226
compile_model_family=None,
226227
description=None,
227228
depends_on: List[str] = None,
229+
container_def_list=None,
228230
**kwargs,
229231
):
230232
"""Constructor of a register model step.
@@ -254,6 +256,7 @@ def __init__(
254256
description (str): Model Package description (default: None).
255257
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
256258
depends on
259+
container_def_list (list): A list of container defintiions.
257260
**kwargs: additional arguments to `create_model`.
258261
"""
259262
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)
@@ -270,6 +273,7 @@ def __init__(
270273
self.image_uri = image_uri
271274
self.compile_model_family = compile_model_family
272275
self.description = description
276+
self.container_def_list = container_def_list
273277
self.kwargs = kwargs
274278

275279
self._properties = Properties(
@@ -301,7 +305,7 @@ def arguments(self) -> RequestType:
301305
self.estimator.output_path = output_path
302306

303307
# yeah, there is some framework stuff going on that we need to pull in here
304-
if model.image_uri is None:
308+
if model.image_uri is None and self.container_def_list is None:
305309
region_name = self.estimator.sagemaker_session.boto_session.region_name
306310
model.image_uri = image_uris.retrieve(
307311
model._framework_name,
@@ -324,6 +328,7 @@ def arguments(self) -> RequestType:
324328
metadata_properties=self.metadata_properties,
325329
approval_status=self.approval_status,
326330
description=self.description,
331+
container_def_list=self.container_def_list,
327332
)
328333
request_dict = model.sagemaker_session._get_create_model_package_request(
329334
**model_package_args

src/sagemaker/workflow/step_collections.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,19 @@ def __init__(
5555
self,
5656
name: str,
5757
estimator: EstimatorBase,
58-
model_data,
5958
content_types,
6059
response_types,
6160
inference_instances,
6261
transform_instances,
62+
model_data=None,
6363
depends_on: List[str] = None,
6464
model_package_group_name=None,
6565
model_metrics=None,
6666
approval_status=None,
6767
image_uri=None,
6868
compile_model_family=None,
6969
description=None,
70+
container_def_list=None,
7071
**kwargs,
7172
):
7273
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -94,6 +95,7 @@ def __init__(
9495
compile_model_family (str): The instance family for the compiled model. If
9596
specified, a compiled model is used (default: None).
9697
description (str): Model Package description (default: None).
98+
container_def_list (list): A list of container defintiions.
9799
**kwargs: additional arguments to `create_model`.
98100
"""
99101
steps: List[Step] = []
@@ -134,6 +136,7 @@ def __init__(
134136
image_uri=image_uri,
135137
compile_model_family=compile_model_family,
136138
description=description,
139+
container_def_list=container_def_list,
137140
**kwargs,
138141
)
139142
if not repack_model:

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,76 @@ def test_register_model_tf(estimator_tf, model_metrics):
266266
)
267267

268268

269+
def test_register_model_sip(estimator, model_metrics):
270+
container_def_list = [
271+
{
272+
"Image": "fakeimage1",
273+
"ModelDataUrl": "Url1",
274+
"Environment": [{"k1": "v1"}, {"k2": "v2"}],
275+
},
276+
{
277+
"Image": "fakeimage2",
278+
"ModelDataUrl": "Url2",
279+
"Environment": [{"k3": "v3"}, {"k4": "v4"}],
280+
},
281+
]
282+
283+
register_model = RegisterModel(
284+
name="RegisterModelStep",
285+
estimator=estimator,
286+
content_types=["content_type"],
287+
response_types=["response_type"],
288+
inference_instances=["inference_instance"],
289+
transform_instances=["transform_instance"],
290+
model_package_group_name="mpg",
291+
model_metrics=model_metrics,
292+
approval_status="Approved",
293+
description="description",
294+
container_def_list=container_def_list,
295+
depends_on=["TestStep"],
296+
)
297+
assert ordered(register_model.request_dicts()) == ordered(
298+
[
299+
{
300+
"Name": "RegisterModelStep",
301+
"Type": "RegisterModel",
302+
"DependsOn": ["TestStep"],
303+
"Arguments": {
304+
"InferenceSpecification": {
305+
"Containers": [
306+
{
307+
"Image": "fakeimage1",
308+
"ModelDataUrl": "Url1",
309+
"Environment": [{"k1": "v1"}, {"k2": "v2"}],
310+
},
311+
{
312+
"Image": "fakeimage2",
313+
"ModelDataUrl": "Url2",
314+
"Environment": [{"k3": "v3"}, {"k4": "v4"}],
315+
},
316+
],
317+
"SupportedContentTypes": ["content_type"],
318+
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
319+
"SupportedResponseMIMETypes": ["response_type"],
320+
"SupportedTransformInstanceTypes": ["transform_instance"],
321+
},
322+
"ModelApprovalStatus": "Approved",
323+
"ModelMetrics": {
324+
"ModelQuality": {
325+
"Statistics": {
326+
"ContentType": "text/csv",
327+
"S3Uri": f"s3://{BUCKET}/metrics.csv",
328+
},
329+
},
330+
},
331+
"ModelPackageDescription": "description",
332+
"ModelPackageGroupName": "mpg",
333+
},
334+
},
335+
]
336+
)
337+
338+
269339
def test_register_model_with_model_repack(estimator, model_metrics):
270340
model_data = f"s3://{BUCKET}/model.tar.gz"
271341
register_model = RegisterModel(

0 commit comments

Comments
 (0)