11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
+ import json
14
15
15
16
import pytest
16
17
from mock import Mock , MagicMock
17
- from sagemaker .workflow .conditions import ConditionEquals
18
- from sagemaker .workflow .parameters import ParameterInteger
18
+ from sagemaker .workflow .conditions import (
19
+ ConditionEquals ,
20
+ ConditionGreaterThan ,
21
+ ConditionGreaterThanOrEqualTo ,
22
+ ConditionIn , ConditionLessThan ,
23
+ ConditionLessThanOrEqualTo ,
24
+ ConditionNot , ConditionOr )
25
+ from sagemaker .workflow .execution_variables import ExecutionVariables
26
+ from sagemaker .workflow .parameters import ParameterInteger , ParameterString
19
27
from sagemaker .workflow .condition_step import ConditionStep
20
28
from sagemaker .workflow .pipeline import Pipeline , PipelineGraph
29
+ from sagemaker .workflow .properties import Properties
21
30
from tests .unit .sagemaker .workflow .helpers import CustomStep , ordered
22
31
23
32
@@ -35,7 +44,7 @@ def sagemaker_session():
35
44
return session_mock
36
45
37
46
38
- def test_condition_step ():
47
+ def test_condition_step (sagemaker_session ):
39
48
param = ParameterInteger (name = "MyInt" )
40
49
cond = ConditionEquals (left = param , right = 1 )
41
50
step1 = CustomStep (name = "MyStep1" )
@@ -56,7 +65,7 @@ def test_condition_step():
56
65
"Conditions" : [
57
66
{
58
67
"Type" : "Equals" ,
59
- "LeftValue" : { "Get" : "Parameters.MyInt" } ,
68
+ "LeftValue" : param ,
60
69
"RightValue" : 1 ,
61
70
},
62
71
],
@@ -79,6 +88,169 @@ def test_condition_step():
79
88
assert cond_step .properties .Outcome .expr == {"Get" : "Steps.MyConditionStep.Outcome" }
80
89
81
90
91
+ def test_pipeline_condition_step_interpolated (sagemaker_session ):
92
+ param1 = ParameterInteger (name = "MyInt1" )
93
+ param2 = ParameterInteger (name = "MyInt2" )
94
+ param3 = ParameterString (name = "MyStr" )
95
+ var = ExecutionVariables .START_DATETIME
96
+ prop = Properties ("foo" )
97
+
98
+ cond_eq = ConditionEquals (left = param1 , right = param2 )
99
+ cond_gt = ConditionGreaterThan (left = var , right = "2020-12-01" )
100
+ cond_gte = ConditionGreaterThanOrEqualTo (left = var , right = param3 )
101
+ cond_lt = ConditionLessThan (left = var , right = "2020-12-01" )
102
+ cond_lte = ConditionLessThanOrEqualTo (left = var , right = param3 )
103
+ cond_in = ConditionIn (value = param3 , in_values = ["abc" , "def" ])
104
+ cond_in_mixed = ConditionIn (value = param3 , in_values = ["abc" , prop , var ])
105
+ cond_not_eq = ConditionNot (expression = cond_eq )
106
+ cond_not_in = ConditionNot (expression = cond_in )
107
+ cond_or = ConditionOr (conditions = [cond_gt , cond_in ])
108
+
109
+ step1 = CustomStep (name = "MyStep1" )
110
+ step2 = CustomStep (name = "MyStep2" )
111
+ cond_step = ConditionStep (
112
+ name = "MyConditionStep" ,
113
+ conditions = [
114
+ cond_eq , cond_gt , cond_gte , cond_lt , cond_lte , cond_in , cond_in_mixed , cond_not_eq , cond_not_in , cond_or
115
+ ],
116
+ if_steps = [step1 ],
117
+ else_steps = [step2 ],
118
+ )
119
+
120
+ pipeline = Pipeline (
121
+ name = "MyPipeline" ,
122
+ parameters = [param1 , param2 , param3 ],
123
+ steps = [cond_step ],
124
+ sagemaker_session = sagemaker_session ,
125
+ )
126
+ assert json .loads (pipeline .definition ()) == {
127
+ "Version" : "2020-12-01" ,
128
+ "Metadata" : {},
129
+ "Parameters" : [{
130
+ "Name" : "MyInt1" ,
131
+ "Type" : "Integer"
132
+ }, {
133
+ "Name" : "MyInt2" ,
134
+ "Type" : "Integer"
135
+ }, {
136
+ "Name" : "MyStr" ,
137
+ "Type" : "String"
138
+ }],
139
+ "PipelineExperimentConfig" : {
140
+ "ExperimentName" : {
141
+ "Get" : "Execution.PipelineName"
142
+ },
143
+ "TrialName" : {
144
+ "Get" : "Execution.PipelineExecutionId"
145
+ }
146
+ },
147
+ "Steps" : [{
148
+ "Name" : "MyConditionStep" ,
149
+ "Type" : "Condition" ,
150
+ "Arguments" : {
151
+ "Conditions" : [{
152
+ "Type" : "Equals" ,
153
+ "LeftValue" : {
154
+ "Get" : "Parameters.MyInt1"
155
+ },
156
+ "RightValue" : {
157
+ "Get" : "Parameters.MyInt2"
158
+ }
159
+ }, {
160
+ "Type" : "GreaterThan" ,
161
+ "LeftValue" : {
162
+ "Get" : "Execution.StartDateTime"
163
+ },
164
+ "RightValue" : "2020-12-01"
165
+ }, {
166
+ "Type" : "GreaterThanOrEqualTo" ,
167
+ "LeftValue" : {
168
+ "Get" : "Execution.StartDateTime"
169
+ },
170
+ "RightValue" : {
171
+ "Get" : "Parameters.MyStr"
172
+ }
173
+ }, {
174
+ "Type" : "LessThan" ,
175
+ "LeftValue" : {
176
+ "Get" : "Execution.StartDateTime"
177
+ },
178
+ "RightValue" : "2020-12-01"
179
+ }, {
180
+ "Type" : "LessThanOrEqualTo" ,
181
+ "LeftValue" : {
182
+ "Get" : "Execution.StartDateTime"
183
+ },
184
+ "RightValue" : {
185
+ "Get" : "Parameters.MyStr"
186
+ }
187
+ }, {
188
+ "Type" : "In" ,
189
+ "QueryValue" : {
190
+ "Get" : "Parameters.MyStr"
191
+ },
192
+ "Values" : ["abc" , "def" ]
193
+ }, {
194
+ "Type" : "In" ,
195
+ "QueryValue" : {
196
+ "Get" : "Parameters.MyStr"
197
+ },
198
+ "Values" : ["abc" , {
199
+ "Get" : "Steps.foo"
200
+ }, {
201
+ "Get" : "Execution.StartDateTime"
202
+ }]
203
+ }, {
204
+ "Type" : "Not" ,
205
+ "Expression" : {
206
+ "Type" : "Equals" ,
207
+ "LeftValue" : {
208
+ "Get" : "Parameters.MyInt1"
209
+ },
210
+ "RightValue" : {
211
+ "Get" : "Parameters.MyInt2"
212
+ }
213
+ }
214
+ }, {
215
+ "Type" : "Not" ,
216
+ "Expression" : {
217
+ "Type" : "In" ,
218
+ "QueryValue" : {
219
+ "Get" : "Parameters.MyStr"
220
+ },
221
+ "Values" : ["abc" , "def" ]
222
+ }
223
+ }, {
224
+ "Type" : "Or" ,
225
+ "Conditions" : [{
226
+ "Type" : "GreaterThan" ,
227
+ "LeftValue" : {
228
+ "Get" : "Execution.StartDateTime"
229
+ },
230
+ "RightValue" : "2020-12-01"
231
+ }, {
232
+ "Type" : "In" ,
233
+ "QueryValue" : {
234
+ "Get" : "Parameters.MyStr"
235
+ },
236
+ "Values" : ["abc" , "def" ]
237
+ }]
238
+ }],
239
+ "IfSteps" : [{
240
+ "Name" : "MyStep1" ,
241
+ "Type" : "Training" ,
242
+ "Arguments" : {}
243
+ }],
244
+ "ElseSteps" : [{
245
+ "Name" : "MyStep2" ,
246
+ "Type" : "Training" ,
247
+ "Arguments" : {}
248
+ }]
249
+ }
250
+ }]
251
+ }
252
+
253
+
82
254
def test_pipeline (sagemaker_session ):
83
255
param = ParameterInteger (name = "MyInt" , default_value = 2 )
84
256
cond = ConditionEquals (left = param , right = 1 )
0 commit comments