Skip to content

Commit 87ccb28

Browse files
authored
Merge branch 'master' into doc
2 parents d46aee9 + 7d56242 commit 87ccb28

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

src/sagemaker/workflow/step_collections.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def __init__(
112112
if "entry_point" in kwargs:
113113
repack_model = True
114114
entry_point = kwargs.pop("entry_point", None)
115-
source_dir = kwargs.get("source_dir")
116-
dependencies = kwargs.get("dependencies")
115+
source_dir = kwargs.pop("source_dir", None)
116+
dependencies = kwargs.pop("dependencies", None)
117117
kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None))
118118

119119
repack_model_step = _RepackModelStep(
@@ -130,13 +130,10 @@ def __init__(
130130
steps.append(repack_model_step)
131131
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
132132

133-
# remove kwargs consumed by model repacking step
134-
kwargs.pop("entry_point", None)
135-
kwargs.pop("source_dir", None)
136-
kwargs.pop("dependencies", None)
137-
kwargs.pop("output_kms_key", None)
133+
# remove kwargs consumed by model repacking step
134+
kwargs.pop("output_kms_key", None)
138135

139-
if model is not None:
136+
elif model is not None:
140137
if isinstance(model, PipelineModel):
141138
self.model_list = model.models
142139
self.container_def_list = model.pipeline_container_def(inference_instances[0])
@@ -156,7 +153,9 @@ def __init__(
156153
entry_point = model_entity.entry_point
157154
source_dir = model_entity.source_dir
158155
dependencies = model_entity.dependencies
156+
kwargs = dict(**kwargs, output_kms_key=model_entity.model_kms_key)
159157
name = model_entity.name or model_entity._framework_name
158+
160159
repack_model_step = _RepackModelStep(
161160
name=f"{name}RepackModel",
162161
depends_on=depends_on,
@@ -166,12 +165,16 @@ def __init__(
166165
entry_point=entry_point,
167166
source_dir=source_dir,
168167
dependencies=dependencies,
168+
**kwargs,
169169
)
170170
steps.append(repack_model_step)
171171
model_entity.model_data = (
172172
repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
173173
)
174174

175+
# remove kwargs consumed by model repacking step
176+
kwargs.pop("output_kms_key", None)
177+
175178
register_model_step = _RegisterModelStep(
176179
name=name,
177180
estimator=estimator,

0 commit comments

Comments
 (0)