Skip to content

Commit 7ae479f

Browse files
author
Dewen Qi
committed
change: Implement overide solution for pipeline varialbes
1 parent 7e2c7ab commit 7ae479f

16 files changed

+635
-66
lines changed

src/sagemaker/parameter.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
from __future__ import absolute_import
1515

1616
import json
17-
from sagemaker.workflow.parameters import Parameter as PipelineParameter
18-
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
19-
from sagemaker.workflow.functions import Join as PipelineJoin
17+
18+
from sagemaker.workflow.entities import PipelineVariable
2019

2120

2221
class ParameterRange(object):
@@ -73,11 +72,11 @@ def as_tuning_range(self, name):
7372
return {
7473
"Name": name,
7574
"MinValue": str(self.min_value)
76-
if not isinstance(self.min_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
77-
else self.min_value,
75+
if not isinstance(self.min_value, PipelineVariable)
76+
else self.min_value.to_string(),
7877
"MaxValue": str(self.max_value)
79-
if not isinstance(self.max_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
80-
else self.max_value,
78+
if not isinstance(self.max_value, PipelineVariable)
79+
else self.max_value.to_string(),
8180
"ScalingType": self.scaling_type,
8281
}
8382

@@ -112,8 +111,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
112111
"""
113112
values = values if isinstance(values, list) else [values]
114113
self.values = [
115-
str(v) if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) else v
116-
for v in values
114+
str(v) if not isinstance(v, PipelineVariable) else v.to_string() for v in values
117115
]
118116

119117
def as_tuning_range(self, name):

src/sagemaker/tuner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
IntegerParameter,
3939
ParameterRange,
4040
)
41+
from sagemaker.workflow.entities import PipelineVariable
4142
from sagemaker.workflow.parameters import Parameter as PipelineParameter
4243
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
4344
from sagemaker.workflow.functions import Join as PipelineJoin
@@ -376,9 +377,7 @@ def _prepare_static_hyperparameters(
376377
"""Prepare static hyperparameters for one estimator before tuning."""
377378
# Remove any hyperparameter that will be tuned
378379
static_hyperparameters = {
379-
str(k): str(v)
380-
if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin))
381-
else v
380+
str(k): str(v) if not isinstance(v, PipelineVariable) else v.to_string()
382381
for (k, v) in estimator.hyperparameters().items()
383382
}
384383
for hyperparameter_name in hyperparameter_ranges.keys():

src/sagemaker/workflow/entities.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import abc
1717

1818
from enum import EnumMeta
19-
from typing import Any, Dict, List, Union
19+
from typing import Any, Dict, List, Union, Optional
2020

2121
PrimitiveType = Union[str, int, bool, float, None]
2222
RequestType = Union[Dict[str, Any], List[Dict[str, Any]]]
@@ -57,3 +57,62 @@ class Expression(abc.ABC):
5757
@abc.abstractmethod
5858
def expr(self) -> RequestType:
5959
"""Get the expression structure for workflow service calls."""
60+
61+
62+
class PipelineVariable(Expression):
63+
"""Base object for pipeline variables
64+
65+
PipelineVariables must implement the expr property.
66+
"""
67+
68+
def __add__(self, other: Union[Expression, PrimitiveType]):
69+
"""Add function for PipelineVariable
70+
71+
Args:
72+
other (Union[Expression, PrimitiveType]): The other object to be concatenated.
73+
74+
Always raise an error since pipeline variables do not support concatenation
75+
"""
76+
77+
raise TypeError("Pipeline variables do not support concatenation.")
78+
79+
def __str__(self):
80+
"""Override built-in String function for PipelineVariable"""
81+
raise TypeError("Pipeline variables do not support __str__ operation.")
82+
83+
def __int__(self):
84+
"""Override built-in Integer function for PipelineVariable"""
85+
raise TypeError("Pipeline variables do not support __int__ operation.")
86+
87+
def __float__(self):
88+
"""Override built-in Float function for PipelineVariable"""
89+
raise TypeError("Pipeline variables do not support __float__ operation.")
90+
91+
def to_string(self):
92+
"""Prompt the pipeline to convert the pipeline variable to String in runtime"""
93+
from sagemaker.workflow.functions import Join
94+
95+
return Join(on="", values=[self])
96+
97+
@property
98+
@abc.abstractmethod
99+
def expr(self) -> RequestType:
100+
"""Get the expression structure for workflow service calls."""
101+
102+
def startswith(
103+
self,
104+
prefix: Union[str, tuple], # pylint: disable=unused-argument
105+
start: Optional[int] = None, # pylint: disable=unused-argument
106+
end: Optional[int] = None, # pylint: disable=unused-argument
107+
) -> bool:
108+
"""Simulate the Python string's built-in method: startswith
109+
110+
Args:
111+
prefix (str, tuple): The (tuple of) string to be checked.
112+
start (int): To set the start index of the matching boundary (default: None).
113+
end (int): To set the end index of the matching boundary (default: None).
114+
115+
Return:
116+
bool: Always return False as Pipeline variables are parsed during execution runtime
117+
"""
118+
return False

src/sagemaker/workflow/execution_variables.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.workflow.entities import (
17-
Expression,
1817
RequestType,
18+
PipelineVariable,
1919
)
2020

2121

22-
class ExecutionVariable(Expression):
22+
class ExecutionVariable(PipelineVariable):
2323
"""Pipeline execution variables for workflow."""
2424

2525
def __init__(self, name: str):
@@ -30,6 +30,13 @@ def __init__(self, name: str):
3030
"""
3131
self.name = name
3232

33+
def to_string(self) -> PipelineVariable:
34+
"""Prompt the pipeline to convert the pipeline variable to String in runtime
35+
36+
As ExecutionVariable is treated as String in runtime, no extra actions are needed.
37+
"""
38+
return self
39+
3340
@property
3441
def expr(self) -> RequestType:
3542
"""The 'Get' expression dict for an `ExecutionVariable`."""

src/sagemaker/workflow/functions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
import attr
1919

20-
from sagemaker.workflow.entities import Expression
20+
from sagemaker.workflow.entities import PipelineVariable
2121
from sagemaker.workflow.properties import PropertyFile
2222

2323

2424
@attr.s
25-
class Join(Expression):
25+
class Join(PipelineVariable):
2626
"""Join together properties.
2727
2828
Examples:
@@ -38,15 +38,23 @@ class Join(Expression):
3838
Attributes:
3939
values (List[Union[PrimitiveType, Parameter, Expression]]):
4040
The primitive type values, parameters, step properties, expressions to join.
41-
on_str (str): The string to join the values on (Defaults to "").
41+
on (str): The string to join the values on (Defaults to "").
4242
"""
4343

4444
on: str = attr.ib(factory=str)
4545
values: List = attr.ib(factory=list)
4646

47+
def to_string(self) -> PipelineVariable:
48+
"""Prompt the pipeline to convert the pipeline variable to String in runtime
49+
50+
As Join is treated as String in runtime, no extra actions are needed.
51+
"""
52+
return self
53+
4754
@property
4855
def expr(self):
4956
"""The expression dict for a `Join` function."""
57+
5058
return {
5159
"Std:Join": {
5260
"On": self.on,
@@ -58,7 +66,7 @@ def expr(self):
5866

5967

6068
@attr.s
61-
class JsonGet(Expression):
69+
class JsonGet(PipelineVariable):
6270
"""Get JSON properties from PropertyFiles.
6371
6472
Attributes:

src/sagemaker/workflow/parameters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Entity,
2525
PrimitiveType,
2626
RequestType,
27+
PipelineVariable,
2728
)
2829

2930

@@ -48,7 +49,7 @@ def python_type(self) -> Type:
4849

4950

5051
@attr.s
51-
class Parameter(Entity):
52+
class Parameter(PipelineVariable, Entity):
5253
"""Pipeline parameter for workflow.
5354
5455
Attributes:
@@ -170,6 +171,13 @@ def __hash__(self):
170171
"""Hash function for parameter types"""
171172
return hash(tuple(self.to_request()))
172173

174+
def to_string(self) -> PipelineVariable:
175+
"""Prompt the pipeline to convert the pipeline variable to String in runtime
176+
177+
As ParameterString is treated as String in runtime, no extra actions are needed.
178+
"""
179+
return self
180+
173181
def to_request(self) -> RequestType:
174182
"""Get the request structure for workflow service calls."""
175183
request_dict = super(ParameterString, self).to_request()

src/sagemaker/workflow/properties.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
"""The properties definitions for workflow."""
1414
from __future__ import absolute_import
1515

16+
from abc import ABCMeta
1617
from typing import Dict, Union, List
1718

1819
import attr
1920

2021
import botocore.loaders
2122

22-
from sagemaker.workflow.entities import Expression
23+
from sagemaker.workflow.entities import Expression, PipelineVariable
2324

2425

25-
class PropertiesMeta(type):
26+
class PropertiesMeta(ABCMeta):
2627
"""Load an internal shapes attribute from the botocore service model
2728
2829
for sagemaker and emr service.
@@ -44,7 +45,7 @@ def __new__(mcs, *args, **kwargs):
4445
return super().__new__(mcs, *args, **kwargs)
4546

4647

47-
class Properties(metaclass=PropertiesMeta):
48+
class Properties(PipelineVariable, metaclass=PropertiesMeta):
4849
"""Properties for use in workflow expressions."""
4950

5051
def __init__(
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from botocore.exceptions import WaiterError
17+
18+
from sagemaker import get_execution_role, utils
19+
from sagemaker.workflow.condition_step import ConditionStep
20+
from sagemaker.workflow.conditions import ConditionGreaterThan
21+
from sagemaker.workflow.fail_step import FailStep
22+
from sagemaker.workflow.functions import Join
23+
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
24+
from sagemaker.workflow.pipeline import Pipeline
25+
26+
27+
@pytest.fixture
28+
def role(sagemaker_session):
29+
return get_execution_role(sagemaker_session)
30+
31+
32+
@pytest.fixture
33+
def pipeline_name():
34+
return utils.unique_name_from_base("my-pipeline-vars")
35+
36+
37+
def test_ppl_var_to_string_and_add(sagemaker_session, role, pipeline_name):
38+
param_str = ParameterString(name="MyString", default_value="1")
39+
param_int = ParameterInteger(name="MyInteger", default_value=3)
40+
41+
cond = ConditionGreaterThan(left=param_str, right=param_int.to_string())
42+
step_cond = ConditionStep(
43+
name="CondStep",
44+
conditions=[cond],
45+
if_steps=[],
46+
else_steps=[],
47+
)
48+
join_fn1 = Join(
49+
on=" ",
50+
values=[
51+
"condition greater than check return:",
52+
step_cond.properties.Outcome.to_string(),
53+
"and left side param str is",
54+
param_str,
55+
"and right side param int is",
56+
param_int,
57+
],
58+
)
59+
60+
step_fail = FailStep(
61+
name="FailStep",
62+
error_message=join_fn1,
63+
)
64+
pipeline = Pipeline(
65+
name=pipeline_name,
66+
parameters=[param_str, param_int],
67+
steps=[step_cond, step_fail],
68+
sagemaker_session=sagemaker_session,
69+
)
70+
71+
try:
72+
response = pipeline.create(role)
73+
pipeline_arn = response["PipelineArn"]
74+
execution = pipeline.start()
75+
response = execution.describe()
76+
assert response["PipelineArn"] == pipeline_arn
77+
78+
try:
79+
execution.wait(delay=30, max_attempts=60)
80+
except WaiterError:
81+
pass
82+
execution_steps = execution.list_steps()
83+
84+
assert len(execution_steps) == 2
85+
for execution_step in execution_steps:
86+
if execution_step["StepName"] == "CondStep":
87+
assert execution_step["StepStatus"] == "Succeeded"
88+
continue
89+
assert execution_step["StepName"] == "FailStep"
90+
assert execution_step["StepStatus"] == "Failed"
91+
assert (
92+
execution_step["FailureReason"] == "condition greater than check return: false "
93+
"and left side param str is 1 and right side param int is 3"
94+
)
95+
96+
# Update int param to update cond step outcome
97+
execution = pipeline.start(parameters={"MyInteger": 0})
98+
try:
99+
execution.wait(delay=30, max_attempts=60)
100+
except WaiterError:
101+
pass
102+
execution_steps = execution.list_steps()
103+
104+
assert len(execution_steps) == 2
105+
for execution_step in execution_steps:
106+
if execution_step["StepName"] == "CondStep":
107+
assert execution_step["StepStatus"] == "Succeeded"
108+
continue
109+
assert execution_step["StepName"] == "FailStep"
110+
assert execution_step["StepStatus"] == "Failed"
111+
assert (
112+
execution_step["FailureReason"] == "condition greater than check return: true "
113+
"and left side param str is 1 and right side param int is 0"
114+
)
115+
finally:
116+
try:
117+
pipeline.delete()
118+
except Exception:
119+
pass

0 commit comments

Comments
 (0)