Skip to content

Commit 7dd245e

Browse files
author
Ao Guo
committed
remove primitive_or_expr() from conditions
1 parent c70e30c commit 7dd245e

File tree

4 files changed

+182
-67
lines changed

4 files changed

+182
-67
lines changed

src/sagemaker/workflow/condition_step.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.deprecations import deprecated_class
2121
from sagemaker.workflow.conditions import Condition
2222
from sagemaker.workflow.step_collections import StepCollection
23+
from sagemaker.workflow.functions import JsonGet as NewJsonGet
2324
from sagemaker.workflow.steps import (
2425
Step,
2526
StepTypeEnum,
@@ -93,16 +94,15 @@ def arguments(self) -> RequestType:
9394
@property
9495
def step_only_arguments(self):
9596
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
96-
return self.conditions
97+
return [condition.to_request() for condition in self.conditions]
9798

9899
@property
99100
def properties(self):
100101
"""A simple Properties object with `Outcome` as the only property"""
101102
return self._properties
102103

103104

104-
@attr.s
105-
class JsonGet(PipelineVariable): # pragma: no cover
105+
class JsonGet(NewJsonGet): # pragma: no cover
106106
"""Get JSON properties from PropertyFiles.
107107
108108
Attributes:
@@ -112,28 +112,8 @@ class JsonGet(PipelineVariable): # pragma: no cover
112112
json_path (str): The JSON path expression to the requested value.
113113
"""
114114

115-
step: Step = attr.ib()
116-
property_file: Union[PropertyFile, str] = attr.ib()
117-
json_path: str = attr.ib()
118-
119-
@property
120-
def expr(self):
121-
"""The expression dict for a `JsonGet` function."""
122-
if isinstance(self.property_file, PropertyFile):
123-
name = self.property_file.name
124-
else:
125-
name = self.property_file
126-
return {
127-
"Std:JsonGet": {
128-
"PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"},
129-
"Path": self.json_path,
130-
}
131-
}
132-
133-
@property
134-
def _referenced_steps(self) -> List[str]:
135-
"""List of step names that this function depends on."""
136-
return [self.step.name]
115+
def __init__(self, step: Step, property_file: Union[PropertyFile, str], json_path: str):
116+
super().__init__(step_name=step.name, property_file=property_file, json_path=json_path)
137117

138118

139119
JsonGet = deprecated_class(JsonGet, "JsonGet")

src/sagemaker/workflow/conditions.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
import abc
2121

2222
from enum import Enum
23-
from typing import Dict, List, Union
23+
from typing import List, Union
2424

2525
import attr
2626

27-
from sagemaker.workflow import is_pipeline_variable
2827
from sagemaker.workflow.entities import (
2928
DefaultEnumMeta,
3029
Entity,
31-
Expression,
3230
PrimitiveType,
3331
RequestType,
3432
)
@@ -88,8 +86,8 @@ def to_request(self) -> RequestType:
8886
"""Get the request structure for workflow service calls."""
8987
return {
9088
"Type": self.condition_type.value,
91-
"LeftValue": primitive_or_expr(self.left),
92-
"RightValue": primitive_or_expr(self.right),
89+
"LeftValue": self.left,
90+
"RightValue": self.right,
9391
}
9492

9593
@property
@@ -227,8 +225,8 @@ def to_request(self) -> RequestType:
227225
"""Get the request structure for workflow service calls."""
228226
return {
229227
"Type": self.condition_type.value,
230-
"QueryValue": self.value.expr,
231-
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
228+
"QueryValue": self.value,
229+
"Values": self.in_values,
232230
}
233231

234232
@property
@@ -291,19 +289,3 @@ def _referenced_steps(self) -> List[str]:
291289
for condition in self.conditions:
292290
steps.extend(condition._referenced_steps)
293291
return steps
294-
295-
296-
def primitive_or_expr(
297-
value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties]
298-
) -> Union[Dict[str, str], PrimitiveType]:
299-
"""Provide the expression of the value or return value if it is a primitive.
300-
301-
Args:
302-
value (Union[ConditionValueType, PrimitiveType]): The value to evaluate.
303-
304-
Returns:
305-
Either the expression of the value or the primitive value.
306-
"""
307-
if is_pipeline_variable(value):
308-
return value.expr
309-
return value

tests/unit/sagemaker/workflow/test_condition_step.py

Lines changed: 156 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,25 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import json
1415

1516
import pytest
1617
from mock import Mock, MagicMock
17-
from sagemaker.workflow.conditions import ConditionEquals
18-
from sagemaker.workflow.parameters import ParameterInteger
18+
from sagemaker.workflow.conditions import (
19+
ConditionEquals,
20+
ConditionGreaterThan,
21+
ConditionGreaterThanOrEqualTo,
22+
ConditionIn,
23+
ConditionLessThan,
24+
ConditionLessThanOrEqualTo,
25+
ConditionNot,
26+
ConditionOr,
27+
)
28+
from sagemaker.workflow.execution_variables import ExecutionVariables
29+
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
1930
from sagemaker.workflow.condition_step import ConditionStep
2031
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
32+
from sagemaker.workflow.properties import Properties
2133
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered
2234

2335

@@ -56,7 +68,7 @@ def test_condition_step():
5668
"Conditions": [
5769
{
5870
"Type": "Equals",
59-
"LeftValue": {"Get": "Parameters.MyInt"},
71+
"LeftValue": param,
6072
"RightValue": 1,
6173
},
6274
],
@@ -79,6 +91,147 @@ def test_condition_step():
7991
assert cond_step.properties.Outcome.expr == {"Get": "Steps.MyConditionStep.Outcome"}
8092

