11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
+ import json
14
15
15
16
import pytest
16
17
from mock import Mock , MagicMock
17
- from sagemaker .workflow .conditions import ConditionEquals
18
- from sagemaker .workflow .parameters import ParameterInteger
18
+ from sagemaker .workflow .conditions import (
19
+ ConditionEquals ,
20
+ ConditionGreaterThan ,
21
+ ConditionGreaterThanOrEqualTo ,
22
+ ConditionIn ,
23
+ ConditionLessThan ,
24
+ ConditionLessThanOrEqualTo ,
25
+ ConditionNot ,
26
+ ConditionOr ,
27
+ )
28
+ from sagemaker .workflow .execution_variables import ExecutionVariables
29
+ from sagemaker .workflow .parameters import ParameterInteger , ParameterString
19
30
from sagemaker .workflow .condition_step import ConditionStep
20
31
from sagemaker .workflow .pipeline import Pipeline , PipelineGraph
32
+ from sagemaker .workflow .properties import Properties
21
33
from tests .unit .sagemaker .workflow .helpers import CustomStep , ordered
22
34
23
35
@@ -56,7 +68,7 @@ def test_condition_step():
56
68
"Conditions" : [
57
69
{
58
70
"Type" : "Equals" ,
59
- "LeftValue" : { "Get" : "Parameters.MyInt" } ,
71
+ "LeftValue" : param ,
60
72
"RightValue" : 1 ,
61
73
},
62
74
],
@@ -79,6 +91,147 @@ def test_condition_step():
79
91
assert cond_step .properties .Outcome .expr == {"Get" : "Steps.MyConditionStep.Outcome" }
80
92
81
93
94
+ def test_pipeline_condition_step_interpolated (sagemaker_session ):
95
+ param1 = ParameterInteger (name = "MyInt1" )
96
+ param2 = ParameterInteger (name = "MyInt2" )
97
+ param3 = ParameterString (name = "MyStr" )
98
+ var = ExecutionVariables .START_DATETIME
99
+ prop = Properties ("foo" )
100
+
101
+ cond_eq = ConditionEquals (left = param1 , right = param2 )
102
+ cond_gt = ConditionGreaterThan (left = var , right = "2020-12-01" )
103
+ cond_gte = ConditionGreaterThanOrEqualTo (left = var , right = param3 )
104
+ cond_lt = ConditionLessThan (left = var , right = "2020-12-01" )
105
+ cond_lte = ConditionLessThanOrEqualTo (left = var , right = param3 )
106
+ cond_in = ConditionIn (value = param3 , in_values = ["abc" , "def" ])
107
+ cond_in_mixed = ConditionIn (value = param3 , in_values = ["abc" , prop , var ])
108
+ cond_not_eq = ConditionNot (expression = cond_eq )
109
+ cond_not_in = ConditionNot (expression = cond_in )
110
+ cond_or = ConditionOr (conditions = [cond_gt , cond_in ])
111
+
112
+ step1 = CustomStep (name = "MyStep1" )
113
+ step2 = CustomStep (name = "MyStep2" )
114
+ cond_step = ConditionStep (
115
+ name = "MyConditionStep" ,
116
+ conditions = [
117
+ cond_eq ,
118
+ cond_gt ,
119
+ cond_gte ,
120
+ cond_lt ,
121
+ cond_lte ,
122
+ cond_in ,
123
+ cond_in_mixed ,
124
+ cond_not_eq ,
125
+ cond_not_in ,
126
+ cond_or ,
127
+ ],
128
+ if_steps = [step1 ],
129
+ else_steps = [step2 ],
130
+ )
131
+
132
+ pipeline = Pipeline (
133
+ name = "MyPipeline" ,
134
+ parameters = [param1 , param2 , param3 ],
135
+ steps = [cond_step ],
136
+ sagemaker_session = sagemaker_session ,
137
+ )
138
+ assert json .loads (pipeline .definition ()) == {
139
+ "Version" : "2020-12-01" ,
140
+ "Metadata" : {},
141
+ "Parameters" : [
142
+ {"Name" : "MyInt1" , "Type" : "Integer" },
143
+ {"Name" : "MyInt2" , "Type" : "Integer" },
144
+ {"Name" : "MyStr" , "Type" : "String" },
145
+ ],
146
+ "PipelineExperimentConfig" : {
147
+ "ExperimentName" : {"Get" : "Execution.PipelineName" },
148
+ "TrialName" : {"Get" : "Execution.PipelineExecutionId" },
149
+ },
150
+ "Steps" : [
151
+ {
152
+ "Name" : "MyConditionStep" ,
153
+ "Type" : "Condition" ,
154
+ "Arguments" : {
155
+ "Conditions" : [
156
+ {
157
+ "Type" : "Equals" ,
158
+ "LeftValue" : {"Get" : "Parameters.MyInt1" },
159
+ "RightValue" : {"Get" : "Parameters.MyInt2" },
160
+ },
161
+ {
162
+ "Type" : "GreaterThan" ,
163
+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
164
+ "RightValue" : "2020-12-01" ,
165
+ },
166
+ {
167
+ "Type" : "GreaterThanOrEqualTo" ,
168
+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
169
+ "RightValue" : {"Get" : "Parameters.MyStr" },
170
+ },
171
+ {
172
+ "Type" : "LessThan" ,
173
+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
174
+ "RightValue" : "2020-12-01" ,
175
+ },
176
+ {
177
+ "Type" : "LessThanOrEqualTo" ,
178
+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
179
+ "RightValue" : {"Get" : "Parameters.MyStr" },
180
+ },
181
+ {
182
+ "Type" : "In" ,
183
+ "QueryValue" : {"Get" : "Parameters.MyStr" },
184
+ "Values" : ["abc" , "def" ],
185
+ },
186
+ {
187
+ "Type" : "In" ,
188
+ "QueryValue" : {"Get" : "Parameters.MyStr" },
189
+ "Values" : [
190
+ "abc" ,
191
+ {"Get" : "Steps.foo" },
192
+ {"Get" : "Execution.StartDateTime" },
193
+ ],
194
+ },
195
+ {
196
+ "Type" : "Not" ,
197
+ "Expression" : {
198
+ "Type" : "Equals" ,
199
+ "LeftValue" : {"Get" : "Parameters.MyInt1" },
200
+ "RightValue" : {"Get" : "Parameters.MyInt2" },
201
+ },
202
+ },
203
+ {
204
+ "Type" : "Not" ,
205
+ "Expression" : {
206
+ "Type" : "In" ,
207
+ "QueryValue" : {"Get" : "Parameters.MyStr" },
208
+ "Values" : ["abc" , "def" ],
209
+ },
210
+ },
211
+ {
212
+ "Type" : "Or" ,
213
+ "Conditions" : [
214
+ {
215
+ "Type" : "GreaterThan" ,
216
+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
217
+ "RightValue" : "2020-12-01" ,
218
+ },
219
+ {
220
+ "Type" : "In" ,
221
+ "QueryValue" : {"Get" : "Parameters.MyStr" },
222
+ "Values" : ["abc" , "def" ],
223
+ },
224
+ ],
225
+ },
226
+ ],
227
+ "IfSteps" : [{"Name" : "MyStep1" , "Type" : "Training" , "Arguments" : {}}],
228
+ "ElseSteps" : [{"Name" : "MyStep2" , "Type" : "Training" , "Arguments" : {}}],
229
+ },
230
+ }
231
+ ],
232
+ }
233
+
234
+
82
235
def test_pipeline (sagemaker_session ):
83
236
param = ParameterInteger (name = "MyInt" , default_value = 2 )
84
237
cond = ConditionEquals (left = param , right = 1 )
0 commit comments