Skip to content

Commit 405567e

Browse files
author
Ao Guo
committed
remove primitive_or_expr() from conditions
1 parent c70e30c commit 405567e

File tree

4 files changed

+178
-43
lines changed

4 files changed

+178
-43
lines changed

src/sagemaker/workflow/condition_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def arguments(self) -> RequestType:
9393
@property
9494
def step_only_arguments(self):
9595
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
96-
return self.conditions
96+
return [condition.to_request() for condition in self.conditions]
9797

9898
@property
9999
def properties(self):

src/sagemaker/workflow/conditions.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
import abc
2121

2222
from enum import Enum
23-
from typing import Dict, List, Union
23+
from typing import List, Union
2424

2525
import attr
2626

27-
from sagemaker.workflow import is_pipeline_variable
2827
from sagemaker.workflow.entities import (
2928
DefaultEnumMeta,
3029
Entity,
31-
Expression,
3230
PrimitiveType,
3331
RequestType,
3432
)
@@ -88,8 +86,8 @@ def to_request(self) -> RequestType:
8886
"""Get the request structure for workflow service calls."""
8987
return {
9088
"Type": self.condition_type.value,
91-
"LeftValue": primitive_or_expr(self.left),
92-
"RightValue": primitive_or_expr(self.right),
89+
"LeftValue": self.left,
90+
"RightValue": self.right,
9391
}
9492

9593
@property
@@ -227,8 +225,8 @@ def to_request(self) -> RequestType:
227225
"""Get the request structure for workflow service calls."""
228226
return {
229227
"Type": self.condition_type.value,
230-
"QueryValue": self.value.expr,
231-
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
228+
"QueryValue": self.value,
229+
"Values": self.in_values,
232230
}
233231

234232
@property
@@ -291,19 +289,3 @@ def _referenced_steps(self) -> List[str]:
291289
for condition in self.conditions:
292290
steps.extend(condition._referenced_steps)
293291
return steps
294-
295-
296-
def primitive_or_expr(
297-
value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties]
298-
) -> Union[Dict[str, str], PrimitiveType]:
299-
"""Provide the expression of the value or return value if it is a primitive.
300-
301-
Args:
302-
value (Union[ConditionValueType, PrimitiveType]): The value to evaluate.
303-
304-
Returns:
305-
Either the expression of the value or the primitive value.
306-
"""
307-
if is_pipeline_variable(value):
308-
return value.expr
309-
return value

tests/unit/sagemaker/workflow/test_condition_step.py

Lines changed: 156 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,25 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import json
1415

1516
import pytest
1617
from mock import Mock, MagicMock
17-
from sagemaker.workflow.conditions import ConditionEquals
18-
from sagemaker.workflow.parameters import ParameterInteger
18+
from sagemaker.workflow.conditions import (
19+
ConditionEquals,
20+
ConditionGreaterThan,
21+
ConditionGreaterThanOrEqualTo,
22+
ConditionIn,
23+
ConditionLessThan,
24+
ConditionLessThanOrEqualTo,
25+
ConditionNot,
26+
ConditionOr,
27+
)
28+
from sagemaker.workflow.execution_variables import ExecutionVariables
29+
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
1930
from sagemaker.workflow.condition_step import ConditionStep
2031
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
32+
from sagemaker.workflow.properties import Properties
2133
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered
2234

2335

@@ -56,7 +68,7 @@ def test_condition_step():
5668
"Conditions": [
5769
{
5870
"Type": "Equals",
59-
"LeftValue": {"Get": "Parameters.MyInt"},
71+
"LeftValue": param,
6072
"RightValue": 1,
6173
},
6274
],
@@ -79,6 +91,147 @@ def test_condition_step():
7991
assert cond_step.properties.Outcome.expr == {"Get": "Steps.MyConditionStep.Outcome"}
8092

8193

