Skip to content

Commit d360659

Browse files
committed
Add missing unit tests for s3 upload
1 parent eac9d22 commit d360659

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _create_args(
144144
# If pipeline definition is large, upload to S3 bucket and
145145
# provide PipelineDefinitionS3Location to request instead.
146146
if len(pipeline_definition.encode("utf-8")) < 1024 * 100:
147-
kwargs["PipelineDefinition"] = self.definition()
147+
kwargs["PipelineDefinition"] = pipeline_definition
148148
else:
149149
desired_s3_uri = s3.s3_path_join(
150150
"s3://", self.sagemaker_session.default_bucket(), self.name

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121

2222
from mock import Mock
2323

24+
from sagemaker import s3
2425
from sagemaker.workflow.execution_variables import ExecutionVariables
2526
from sagemaker.workflow.parameters import ParameterString
2627
from sagemaker.workflow.pipeline import Pipeline
28+
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
2729
from sagemaker.workflow.pipeline_experiment_config import (
2830
PipelineExperimentConfig,
2931
PipelineExperimentConfigProperties,
@@ -62,7 +64,9 @@ def role_arn():
6264

6365
@pytest.fixture
6466
def sagemaker_session_mock():
65-
return Mock()
67+
session_mock = Mock()
68+
session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket")
69+
return session_mock
6670

6771

6872
def test_pipeline_create(sagemaker_session_mock, role_arn):
@@ -78,6 +82,47 @@ def test_pipeline_create(sagemaker_session_mock, role_arn):
7882
)
7983

8084

85+
def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_arn):
86+
pipeline = Pipeline(
87+
name="MyPipeline",
88+
parameters=[],
89+
steps=[],
90+
pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10),
91+
sagemaker_session=sagemaker_session_mock,
92+
)
93+
pipeline.create(role_arn=role_arn)
94+
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
95+
PipelineName="MyPipeline",
96+
PipelineDefinition=pipeline.definition(),
97+
RoleArn=role_arn,
98+
ParallelismConfiguration={"MaxParallelExecutionSteps": 10},
99+
)
100+
101+
102+
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
103+
parameter = ParameterString("MyStr")
104+
pipeline = Pipeline(
105+
name="MyPipeline",
106+
parameters=[parameter],
107+
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
108+
sagemaker_session=sagemaker_session_mock,
109+
)
110+
111+
s3.S3Uploader.upload_string_as_file_body = Mock()
112+
113+
pipeline.create(role_arn=role_arn)
114+
115+
assert s3.S3Uploader.upload_string_as_file_body.called_with(
116+
body=pipeline.definition(), s3_uri="s3://s3_bucket/MyPipeline"
117+
)
118+
119+
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
120+
PipelineName="MyPipeline",
121+
PipelineDefinitionS3Location={"Bucket": "s3_bucket", "ObjectKey": "MyPipeline"},
122+
RoleArn=role_arn,
123+
)
124+
125+
81126
def test_pipeline_update(sagemaker_session_mock, role_arn):
82127
pipeline = Pipeline(
83128
name="MyPipeline",
@@ -91,6 +136,47 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
91136
)
92137

93138

139+
def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_arn):
140+
pipeline = Pipeline(
141+
name="MyPipeline",
142+
parameters=[],
143+
steps=[],
144+
pipeline_experiment_config=ParallelismConfiguration(max_parallel_execution_steps=10),
145+
sagemaker_session=sagemaker_session_mock,
146+
)
147+
pipeline.create(role_arn=role_arn)
148+
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
149+
PipelineName="MyPipeline",
150+
PipelineDefinition=pipeline.definition(),
151+
RoleArn=role_arn,
152+
ParallelismConfiguration={"MaxParallelExecutionSteps": 10},
153+
)
154+
155+
156+
def test_large_pipeline_update(sagemaker_session_mock, role_arn):
157+
parameter = ParameterString("MyStr")
158+
pipeline = Pipeline(
159+
name="MyPipeline",
160+
parameters=[parameter],
161+
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
162+
sagemaker_session=sagemaker_session_mock,
163+
)
164+
165+
s3.S3Uploader.upload_string_as_file_body = Mock()
166+
167+
pipeline.create(role_arn=role_arn)
168+
169+
assert s3.S3Uploader.upload_string_as_file_body.called_with(
170+
body=pipeline.definition(), s3_uri="s3://s3_bucket/MyPipeline"
171+
)
172+
173+
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
174+
PipelineName="MyPipeline",
175+
PipelineDefinitionS3Location={"Bucket": "s3_bucket", "ObjectKey": "MyPipeline"},
176+
RoleArn=role_arn,
177+
)
178+
179+
94180
def test_pipeline_upsert(sagemaker_session_mock, role_arn):
95181
sagemaker_session_mock.side_effect = [
96182
ClientError(

0 commit comments

Comments
 (0)