Skip to content

change: remove primitive_or_expr() from conditions #3212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 6 additions & 30 deletions src/sagemaker/workflow/condition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,17 @@

from typing import List, Union, Optional

import attr

from sagemaker.deprecations import deprecated_class
from sagemaker.workflow.conditions import Condition
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.functions import JsonGet as NewJsonGet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename NewJsonGet to JsonGetV2

from sagemaker.workflow.steps import (
Step,
StepTypeEnum,
)
from sagemaker.workflow.utilities import list_to_request
from sagemaker.workflow.entities import (
RequestType,
PipelineVariable,
)
from sagemaker.workflow.entities import RequestType
from sagemaker.workflow.properties import (
Properties,
PropertyFile,
Expand Down Expand Up @@ -93,16 +90,15 @@ def arguments(self) -> RequestType:
@property
def step_only_arguments(self):
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
return self.conditions
return [condition.to_request() for condition in self.conditions]

@property
def properties(self):
"""A simple Properties object with `Outcome` as the only property"""
return self._properties


@attr.s
class JsonGet(PipelineVariable): # pragma: no cover
class JsonGet(NewJsonGet): # pragma: no cover
"""Get JSON properties from PropertyFiles.
Attributes:
Expand All @@ -112,28 +108,8 @@ class JsonGet(PipelineVariable): # pragma: no cover
json_path (str): The JSON path expression to the requested value.
"""

step: Step = attr.ib()
property_file: Union[PropertyFile, str] = attr.ib()
json_path: str = attr.ib()

@property
def expr(self):
"""The expression dict for a `JsonGet` function."""
if isinstance(self.property_file, PropertyFile):
name = self.property_file.name
else:
name = self.property_file
return {
"Std:JsonGet": {
"PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"},
"Path": self.json_path,
}
}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
return [self.step.name]
def __init__(self, step: Step, property_file: Union[PropertyFile, str], json_path: str):
super().__init__(step_name=step.name, property_file=property_file, json_path=json_path)


JsonGet = deprecated_class(JsonGet, "JsonGet")
28 changes: 5 additions & 23 deletions src/sagemaker/workflow/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
import abc

from enum import Enum
from typing import Dict, List, Union
from typing import List, Union

import attr

from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import (
DefaultEnumMeta,
Entity,
Expression,
PrimitiveType,
RequestType,
)
Expand Down Expand Up @@ -88,8 +86,8 @@ def to_request(self) -> RequestType:
"""Get the request structure for workflow service calls."""
return {
"Type": self.condition_type.value,
"LeftValue": primitive_or_expr(self.left),
"RightValue": primitive_or_expr(self.right),
"LeftValue": self.left,
"RightValue": self.right,
}

@property
Expand Down Expand Up @@ -227,8 +225,8 @@ def to_request(self) -> RequestType:
"""Get the request structure for workflow service calls."""
return {
"Type": self.condition_type.value,
"QueryValue": self.value.expr,
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
"QueryValue": self.value,
"Values": self.in_values,
}

@property
Expand Down Expand Up @@ -291,19 +289,3 @@ def _referenced_steps(self) -> List[str]:
for condition in self.conditions:
steps.extend(condition._referenced_steps)
return steps


def primitive_or_expr(
value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties]
) -> Union[Dict[str, str], PrimitiveType]:
"""Provide the expression of the value or return value if it is a primitive.
Args:
value (Union[ConditionValueType, PrimitiveType]): The value to evaluate.
Returns:
Either the expression of the value or the primitive value.
"""
if is_pipeline_variable(value):
return value.expr
return value
159 changes: 156 additions & 3 deletions tests/unit/sagemaker/workflow/test_condition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,25 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import json

import pytest
from mock import Mock, MagicMock
from sagemaker.workflow.conditions import ConditionEquals
from sagemaker.workflow.parameters import ParameterInteger
from sagemaker.workflow.conditions import (
ConditionEquals,
ConditionGreaterThan,
ConditionGreaterThanOrEqualTo,
ConditionIn,
ConditionLessThan,
ConditionLessThanOrEqualTo,
ConditionNot,
ConditionOr,
)
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
from sagemaker.workflow.properties import Properties
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered


Expand Down Expand Up @@ -56,7 +68,7 @@ def test_condition_step():
"Conditions": [
{
"Type": "Equals",
"LeftValue": {"Get": "Parameters.MyInt"},
"LeftValue": param,
"RightValue": 1,
},
],
Expand All @@ -79,6 +91,147 @@ def test_condition_step():
assert cond_step.properties.Outcome.expr == {"Get": "Steps.MyConditionStep.Outcome"}


