Skip to content

Commit d95235f

Browse files
committed
Updating the input to RegisterModelStep to Pipeline Model
1 parent 5261a2f commit d95235f

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

src/sagemaker/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2522,7 +2522,8 @@ def _create_model_request(
25222522
if isinstance(container_definition, list):
25232523
request["Containers"] = container_definition
25242524
else:
2525-
request["PrimaryContainer"] = container_definition
2525+
request["Containers"] = [container_definition]
2526+
25262527
if tags:
25272528
request["Tags"] = tags
25282529

src/sagemaker/workflow/step_collections.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
compile_model_family=None,
6969
description=None,
7070
tags=None,
71-
models=None,
71+
pipeline_model=None,
7272
**kwargs,
7373
):
7474
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -100,7 +100,8 @@ def __init__(
100100
that tags will only be applied to newly created model package groups; if the
101101
name of an existing group is passed to "model_package_group_name",
102102
tags will not be applied.
103-
models (list): A list of models.
103+
pipeline_model (object): A PipelineModel object that comprises a list of models
104+
which gets executed as a serial inference pipeline.
104105
**kwargs: additional arguments to `create_model`.
105106
"""
106107
steps: List[Step] = []
@@ -132,14 +133,14 @@ def __init__(
132133
kwargs.pop("dependencies", None)
133134
kwargs.pop("output_kms_key", None)
134135

135-
if models is not None:
136-
for model in models:
136+
if pipeline_model is not None:
137+
for model in pipeline_model.models:
137138
if estimator is not None:
138139
sagemaker_session = estimator.sagemaker_session
139140
role = estimator.role
140141
else:
141142
sagemaker_session = model.sagemaker_session
142-
role = model.role
143+
role = pipeline_model.role
143144
if hasattr(model, "entry_point"):
144145
repack_model = True
145146
entry_point = model.entry_point
@@ -174,7 +175,7 @@ def __init__(
174175
compile_model_family=compile_model_family,
175176
description=description,
176177
tags=tags,
177-
model_list=models,
178+
model_list=pipeline_model.models,
178179
**kwargs,
179180
)
180181
if not repack_model:

tests/integ/test_workflow.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from sagemaker import image_uris
3737
from sagemaker.inputs import CreateModelInput, TrainingInput
3838
from sagemaker.model import Model
39+
from sagemaker import PipelineModel
3940
from sagemaker.processing import ProcessingInput, ProcessingOutput, FeatureStoreOutput
4041
from sagemaker.pytorch.estimator import PyTorch
4142
from sagemaker.tuner import HyperparameterTuner, IntegerParameter
@@ -1018,10 +1019,12 @@ def test_mxnet_model_registration(
10181019
py_version="py3",
10191020
sagemaker_session=sagemaker_session,
10201021
)
1022+
1023+
pipeline_model = PipelineModel([model],role)
10211024

10221025
step_register = RegisterModel(
10231026
name="mxnet-register-model",
1024-
models=[model],
1027+
pipeline_model=pipeline_model,
10251028
content_types=["*"],
10261029
response_types=["*"],
10271030
inference_instances=["ml.m5.xlarge"],
@@ -1202,10 +1205,12 @@ def test_sklearn_xgboost_sip_model_registration(
12021205
role=role,
12031206
sagemaker_session=sagemaker_session,
12041207
)
1208+
1209+
pipeline_model = PipelineModel([xgboost_model, sklearn_model],role)
12051210

12061211
step_register = RegisterModel(
12071212
name="AbaloneRegisterModel",
1208-
models=[xgboost_model, sklearn_model],
1213+
pipeline_model=pipeline_model,
12091214
content_types=["application/json"],
12101215
response_types=["application/json"],
12111216
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
PropertyMock,
2828
)
2929

30+
from sagemaker import PipelineModel
3031
from sagemaker.estimator import Estimator
3132
from sagemaker.model import Model
3233
from sagemaker.tensorflow import TensorFlow
@@ -279,6 +280,8 @@ def test_register_model_sip(estimator, model_metrics):
279280
Model(image_uri="fakeimage1", model_data="Url1", env=[{"k1": "v1"}, {"k2": "v2"}]),
280281
Model(image_uri="fakeimage2", model_data="Url2", env=[{"k3": "v3"}, {"k4": "v4"}]),
281282
]
283+
284+
pipeline_model = PipelineModel(model_list, ROLE)
282285

283286
register_model = RegisterModel(
284287
name="RegisterModelStep",
@@ -291,7 +294,7 @@ def test_register_model_sip(estimator, model_metrics):
291294
model_metrics=model_metrics,
292295
approval_status="Approved",
293296
description="description",
294-
models=model_list,
297+
pipeline_model=pipeline_model,
295298
depends_on=["TestStep"],
296299
)
297300
assert ordered(register_model.request_dicts()) == ordered(

0 commit comments

Comments
 (0)