Skip to content

Commit 80e391e

Browse files
committed
Fix for local tests
1 parent d95235f commit 80e391e

File tree

4 files changed

+10
-7
lines changed

4 files changed

+10
-7
lines changed

src/sagemaker/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2521,8 +2521,10 @@ def _create_model_request(
25212521
request = {"ModelName": name, "ExecutionRoleArn": role}
25222522
if isinstance(container_definition, list):
25232523
request["Containers"] = container_definition
2524-
else:
2524+
elif container_definition["ModelPackageName"] is not None:
25252525
request["Containers"] = [container_definition]
2526+
else:
2527+
request["PrimaryContainer"] = container_definition
25262528

25272529
if tags:
25282530
request["Tags"] = tags

src/sagemaker/workflow/step_collections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(
139139
sagemaker_session = estimator.sagemaker_session
140140
role = estimator.role
141141
else:
142-
sagemaker_session = model.sagemaker_session
142+
sagemaker_session = pipeline_model.sagemaker_session or model.sagemaker_session
143143
role = pipeline_model.role
144144
if hasattr(model, "entry_point"):
145145
repack_model = True

tests/integ/test_workflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,8 +1019,7 @@ def test_mxnet_model_registration(
10191019
py_version="py3",
10201020
sagemaker_session=sagemaker_session,
10211021
)
1022-
1023-
pipeline_model = PipelineModel([model],role)
1022+
pipeline_model = PipelineModel([model], role, sagemaker_session)
10241023

10251024
step_register = RegisterModel(
10261025
name="mxnet-register-model",
@@ -1205,8 +1204,10 @@ def test_sklearn_xgboost_sip_model_registration(
12051204
role=role,
12061205
sagemaker_session=sagemaker_session,
12071206
)
1208-
1209-
pipeline_model = PipelineModel([xgboost_model, sklearn_model],role)
1207+
1208+
pipeline_model = PipelineModel(
1209+
[xgboost_model, sklearn_model], role, sagemaker_session=sagemaker_session
1210+
)
12101211

12111212
step_register = RegisterModel(
12121213
name="AbaloneRegisterModel",

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_register_model_sip(estimator, model_metrics):
280280
Model(image_uri="fakeimage1", model_data="Url1", env=[{"k1": "v1"}, {"k2": "v2"}]),
281281
Model(image_uri="fakeimage2", model_data="Url2", env=[{"k3": "v3"}, {"k4": "v4"}]),
282282
]
283-
283+
284284
pipeline_model = PipelineModel(model_list, ROLE)
285285

286286
register_model = RegisterModel(

0 commit comments

Comments
 (0)