21
21
22
22
from mock import Mock
23
23
24
+ from sagemaker import s3
24
25
from sagemaker .workflow .execution_variables import ExecutionVariables
25
26
from sagemaker .workflow .parameters import ParameterString
26
27
from sagemaker .workflow .pipeline import Pipeline
28
+ from sagemaker .workflow .parallelism_config import ParallelismConfiguration
27
29
from sagemaker .workflow .pipeline_experiment_config import (
28
30
PipelineExperimentConfig ,
29
31
PipelineExperimentConfigProperties ,
@@ -62,7 +64,9 @@ def role_arn():
62
64
63
65
@pytest .fixture
64
66
def sagemaker_session_mock ():
65
- return Mock ()
67
+ session_mock = Mock ()
68
+ session_mock .default_bucket = Mock (name = "default_bucket" , return_value = "s3_bucket" )
69
+ return session_mock
66
70
67
71
68
72
def test_pipeline_create (sagemaker_session_mock , role_arn ):
@@ -78,6 +82,47 @@ def test_pipeline_create(sagemaker_session_mock, role_arn):
78
82
)
79
83
80
84
85
+ def test_pipeline_create_with_parallelism_config (sagemaker_session_mock , role_arn ):
86
+ pipeline = Pipeline (
87
+ name = "MyPipeline" ,
88
+ parameters = [],
89
+ steps = [],
90
+ pipeline_experiment_config = ParallelismConfiguration (max_parallel_execution_steps = 10 ),
91
+ sagemaker_session = sagemaker_session_mock ,
92
+ )
93
+ pipeline .create (role_arn = role_arn )
94
+ assert sagemaker_session_mock .sagemaker_client .create_pipeline .called_with (
95
+ PipelineName = "MyPipeline" ,
96
+ PipelineDefinition = pipeline .definition (),
97
+ RoleArn = role_arn ,
98
+ ParallelismConfiguration = {"MaxParallelExecutionSteps" : 10 },
99
+ )
100
+
101
+
102
+ def test_large_pipeline_create (sagemaker_session_mock , role_arn ):
103
+ parameter = ParameterString ("MyStr" )
104
+ pipeline = Pipeline (
105
+ name = "MyPipeline" ,
106
+ parameters = [parameter ],
107
+ steps = [CustomStep (name = "MyStep" , input_data = parameter )] * 2000 ,
108
+ sagemaker_session = sagemaker_session_mock ,
109
+ )
110
+
111
+ s3 .S3Uploader .upload_string_as_file_body = Mock ()
112
+
113
+ pipeline .create (role_arn = role_arn )
114
+
115
+ assert s3 .S3Uploader .upload_string_as_file_body .called_with (
116
+ body = pipeline .definition (), s3_uri = "s3://s3_bucket/MyPipeline"
117
+ )
118
+
119
+ assert sagemaker_session_mock .sagemaker_client .create_pipeline .called_with (
120
+ PipelineName = "MyPipeline" ,
121
+ PipelineDefinitionS3Location = {"Bucket" : "s3_bucket" , "ObjectKey" : "MyPipeline" },
122
+ RoleArn = role_arn ,
123
+ )
124
+
125
+
81
126
def test_pipeline_update (sagemaker_session_mock , role_arn ):
82
127
pipeline = Pipeline (
83
128
name = "MyPipeline" ,
@@ -91,6 +136,47 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
91
136
)
92
137
93
138
139
+ def test_pipeline_update_with_parallelism_config (sagemaker_session_mock , role_arn ):
140
+ pipeline = Pipeline (
141
+ name = "MyPipeline" ,
142
+ parameters = [],
143
+ steps = [],
144
+ pipeline_experiment_config = ParallelismConfiguration (max_parallel_execution_steps = 10 ),
145
+ sagemaker_session = sagemaker_session_mock ,
146
+ )
147
+ pipeline .create (role_arn = role_arn )
148
+ assert sagemaker_session_mock .sagemaker_client .update_pipeline .called_with (
149
+ PipelineName = "MyPipeline" ,
150
+ PipelineDefinition = pipeline .definition (),
151
+ RoleArn = role_arn ,
152
+ ParallelismConfiguration = {"MaxParallelExecutionSteps" : 10 },
153
+ )
154
+
155
+
156
+ def test_large_pipeline_update (sagemaker_session_mock , role_arn ):
157
+ parameter = ParameterString ("MyStr" )
158
+ pipeline = Pipeline (
159
+ name = "MyPipeline" ,
160
+ parameters = [parameter ],
161
+ steps = [CustomStep (name = "MyStep" , input_data = parameter )] * 2000 ,
162
+ sagemaker_session = sagemaker_session_mock ,
163
+ )
164
+
165
+ s3 .S3Uploader .upload_string_as_file_body = Mock ()
166
+
167
+ pipeline .create (role_arn = role_arn )
168
+
169
+ assert s3 .S3Uploader .upload_string_as_file_body .called_with (
170
+ body = pipeline .definition (), s3_uri = "s3://s3_bucket/MyPipeline"
171
+ )
172
+
173
+ assert sagemaker_session_mock .sagemaker_client .update_pipeline .called_with (
174
+ PipelineName = "MyPipeline" ,
175
+ PipelineDefinitionS3Location = {"Bucket" : "s3_bucket" , "ObjectKey" : "MyPipeline" },
176
+ RoleArn = role_arn ,
177
+ )
178
+
179
+
94
180
def test_pipeline_upsert (sagemaker_session_mock , role_arn ):
95
181
sagemaker_session_mock .side_effect = [
96
182
ClientError (
0 commit comments