@@ -136,10 +136,8 @@ def __init__(
136
136
elif model is not None :
137
137
if isinstance (model , PipelineModel ):
138
138
self .model_list = model .models
139
- self .container_def_list = model .pipeline_container_def (inference_instances [0 ])
140
139
elif isinstance (model , Model ):
141
140
self .model_list = [model ]
142
- self .container_def_list = [model .prepare_container_def (inference_instances [0 ])]
143
141
144
142
for model_entity in self .model_list :
145
143
if estimator is not None :
@@ -154,10 +152,10 @@ def __init__(
154
152
source_dir = model_entity .source_dir
155
153
dependencies = model_entity .dependencies
156
154
kwargs = dict (** kwargs , output_kms_key = model_entity .model_kms_key )
157
- name = model_entity .name or model_entity ._framework_name
155
+ model_name = model_entity .name or model_entity ._framework_name
158
156
159
157
repack_model_step = _RepackModelStep (
160
- name = f"{ name } RepackModel" ,
158
+ name = f"{ model_name } RepackModel" ,
161
159
depends_on = depends_on ,
162
160
sagemaker_session = sagemaker_session ,
163
161
role = role ,
@@ -171,10 +169,14 @@ def __init__(
171
169
model_entity .model_data = (
172
170
repack_model_step .properties .ModelArtifacts .S3ModelArtifacts
173
171
)
174
-
175
172
# remove kwargs consumed by model repacking step
176
173
kwargs .pop ("output_kms_key" , None )
177
174
175
+ if isinstance (model , PipelineModel ):
176
+ self .container_def_list = model .pipeline_container_def (inference_instances [0 ])
177
+ elif isinstance (model , Model ):
178
+ self .container_def_list = [model .prepare_container_def (inference_instances [0 ])]
179
+
178
180
register_model_step = _RegisterModelStep (
179
181
name = name ,
180
182
estimator = estimator ,
0 commit comments