Skip to content

Commit 25a6a7c

Browse files
author
Ao Guo
committed
remove primitive_or_expr() from conditions
1 parent c70e30c commit 25a6a7c

File tree

4 files changed

+196
-40
lines changed

4 files changed

+196
-40
lines changed

src/sagemaker/workflow/condition_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def arguments(self) -> RequestType:
9393
@property
9494
def step_only_arguments(self):
9595
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
96-
return self.conditions
96+
return [condition.to_request() for condition in self.conditions]
9797

9898
@property
9999
def properties(self):

src/sagemaker/workflow/conditions.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def to_request(self) -> RequestType:
8888
"""Get the request structure for workflow service calls."""
8989
return {
9090
"Type": self.condition_type.value,
91-
"LeftValue": primitive_or_expr(self.left),
92-
"RightValue": primitive_or_expr(self.right),
91+
"LeftValue": self.left,
92+
"RightValue": self.right,
9393
}
9494

9595
@property
@@ -227,8 +227,8 @@ def to_request(self) -> RequestType:
227227
"""Get the request structure for workflow service calls."""
228228
return {
229229
"Type": self.condition_type.value,
230-
"QueryValue": self.value.expr,
231-
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
230+
"QueryValue": self.value,
231+
"Values": self.in_values,
232232
}
233233

234234
@property
@@ -291,19 +291,3 @@ def _referenced_steps(self) -> List[str]:
291291
for condition in self.conditions:
292292
steps.extend(condition._referenced_steps)
293293
return steps
294-
295-
296-
def primitive_or_expr(
297-
value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties]
298-
) -> Union[Dict[str, str], PrimitiveType]:
299-
"""Provide the expression of the value or return value if it is a primitive.
300-
301-
Args:
302-
value (Union[ConditionValueType, PrimitiveType]): The value to evaluate.
303-
304-
Returns:
305-
Either the expression of the value or the primitive value.
306-
"""
307-
if is_pipeline_variable(value):
308-
return value.expr
309-
return value

tests/unit/sagemaker/workflow/test_condition_step.py

Lines changed: 175 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,22 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import json
1415

1516
import pytest
1617
from mock import Mock, MagicMock
17-
from sagemaker.workflow.conditions import ConditionEquals
18-
from sagemaker.workflow.parameters import ParameterInteger
18+
from sagemaker.workflow.conditions import (
19+
ConditionEquals,
20+
ConditionGreaterThan,
21+
ConditionGreaterThanOrEqualTo,
22+
ConditionIn, ConditionLessThan,
23+
ConditionLessThanOrEqualTo,
24+
ConditionNot, ConditionOr)
25+
from sagemaker.workflow.execution_variables import ExecutionVariables
26+
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
1927
from sagemaker.workflow.condition_step import ConditionStep
2028
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
29+
from sagemaker.workflow.properties import Properties
2130
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered
2231

2332

@@ -56,7 +65,7 @@ def test_condition_step():
5665
"Conditions": [
5766
{
5867
"Type": "Equals",
59-
"LeftValue": {"Get": "Parameters.MyInt"},
68+
"LeftValue": param,
6069
"RightValue": 1,
6170
},
6271
],
@@ -79,6 +88,169 @@ def test_condition_step():
7988
assert cond_step.properties.Outcome.expr == {"Get": "Steps.MyConditionStep.Outcome"}
8089

8190

