Skip to content

Commit c2d5e61

Browse files
committed
Updating the input to model or pipeline model
1 parent 6e12e41 commit c2d5e61

File tree

3 files changed

+25
-19
lines changed

3 files changed

+25
-19
lines changed

src/sagemaker/workflow/step_collections.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker.estimator import EstimatorBase
2121
from sagemaker.model import Model
22+
from sagemaker import PipelineModel
2223
from sagemaker.predictor import Predictor
2324
from sagemaker.transformer import Transformer
2425
from sagemaker.workflow.entities import RequestType
@@ -68,7 +69,7 @@ def __init__(
6869
compile_model_family=None,
6970
description=None,
7071
tags=None,
71-
pipeline_model=None,
72+
model=None,
7273
**kwargs,
7374
):
7475
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -100,8 +101,8 @@ def __init__(
100101
that tags will only be applied to newly created model package groups; if the
101102
name of an existing group is passed to "model_package_group_name",
102103
tags will not be applied.
103-
pipeline_model (object): A PipelineModel object that comprises a list of models
104-
which gets executed as a serial inference pipeline.
104+
model (object or Model): A PipelineModel object that comprises a list of models
105+
which gets executed as a serial inference pipeline or a Model object.
105106
**kwargs: additional arguments to `create_model`.
106107
"""
107108
steps: List[Step] = []
@@ -134,33 +135,39 @@ def __init__(
134135
kwargs.pop("dependencies", None)
135136
kwargs.pop("output_kms_key", None)
136137

137-
if pipeline_model is not None:
138-
self.model_list = pipeline_model.models
139-
for model in pipeline_model.models:
138+
if model is not None:
139+
if isinstance(model, PipelineModel):
140+
self.model_list = model.models
141+
elif isinstance(model, Model):
142+
self.model_list = [model]
143+
144+
for model_entity in self.model_list:
140145
if estimator is not None:
141146
sagemaker_session = estimator.sagemaker_session
142147
role = estimator.role
143148
else:
144-
sagemaker_session = pipeline_model.sagemaker_session or model.sagemaker_session
145-
role = pipeline_model.role
146-
if hasattr(model, "entry_point"):
149+
sagemaker_session = model_entity.sagemaker_session
150+
role = model_entity.role
151+
if hasattr(model_entity, "entry_point"):
147152
repack_model = True
148-
entry_point = model.entry_point
149-
source_dir = model.source_dir
150-
dependencies = model.dependencies
151-
name = model.name or model._framework_name
153+
entry_point = model_entity.entry_point
154+
source_dir = model_entity.source_dir
155+
dependencies = model_entity.dependencies
156+
name = model_entity.name or model_entity._framework_name
152157
repack_model_step = _RepackModelStep(
153158
name=f"{name}RepackModel",
154159
depends_on=depends_on,
155160
sagemaker_session=sagemaker_session,
156161
role=role,
157-
model_data=model.model_data,
162+
model_data=model_entity.model_data,
158163
entry_point=entry_point,
159164
source_dir=source_dir,
160165
dependencies=dependencies,
161166
)
162167
steps.append(repack_model_step)
163-
model.model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
168+
model_entity.model_data = (
169+
repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
170+
)
164171

165172
register_model_step = _RegisterModelStep(
166173
name=name,

tests/integ/test_workflow.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,11 +1019,10 @@ def test_mxnet_model_registration(
10191019
py_version="py3",
10201020
sagemaker_session=sagemaker_session,
10211021
)
1022-
pipeline_model = PipelineModel([model], role, sagemaker_session)
10231022

10241023
step_register = RegisterModel(
10251024
name="mxnet-register-model",
1026-
pipeline_model=pipeline_model,
1025+
model=model,
10271026
content_types=["*"],
10281027
response_types=["*"],
10291028
inference_instances=["ml.m5.xlarge"],
@@ -1211,7 +1210,7 @@ def test_sklearn_xgboost_sip_model_registration(
12111210

12121211
step_register = RegisterModel(
12131212
name="AbaloneRegisterModel",
1214-
pipeline_model=pipeline_model,
1213+
model=pipeline_model,
12151214
content_types=["application/json"],
12161215
response_types=["application/json"],
12171216
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def test_register_model_sip(estimator, model_metrics):
294294
model_metrics=model_metrics,
295295
approval_status="Approved",
296296
description="description",
297-
pipeline_model=pipeline_model,
297+
model=pipeline_model,
298298
depends_on=["TestStep"],
299299
)
300300
assert ordered(register_model.request_dicts()) == ordered(

0 commit comments

Comments
 (0)