94+
def test_pipeline_condition_step_interpolated(sagemaker_session):
95+
param1 = ParameterInteger(name="MyInt1")
96+
param2 = ParameterInteger(name="MyInt2")
97+
param3 = ParameterString(name="MyStr")
98+
var = ExecutionVariables.START_DATETIME
99+
prop = Properties("foo")
100+
101+
cond_eq = ConditionEquals(left=param1, right=param2)
102+
cond_gt = ConditionGreaterThan(left=var, right="2020-12-01")
103+
cond_gte = ConditionGreaterThanOrEqualTo(left=var, right=param3)
104+
cond_lt = ConditionLessThan(left=var, right="2020-12-01")
105+
cond_lte = ConditionLessThanOrEqualTo(left=var, right=param3)
106+
cond_in = ConditionIn(value=param3, in_values=["abc", "def"])
107+
cond_in_mixed = ConditionIn(value=param3, in_values=["abc", prop, var])
108+
cond_not_eq = ConditionNot(expression=cond_eq)
109+
cond_not_in = ConditionNot(expression=cond_in)
110+
cond_or = ConditionOr(conditions=[cond_gt, cond_in])
111+
112+
step1 = CustomStep(name="MyStep1")
113+
step2 = CustomStep(name="MyStep2")
114+
cond_step = ConditionStep(
115+
name="MyConditionStep",
116+
conditions=[
117+
cond_eq,
118+
cond_gt,
119+
cond_gte,
120+
cond_lt,
121+
cond_lte,
122+
cond_in,
123+
cond_in_mixed,
124+
cond_not_eq,
125+
cond_not_in,
126+
cond_or,
127+
],
128+
if_steps=[step1],
129+
else_steps=[step2],
130+
)
131+
132+
pipeline = Pipeline(
133+
name="MyPipeline",
134+
parameters=[param1, param2, param3],
135+
steps=[cond_step],
136+
sagemaker_session=sagemaker_session,
137+
)
138+
assert json.loads(pipeline.definition()) == {
139+
"Version": "2020-12-01",
140+
"Metadata": {},
141+
"Parameters": [
142+
{"Name": "MyInt1", "Type": "Integer"},
143+
{"Name": "MyInt2", "Type": "Integer"},
144+
{"Name": "MyStr", "Type": "String"},
145+
],
146+
"PipelineExperimentConfig": {
147+
"ExperimentName": {"Get": "Execution.PipelineName"},
148+
"TrialName": {"Get": "Execution.PipelineExecutionId"},
149+
},
150+
"Steps": [
151+
{
152+
"Name": "MyConditionStep",
153+
"Type": "Condition",
154+
"Arguments": {
155+
"Conditions": [
156+
{
157+
"Type": "Equals",
158+
"LeftValue": {"Get": "Parameters.MyInt1"},
159+
"RightValue": {"Get": "Parameters.MyInt2"},
160+
},
161+
{
162+
"Type": "GreaterThan",
163+
"LeftValue": {"Get": "Execution.StartDateTime"},
164+
"RightValue": "2020-12-01",
165+
},
166+
{
167+
"Type": "GreaterThanOrEqualTo",
168+
"LeftValue": {"Get": "Execution.StartDateTime"},
169+
"RightValue": {"Get": "Parameters.MyStr"},
170+
},
171+
{
172+
"Type": "LessThan",
173+
"LeftValue": {"Get": "Execution.StartDateTime"},
174+
"RightValue": "2020-12-01",
175+
},
176+
{
177+
"Type": "LessThanOrEqualTo",
178+
"LeftValue": {"Get": "Execution.StartDateTime"},
179+
"RightValue": {"Get": "Parameters.MyStr"},
180+
},
181+
{
182+
"Type": "In",
183+
"QueryValue": {"Get": "Parameters.MyStr"},
184+
"Values": ["abc", "def"],
185+
},
186+
{
187+
"Type": "In",
188+
"QueryValue": {"Get": "Parameters.MyStr"},
189+
"Values": [
190+
"abc",
191+
{"Get": "Steps.foo"},
192+
{"Get": "Execution.StartDateTime"},
193+
],
194+
},
195+
{
196+
"Type": "Not",
197+
"Expression": {
198+
"Type": "Equals",
199+
"LeftValue": {"Get": "Parameters.MyInt1"},
200+
"RightValue": {"Get": "Parameters.MyInt2"},
201+
},
202+
},
203+
{
204+
"Type": "Not",
205+
"Expression": {
206+
"Type": "In",
207+
"QueryValue": {"Get": "Parameters.MyStr"},
208+
"Values": ["abc", "def"],
209+
},
210+
},
211+
{
212+
"Type": "Or",
213+
"Conditions": [
214+
{
215+
"Type": "GreaterThan",
216+
"LeftValue": {"Get": "Execution.StartDateTime"},
217+
"RightValue": "2020-12-01",
218+
},
219+
{
220+
"Type": "In",
221+
"QueryValue": {"Get": "Parameters.MyStr"},
222+
"Values": ["abc", "def"],
223+
},
224+
],
225+
},
226+
],
227+
"IfSteps": [{"Name": "MyStep1", "Type": "Training", "Arguments": {}}],
228+
"ElseSteps": [{"Name": "MyStep2", "Type": "Training", "Arguments": {}}],
229+
},
230+
}
231+
],
232+
}
233+
234+
82235
def test_pipeline(sagemaker_session):
83236
param = ParameterInteger(name="MyInt", default_value=2)
84237
cond = ConditionEquals(left=param, right=1)