91+
def test_pipeline_condition_step_interpolated(sagemaker_session):
92+
param1 = ParameterInteger(name="MyInt1")
93+
param2 = ParameterInteger(name="MyInt2")
94+
param3 = ParameterString(name="MyStr")
95+
var = ExecutionVariables.START_DATETIME
96+
prop = Properties("foo")
97+
98+
cond_eq = ConditionEquals(left=param1, right=param2)
99+
cond_gt = ConditionGreaterThan(left=var, right="2020-12-01")
100+
cond_gte = ConditionGreaterThanOrEqualTo(left=var, right=param3)
101+
cond_lt = ConditionLessThan(left=var, right="2020-12-01")
102+
cond_lte = ConditionLessThanOrEqualTo(left=var, right=param3)
103+
cond_in = ConditionIn(value=param3, in_values=["abc", "def"])
104+
cond_in_mixed = ConditionIn(value=param3, in_values=["abc", prop, var])
105+
cond_not_eq = ConditionNot(expression=cond_eq)
106+
cond_not_in = ConditionNot(expression=cond_in)
107+
cond_or = ConditionOr(conditions=[cond_gt, cond_in])
108+
109+
step1 = CustomStep(name="MyStep1")
110+
step2 = CustomStep(name="MyStep2")
111+
cond_step = ConditionStep(
112+
name="MyConditionStep",
113+
conditions=[
114+
cond_eq, cond_gt, cond_gte, cond_lt, cond_lte, cond_in, cond_in_mixed, cond_not_eq, cond_not_in, cond_or
115+
],
116+
if_steps=[step1],
117+
else_steps=[step2],
118+
)
119+
120+
pipeline = Pipeline(
121+
name="MyPipeline",
122+
parameters=[param1, param2, param3],
123+
steps=[cond_step],
124+
sagemaker_session=sagemaker_session,
125+
)
126+
assert json.loads(pipeline.definition()) == {
127+
"Version": "2020-12-01",
128+
"Metadata": {},
129+
"Parameters": [{
130+
"Name": "MyInt1",
131+
"Type": "Integer"
132+
}, {
133+
"Name": "MyInt2",
134+
"Type": "Integer"
135+
}, {
136+
"Name": "MyStr",
137+
"Type": "String"
138+
}],
139+
"PipelineExperimentConfig": {
140+
"ExperimentName": {
141+
"Get": "Execution.PipelineName"
142+
},
143+
"TrialName": {
144+
"Get": "Execution.PipelineExecutionId"
145+
}
146+
},
147+
"Steps": [{
148+
"Name": "MyConditionStep",
149+
"Type": "Condition",
150+
"Arguments": {
151+
"Conditions": [{
152+
"Type": "Equals",
153+
"LeftValue": {
154+
"Get": "Parameters.MyInt1"
155+
},
156+
"RightValue": {
157+
"Get": "Parameters.MyInt2"
158+
}
159+
}, {
160+
"Type": "GreaterThan",
161+
"LeftValue": {
162+
"Get": "Execution.StartDateTime"
163+
},
164+
"RightValue": "2020-12-01"
165+
}, {
166+
"Type": "GreaterThanOrEqualTo",
167+
"LeftValue": {
168+
"Get": "Execution.StartDateTime"
169+
},
170+
"RightValue": {
171+
"Get": "Parameters.MyStr"
172+
}
173+
}, {
174+
"Type": "LessThan",
175+
"LeftValue": {
176+
"Get": "Execution.StartDateTime"
177+
},
178+
"RightValue": "2020-12-01"
179+
}, {
180+
"Type": "LessThanOrEqualTo",
181+
"LeftValue": {
182+
"Get": "Execution.StartDateTime"
183+
},
184+
"RightValue": {
185+
"Get": "Parameters.MyStr"
186+
}
187+
}, {
188+
"Type": "In",
189+
"QueryValue": {
190+
"Get": "Parameters.MyStr"
191+
},
192+
"Values": ["abc", "def"]
193+
}, {
194+
"Type": "In",
195+
"QueryValue": {
196+
"Get": "Parameters.MyStr"
197+
},
198+
"Values": ["abc", {
199+
"Get": "Steps.foo"
200+
}, {
201+
"Get": "Execution.StartDateTime"
202+
}]
203+
}, {
204+
"Type": "Not",
205+
"Expression": {
206+
"Type": "Equals",
207+
"LeftValue": {
208+
"Get": "Parameters.MyInt1"
209+
},
210+
"RightValue": {
211+
"Get": "Parameters.MyInt2"
212+
}
213+
}
214+
}, {
215+
"Type": "Not",
216+
"Expression": {
217+
"Type": "In",
218+
"QueryValue": {
219+
"Get": "Parameters.MyStr"
220+
},
221+
"Values": ["abc", "def"]
222+
}
223+
}, {
224+
"Type": "Or",
225+
"Conditions": [{
226+
"Type": "GreaterThan",
227+
"LeftValue": {
228+
"Get": "Execution.StartDateTime"
229+
},
230+
"RightValue": "2020-12-01"
231+
}, {
232+
"Type": "In",
233+
"QueryValue": {
234+
"Get": "Parameters.MyStr"
235+
},
236+
"Values": ["abc", "def"]
237+
}]
238+
}],
239+
"IfSteps": [{
240+
"Name": "MyStep1",
241+
"Type": "Training",
242+
"Arguments": {}
243+
}],
244+
"ElseSteps": [{
245+
"Name": "MyStep2",
246+
"Type": "Training",
247+
"Arguments": {}
248+
}]
249+
}
250+
}]
251+
}
252+
253+
82254
def test_pipeline(sagemaker_session):
83255
param = ParameterInteger(name="MyInt", default_value=2)
84256
cond = ConditionEquals(left=param, right=1)

