Skip to content

Commit 89ab163

Browse files
authored
Merge branch 'master' into doc
2 parents 975e1e5 + 9fc4e3e commit 89ab163

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/sagemaker/workflow/step_collections.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,8 @@ def __init__(
136136
elif model is not None:
137137
if isinstance(model, PipelineModel):
138138
self.model_list = model.models
139-
self.container_def_list = model.pipeline_container_def(inference_instances[0])
140139
elif isinstance(model, Model):
141140
self.model_list = [model]
142-
self.container_def_list = [model.prepare_container_def(inference_instances[0])]
143141

144142
for model_entity in self.model_list:
145143
if estimator is not None:
@@ -154,10 +152,10 @@ def __init__(
154152
source_dir = model_entity.source_dir
155153
dependencies = model_entity.dependencies
156154
kwargs = dict(**kwargs, output_kms_key=model_entity.model_kms_key)
157-
name = model_entity.name or model_entity._framework_name
155+
model_name = model_entity.name or model_entity._framework_name
158156

159157
repack_model_step = _RepackModelStep(
160-
name=f"{name}RepackModel",
158+
name=f"{model_name}RepackModel",
161159
depends_on=depends_on,
162160
sagemaker_session=sagemaker_session,
163161
role=role,
@@ -171,10 +169,14 @@ def __init__(
171169
model_entity.model_data = (
172170
repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
173171
)
174-
175172
# remove kwargs consumed by model repacking step
176173
kwargs.pop("output_kms_key", None)
177174

175+
if isinstance(model, PipelineModel):
176+
self.container_def_list = model.pipeline_container_def(inference_instances[0])
177+
elif isinstance(model, Model):
178+
self.container_def_list = [model.prepare_container_def(inference_instances[0])]
179+
178180
register_model_step = _RegisterModelStep(
179181
name=name,
180182
estimator=estimator,

0 commit comments

Comments
 (0)