tests/unit/sagemaker/workflow/test_conditions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_condition_equals():
3636
cond = ConditionEquals(left=param, right=1)
3737
assert cond.to_request() == {
3838
"Type": "Equals",
39-
"LeftValue": {"Get": "Parameters.MyInt"},
39+
"LeftValue": param,
4040
"RightValue": 1,
4141
}
4242

@@ -47,8 +47,8 @@ def test_condition_equals_parameter():
4747
cond = ConditionEquals(left=param1, right=param2)
4848
assert cond.to_request() == {
4949
"Type": "Equals",
50-
"LeftValue": {"Get": "Parameters.MyInt1"},
51-
"RightValue": {"Get": "Parameters.MyInt2"},
50+
"LeftValue": param1,
51+
"RightValue": param2,
5252
}
5353

5454

@@ -57,7 +57,7 @@ def test_condition_greater_than():
5757
cond = ConditionGreaterThan(left=var, right="2020-12-01")
5858
assert cond.to_request() == {
5959
"Type": "GreaterThan",
60-
"LeftValue": {"Get": "Execution.StartDateTime"},
60+
"LeftValue": var,
6161
"RightValue": "2020-12-01",
6262
}
6363

@@ -68,8 +68,8 @@ def test_condition_greater_than_or_equal_to():
6868
cond = ConditionGreaterThanOrEqualTo(left=var, right=param)
6969
assert cond.to_request() == {
7070
"Type": "GreaterThanOrEqualTo",
71-
"LeftValue": {"Get": "Execution.StartDateTime"},
72-
"RightValue": {"Get": "Parameters.StartDateTime"},
71+
"LeftValue": var,
72+
"RightValue": param,
7373
}
7474

7575

@@ -78,7 +78,7 @@ def test_condition_less_than():
7878
cond = ConditionLessThan(left=var, right="2020-12-01")
7979
assert cond.to_request() == {
8080
"Type": "LessThan",
81-
"LeftValue": {"Get": "Execution.StartDateTime"},
81+
"LeftValue": var,
8282
"RightValue": "2020-12-01",
8383
}
8484

@@ -89,8 +89,8 @@ def test_condition_less_than_or_equal_to():
8989
cond = ConditionLessThanOrEqualTo(left=var, right=param)
9090
assert cond.to_request() == {
9191
"Type": "LessThanOrEqualTo",
92-
"LeftValue": {"Get": "Execution.StartDateTime"},
93-
"RightValue": {"Get": "Parameters.StartDateTime"},
92+
"LeftValue": var,
93+
"RightValue": param,
9494
}
9595

9696

@@ -99,7 +99,7 @@ def test_condition_in():
9999
cond_in = ConditionIn(value=param, in_values=["abc", "def"])
100100
assert cond_in.to_request() == {
101101
"Type": "In",
102-
"QueryValue": {"Get": "Parameters.MyStr"},
102+
"QueryValue": param,
103103
"Values": ["abc", "def"],
104104
}
105105

@@ -111,8 +111,8 @@ def test_condition_in_mixed():
111111
cond_in = ConditionIn(value=param, in_values=["abc", prop, var])
112112
assert cond_in.to_request() == {
113113
"Type": "In",
114-
"QueryValue": {"Get": "Parameters.MyStr"},
115-
"Values": ["abc", {"Get": "Steps.foo"}, {"Get": "Execution.StartDateTime"}],
114+
"QueryValue": param,
115+
"Values": ["abc", prop, var],
116116
}
117117

118118

@@ -124,7 +124,7 @@ def test_condition_not():
124124
"Type": "Not",
125125
"Expression": {
126126
"Type": "Equals",
127-
"LeftValue": {"Get": "Parameters.MyStr"},
127+
"LeftValue": param,
128128
"RightValue": "foo",
129129
},
130130
}
@@ -138,7 +138,7 @@ def test_condition_not_in():
138138
"Type": "Not",
139139
"Expression": {
140140
"Type": "In",
141-
"QueryValue": {"Get": "Parameters.MyStr"},
141+
"QueryValue": param,
142142
"Values": ["abc", "def"],
143143
},
144144
}
@@ -155,12 +155,12 @@ def test_condition_or():
155155
"Conditions": [
156156
{
157157
"Type": "GreaterThan",
158-
"LeftValue": {"Get": "Execution.StartDateTime"},
158+
"LeftValue": var,
159159
"RightValue": "2020-12-01",
160160
},
161161
{
162162
"Type": "In",
163-
"QueryValue": {"Get": "Parameters.MyStr"},
163+
"QueryValue": param,
164164
"Values": ["abc", "def"],
165165
},
166166
],

0 commit comments

Comments
 (0)