8193

94+
def test_pipeline_condition_step_interpolated(sagemaker_session):
95+
param1 = ParameterInteger(name="MyInt1")
96+
param2 = ParameterInteger(name="MyInt2")
97+
param3 = ParameterString(name="MyStr")
98+
var = ExecutionVariables.START_DATETIME
99+
prop = Properties("foo")
100+
101+
cond_eq = ConditionEquals(left=param1, right=param2)
102+
cond_gt = ConditionGreaterThan(left=var, right="2020-12-01")
103+
cond_gte = ConditionGreaterThanOrEqualTo(left=var, right=param3)
104+
cond_lt = ConditionLessThan(left=var, right="2020-12-01")
105+
cond_lte = ConditionLessThanOrEqualTo(left=var, right=param3)
106+
cond_in = ConditionIn(value=param3, in_values=["abc", "def"])
107+
cond_in_mixed = ConditionIn(value=param3, in_values=["abc", prop, var])
108+
cond_not_eq = ConditionNot(expression=cond_eq)
109+
cond_not_in = ConditionNot(expression=cond_in)
110+
cond_or = ConditionOr(conditions=[cond_gt, cond_in])
111+
112+
step1 = CustomStep(name="MyStep1")
113+
step2 = CustomStep(name="MyStep2")
114+
cond_step = ConditionStep(
115+
name="MyConditionStep",
116+
conditions=[
117+
cond_eq,
118+
cond_gt,
119+
cond_gte,
120+
cond_lt,
121+
cond_lte,
122+
cond_in,
123+
cond_in_mixed,
124+
cond_not_eq,
125+
cond_not_in,
126+
cond_or,
127+
],
128+
if_steps=[step1],
129+
else_steps=[step2],
130+
)
131+
132+
pipeline = Pipeline(
133+
name="MyPipeline",
134+
parameters=[param1, param2, param3],
135+
steps=[cond_step],
136+
sagemaker_session=sagemaker_session,
137+
)
138+
assert json.loads(pipeline.definition()) == {
139+
"Version": "2020-12-01",
140+
"Metadata": {},
141+
"Parameters": [
142+
{"Name": "MyInt1", "Type": "Integer"},
143+
{"Name": "MyInt2", "Type": "Integer"},
144+
{"Name": "MyStr", "Type": "String"},
145+
],
146+
"PipelineExperimentConfig": {
147+
"ExperimentName": {"Get": "Execution.PipelineName"},
148+
"TrialName": {"Get": "Execution.PipelineExecutionId"},
149+
},
150+
"Steps": [
151+
{
152+
"Name": "MyConditionStep",
153+
"Type": "Condition",
154+
"Arguments": {
155+
"Conditions": [
156+
{
157+
"Type": "Equals",
158+
"LeftValue": {"Get": "Parameters.MyInt1"},
159+
"RightValue": {"Get": "Parameters.MyInt2"},
160+
},
161+
{
162+
"Type": "GreaterThan",
163+
"LeftValue": {"Get": "Execution.StartDateTime"},
164+
"RightValue": "2020-12-01",
165+
},
166+
{
167+
"Type": "GreaterThanOrEqualTo",
168+
"LeftValue": {"Get": "Execution.StartDateTime"},
169+
"RightValue": {"Get": "Parameters.MyStr"},
170+
},
171+
{
172+
"Type": "LessThan",
173+
"LeftValue": {"Get": "Execution.StartDateTime"},
174+
"RightValue": "2020-12-01",
175+
},
176+
{
177+
"Type": "LessThanOrEqualTo",
178+
"LeftValue": {"Get": "Execution.StartDateTime"},
179+
"RightValue": {"Get": "Parameters.MyStr"},
180+
},
181+
{
182+
"Type": "In",
183+
"QueryValue": {"Get": "Parameters.MyStr"},
184+
"Values": ["abc", "def"],
185+
},
186+
{
187+
"Type": "In",
188+
"QueryValue": {"Get": "Parameters.MyStr"},
189+
"Values": [
190+
"abc",
191+
{"Get": "Steps.foo"},
192+
{"Get": "Execution.StartDateTime"},
193+
],
194+
},
195+
{
196+
"Type": "Not",
197+
"Expression": {
198+
"Type": "Equals",
199+
"LeftValue": {"Get": "Parameters.MyInt1"},
200+
"RightValue": {"Get": "Parameters.MyInt2"},
201+
},
202+
},
203+
{
204+
"Type": "Not",
205+
"Expression": {
206+
"Type": "In",
207+
"QueryValue": {"Get": "Parameters.MyStr"},
208+
"Values": ["abc", "def"],
209+
},
210+
},
211+
{
212+
"Type": "Or",
213+
"Conditions": [
214+
{
215+
"Type": "GreaterThan",
216+
"LeftValue": {"Get": "Execution.StartDateTime"},
217+
"RightValue": "2020-12-01",
218+
},
219+
{
220+
"Type": "In",
221+
"QueryValue": {"Get": "Parameters.MyStr"},
222+
"Values": ["abc", "def"],
223+
},
224+
],
225+
},
226+
],
227+
"IfSteps": [{"Name": "MyStep1", "Type": "Training", "Arguments": {}}],
228+
"ElseSteps": [{"Name": "MyStep2", "Type": "Training", "Arguments": {}}],
229+
},
230+
}
231+
],
232+
}
233+
234+
82235
def test_pipeline(sagemaker_session):
83236
param = ParameterInteger(name="MyInt", default_value=2)
84237
cond = ConditionEquals(left=param, right=1)

0 commit comments

Comments
 (0)