def test_pipeline_condition_step_interpolated(sagemaker_session):
param1 = ParameterInteger(name="MyInt1")
param2 = ParameterInteger(name="MyInt2")
param3 = ParameterString(name="MyStr")
var = ExecutionVariables.START_DATETIME
prop = Properties("foo")

cond_eq = ConditionEquals(left=param1, right=param2)
cond_gt = ConditionGreaterThan(left=var, right="2020-12-01")
cond_gte = ConditionGreaterThanOrEqualTo(left=var, right=param3)
cond_lt = ConditionLessThan(left=var, right="2020-12-01")
cond_lte = ConditionLessThanOrEqualTo(left=var, right=param3)
cond_in = ConditionIn(value=param3, in_values=["abc", "def"])
cond_in_mixed = ConditionIn(value=param3, in_values=["abc", prop, var])
cond_not_eq = ConditionNot(expression=cond_eq)
cond_not_in = ConditionNot(expression=cond_in)
cond_or = ConditionOr(conditions=[cond_gt, cond_in])

step1 = CustomStep(name="MyStep1")
step2 = CustomStep(name="MyStep2")
cond_step = ConditionStep(
name="MyConditionStep",
conditions=[
cond_eq,
cond_gt,
cond_gte,
cond_lt,
cond_lte,
cond_in,
cond_in_mixed,
cond_not_eq,
cond_not_in,
cond_or,
],
if_steps=[step1],
else_steps=[step2],
)

pipeline = Pipeline(
name="MyPipeline",
parameters=[param1, param2, param3],
steps=[cond_step],
sagemaker_session=sagemaker_session,
)
assert json.loads(pipeline.definition()) == {
"Version": "2020-12-01",
"Metadata": {},
"Parameters": [
{"Name": "MyInt1", "Type": "Integer"},
{"Name": "MyInt2", "Type": "Integer"},
{"Name": "MyStr", "Type": "String"},
],
"PipelineExperimentConfig": {
"ExperimentName": {"Get": "Execution.PipelineName"},
"TrialName": {"Get": "Execution.PipelineExecutionId"},
},
"Steps": [
{
"Name": "MyConditionStep",
"Type": "Condition",
"Arguments": {
"Conditions": [
{
"Type": "Equals",
"LeftValue": {"Get": "Parameters.MyInt1"},
"RightValue": {"Get": "Parameters.MyInt2"},
},
{
"Type": "GreaterThan",
"LeftValue": {"Get": "Execution.StartDateTime"},
"RightValue": "2020-12-01",
},
{
"Type": "GreaterThanOrEqualTo",
"LeftValue": {"Get": "Execution.StartDateTime"},
"RightValue": {"Get": "Parameters.MyStr"},
},
{
"Type": "LessThan",
"LeftValue": {"Get": "Execution.StartDateTime"},
"RightValue": "2020-12-01",
},
{
"Type": "LessThanOrEqualTo",
"LeftValue": {"Get": "Execution.StartDateTime"},
"RightValue": {"Get": "Parameters.MyStr"},
},
{
"Type": "In",
"QueryValue": {"Get": "Parameters.MyStr"},
"Values": ["abc", "def"],
},
{
"Type": "In",
"QueryValue": {"Get": "Parameters.MyStr"},
"Values": [
"abc",
{"Get": "Steps.foo"},
{"Get": "Execution.StartDateTime"},
],
},
{
"Type": "Not",
"Expression": {
"Type": "Equals",
"LeftValue": {"Get": "Parameters.MyInt1"},
"RightValue": {"Get": "Parameters.MyInt2"},
},
},
{
"Type": "Not",
"Expression": {
"Type": "In",
"QueryValue": {"Get": "Parameters.MyStr"},
"Values": ["abc", "def"],
},
},
{
"Type": "Or",
"Conditions": [
{
"Type": "GreaterThan",
"LeftValue": {"Get": "Execution.StartDateTime"},
"RightValue": "2020-12-01",
},
{
"Type": "In",
"QueryValue": {"Get": "Parameters.MyStr"},
"Values": ["abc", "def"],
},
],
},
],
"IfSteps": [{"Name": "MyStep1", "Type": "Training", "Arguments": {}}],
"ElseSteps": [{"Name": "MyStep2", "Type": "Training", "Arguments": {}}],
},
}
],
}


def test_pipeline(sagemaker_session):
param = ParameterInteger(name="MyInt", default_value=2)
cond = ConditionEquals(left=param, right=1)
Expand Down
Loading