Skip to content

Commit e5b551b

Browse files
committed
addressed PR comments for UT's and added default value None to missing parameter
1 parent bf054de commit e5b551b

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

src/sagemaker/pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
self.enable_network_isolation = enable_network_isolation
8585
self.endpoint_name = None
8686

87-
def pipeline_container_def(self, instance_type):
87+
def pipeline_container_def(self, instance_type=None):
8888
"""The pipeline definition for deploying this model.
8989
9090
This is the dict created by ``sagemaker.pipeline_container_def()``.
@@ -321,7 +321,10 @@ def register(
321321
)
322322
else:
323323
container_def = [
324-
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
324+
{
325+
"Image": image_uri or model.image_uri,
326+
"ModelDataUrl": model.model_data,
327+
}
325328
for model in self.models
326329
]
327330

tests/unit/sagemaker/workflow/test_pipeline_session.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock):
124124
assert len(register_step_args.need_runtime_repack) == 0
125125

126126

127-
def test_pipeline_session_context_for_model_step_without_instance_types(pipeline_session_mock):
127+
def test_pipeline_session_context_for_model_step_without_instance_types(
128+
pipeline_session_mock,
129+
):
128130
model = Model(
129131
name="MyModel",
130132
image_uri="fakeimage",
@@ -134,27 +136,38 @@ def test_pipeline_session_context_for_model_step_without_instance_types(pipeline
134136
source_dir=f"{DATA_DIR}",
135137
role=_ROLE,
136138
)
137-
# CreateModelStep requires runtime repack
138-
create_step_args = model.create(
139-
instance_type="c4.4xlarge",
140-
accelerator_type="ml.eia1.medium",
141-
)
142-
# The context should be cleaned up before return
143-
assert pipeline_session_mock.context is None
144-
assert create_step_args.create_model_request
145-
assert not create_step_args.create_model_package_request
146-
assert len(create_step_args.need_runtime_repack) == 1
147139

148-
# _RegisterModelStep does not require runtime repack
149-
model.entry_point = None
150-
model.source_dir = None
151140
register_step_args = model.register(
152141
content_types=["text/csv"],
153142
response_types=["text/csv"],
154143
model_package_group_name="MyModelPackageGroup",
155144
)
156-
# The context should be cleaned up before return
157-
assert not pipeline_session_mock.context
158-
assert not register_step_args.create_model_request
159-
assert register_step_args.create_model_package_request
160-
assert len(register_step_args.need_runtime_repack) == 0
145+
146+
expected_output = {
147+
"ModelPackageGroupName": "MyModelPackageGroup",
148+
"InferenceSpecification": {
149+
"Containers": [
150+
{
151+
"Image": "fakeimage",
152+
"Environment": {
153+
"SAGEMAKER_PROGRAM": "dummy_script.py",
154+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
155+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
156+
"SAGEMAKER_REGION": "us-west-2",
157+
},
158+
"ModelDataUrl": ParameterString(
159+
name="ModelData",
160+
default_value="s3://my-bucket/file",
161+
),
162+
}
163+
],
164+
"SupportedContentTypes": ["text/csv"],
165+
"SupportedResponseMIMETypes": ["text/csv"],
166+
"SupportedRealtimeInferenceInstanceTypes": None,
167+
"SupportedTransformInstanceTypes": None,
168+
},
169+
"CertifyForMarketplace": False,
170+
"ModelApprovalStatus": "PendingManualApproval",
171+
}
172+
173+
assert register_step_args.create_model_package_request == expected_output

0 commit comments

Comments
 (0)