tests/unit/sagemaker/workflow/test_conditions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_condition_equals():
3636
cond = ConditionEquals(left=param, right=1)
3737
assert cond.to_request() == {
3838
"Type": "Equals",
39-
"LeftValue": {"Get": "Parameters.MyInt"},
39+
"LeftValue": param,
4040
"RightValue": 1,
4141
}
4242

@@ -47,8 +47,8 @@ def test_condition_equals_parameter():
4747
cond = ConditionEquals(left=param1, right=param2)
4848
assert cond.to_request() == {
4949
"Type": "Equals",
50-
"LeftValue": {"Get": "Parameters.MyInt1"},
51-
"RightValue": {"Get": "Parameters.MyInt2"},
50+
"LeftValue": param1,
51+
"RightValue": param2,
5252
}
5353

5454

@@ -57,7 +57,7 @@ def test_condition_greater_than():
5757
cond = ConditionGreaterThan(left=var, right="2020-12-01")
5858
assert cond.to_request() == {
5959
"Type": "GreaterThan",
60-
"LeftValue": {"Get": "Execution.StartDateTime"},
60+
"LeftValue": var,
6161
"RightValue": "2020-12-01",
6262
}
6363

@@ -68,8 +68,8 @@ def test_condition_greater_than_or_equal_to():
6868
cond = ConditionGreaterThanOrEqualTo(left=var, right=param)
6969
assert cond.to_request() == {
7070
"Type": "GreaterThanOrEqualTo",
71-
"LeftValue": {"Get": "Execution.StartDateTime"},
72-
"RightValue": {"Get": "Parameters.StartDateTime"},
71+
"LeftValue": var,
72+
"RightValue": param,
7373
}
7474

7575

@@ -78,7 +78,7 @@ def test_condition_less_than():
7878
cond = ConditionLessThan(left=var, right="2020-12-01")
7979
assert cond.to_request() == {
8080
"Type": "LessThan",
81-
"LeftValue": {"Get": "Execution.StartDateTime"},
81+
"LeftValue": var,
8282
"RightValue": "2020-12-01",
8383
}
8484

@@ -89,8 +89,8 @@ def test_condition_less_than_or_equal_to():
8989
cond = ConditionLessThanOrEqualTo(left=var, right=param)
9090
assert cond.to_request() == {
9191
"Type": "LessThanOrEqualTo",
92-
"LeftValue": {"Get": "Execution.StartDateTime"},
93-
"RightValue": {"Get": "Parameters.StartDateTime"},
92+
"LeftValue": var,
93+
"RightValue": param,
9494
}
9595

9696

@@ -99,7 +99,7 @@ def test_condition_in():
9999
cond_in = ConditionIn(value=param, in_values=["abc", "def"])
100100
assert cond_in.to_request() == {
101101
"Type": "In",
102-
"QueryValue": {"Get": "Parameters.MyStr"},
102+
"QueryValue": param,
103103
"Values": ["abc", "def"],
104104
}
105105

@@ -111,8 +111,8 @@ def test_condition_in_mixed():
111111
cond_in = ConditionIn(value=param, in_values=["abc", prop, var])
112112
assert cond_in.to_request() == {
113113
"Type": "In",
114-
"QueryValue": {"Get": "Parameters.MyStr"},
115-
"Values": ["abc", {"Get": "Steps.foo"}, {"Get": "Execution.StartDateTime"}],
114+
"QueryValue": param,
115+
"Values": ["abc", prop, var],
116116
}
117117

118118

@@ -124,7 +124,7 @@ def test_condition_not():
124124
"Type": "Not",
125125
"Expression": {
126126
"Type": "Equals",
127-
"LeftValue": {"Get": "Parameters.MyStr"},
127+
"LeftValue": param,
128128
"RightValue": "foo",
129129
},
130130
}
@@ -138,7 +138,7 @@ def test_condition_not_in():
138138
"Type": "Not",
139139
"Expression": {
140140
"Type": "In",
141-
"QueryValue": {"Get": "Parameters.MyStr"},
141+
"QueryValue": param,
142142
"Values": ["abc", "def"],
143143
},
144144
}
@@ -155,12 +155,12 @@ def test_condition_or():
155155
"Conditions": [
156156
{
157157
"Type": "GreaterThan",
158-
"LeftValue": {"Get": "Execution.StartDateTime"},
158+
"LeftValue": var,
159159
"RightValue": "2020-12-01",
160160
},
161161
{
162162
"Type": "In",
163-
"QueryValue": {"Get": "Parameters.MyStr"},
163+
"QueryValue": param,
164164
"Values": ["abc", "def"],
165165
},
166166
],

0 commit comments

Comments
 (0)