Skip to content

Commit 221eb8d

Browse files
author
Dewen Qi
committed
change: Implement overriden solution for pipeline variables
1 parent dfc6eee commit 221eb8d

17 files changed

+721
-58
lines changed

src/sagemaker/parameter.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from __future__ import absolute_import
1515

1616
import json
17-
from sagemaker.workflow.parameters import Parameter as PipelineParameter
18-
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
19-
from sagemaker.workflow.functions import Join as PipelineJoin
2017

2118

2219
class ParameterRange(object):
@@ -72,12 +69,8 @@ def as_tuning_range(self, name):
7269
"""
7370
return {
7471
"Name": name,
75-
"MinValue": str(self.min_value)
76-
if not isinstance(self.min_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
77-
else self.min_value,
78-
"MaxValue": str(self.max_value)
79-
if not isinstance(self.max_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
80-
else self.max_value,
72+
"MinValue": str(self.min_value),
73+
"MaxValue": str(self.max_value),
8174
"ScalingType": self.scaling_type,
8275
}
8376

@@ -111,10 +104,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
111104
This input will be converted into a list of strings.
112105
"""
113106
values = values if isinstance(values, list) else [values]
114-
self.values = [
115-
str(v) if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) else v
116-
for v in values
117-
]
107+
self.values = [str(v) for v in values]
118108

119109
def as_tuning_range(self, name):
120110
"""Represent the parameter range as a dictionary.

src/sagemaker/tuner.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,12 +375,7 @@ def _prepare_static_hyperparameters(
375375
):
376376
"""Prepare static hyperparameters for one estimator before tuning."""
377377
# Remove any hyperparameter that will be tuned
378-
static_hyperparameters = {
379-
str(k): str(v)
380-
if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin))
381-
else v
382-
for (k, v) in estimator.hyperparameters().items()
383-
}
378+
static_hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
384379
for hyperparameter_name in hyperparameter_ranges.keys():
385380
static_hyperparameters.pop(hyperparameter_name, None)
386381

src/sagemaker/workflow/entities.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import json
1718

1819
from enum import EnumMeta
19-
from typing import Any, Dict, List, Union
20+
from typing import Any, Dict, List, Union, Optional
2021

2122
PrimitiveType = Union[str, int, bool, float, None]
2223
RequestType = Union[Dict[str, Any], List[Dict[str, Any]]]
@@ -57,3 +58,68 @@ class Expression(abc.ABC):
5758
@abc.abstractmethod
5859
def expr(self) -> RequestType:
5960
"""Get the expression structure for workflow service calls."""
61+
62+
63+
class PipelineVariable(Expression):
64+
"""Base object for pipeline variables
65+
66+
PipelineVariables must implement the expr property.
67+
"""
68+
69+
def __add__(self, other: Union[Expression, PrimitiveType]):
70+
"""Add function for PipelineVariable
71+
72+
Args:
73+
other (Union[Expression, PrimitiveType]): The other object to be added.
74+
75+
Return:
76+
`sagemaker.workflow.functions.Join`
77+
"""
78+
from sagemaker.workflow.functions import Join
79+
80+
if isinstance(self, Join):
81+
self.values.append(other) # pylint: disable=no-member
82+
return self
83+
84+
return Join(on="", values=[self, other])
85+
86+
def __str__(self):
87+
"""String function for PipelineVariable"""
88+
return json.dumps(self._expr_helper(True))
89+
90+
def _expr_helper(
91+
self, cast_to_string: bool = False # pylint: disable=unused-argument
92+
) -> RequestType:
93+
"""Get the expression structure for __str__ call.
94+
95+
Args:
96+
cast_to_string (bool): To indicate if the expression value
97+
is need to be casted to string in runtime.
98+
99+
Return:
100+
Union[Dict[str, Any], List[Dict[str, Any]]]
101+
"""
102+
return self.expr
103+
104+
@property
105+
@abc.abstractmethod
106+
def expr(self) -> RequestType:
107+
"""Get the expression structure for workflow service calls."""
108+
109+
def startswith(
110+
self,
111+
prefix: Union[str, tuple], # pylint: disable=unused-argument
112+
start: Optional[int] = None, # pylint: disable=unused-argument
113+
end: Optional[int] = None, # pylint: disable=unused-argument
114+
) -> bool:
115+
"""Simulate the Python string's built-in method: startswith
116+
117+
Args:
118+
prefix (str, tuple): The (tuple of) string to be checked.
119+
start (int): To set the start index of the matching boundary (default: None).
120+
end (int): To set the end index of the matching boundary (default: None).
121+
122+
Return:
123+
bool: Always return False as Pipeline variables are parsed during execution runtime
124+
"""
125+
return False

src/sagemaker/workflow/execution_variables.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
"""Pipeline parameters and conditions for workflow."""
1414
from __future__ import absolute_import
1515

16+
import json
17+
1618
from sagemaker.workflow.entities import (
17-
Expression,
1819
RequestType,
20+
PipelineVariable,
1921
)
2022

2123

22-
class ExecutionVariable(Expression):
24+
class ExecutionVariable(PipelineVariable):
2325
"""Pipeline execution variables for workflow."""
2426

2527
def __init__(self, name: str):
@@ -30,6 +32,10 @@ def __init__(self, name: str):
3032
"""
3133
self.name = name
3234

35+
def __str__(self):
36+
"""String function for ExecutionVariable"""
37+
return json.dumps(self.expr)
38+
3339
@property
3440
def expr(self) -> RequestType:
3541
"""The 'Get' expression dict for an `ExecutionVariable`."""

src/sagemaker/workflow/functions.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16+
import json
1617
from typing import List, Union
1718

1819
import attr
1920

20-
from sagemaker.workflow.entities import Expression
21+
from sagemaker.workflow.entities import PipelineVariable
2122
from sagemaker.workflow.properties import PropertyFile
2223

2324

2425
@attr.s
25-
class Join(Expression):
26+
class Join(PipelineVariable):
2627
"""Join together properties.
2728
2829
Examples:
@@ -38,27 +39,39 @@ class Join(Expression):
3839
Attributes:
3940
values (List[Union[PrimitiveType, Parameter, Expression]]):
4041
The primitive type values, parameters, step properties, expressions to join.
41-
on_str (str): The string to join the values on (Defaults to "").
42+
on (str): The string to join the values on (Defaults to "").
4243
"""
4344

4445
on: str = attr.ib(factory=str)
4546
values: List = attr.ib(factory=list)
4647

48+
def __str__(self):
49+
"""String function for Join"""
50+
return json.dumps(self.expr)
51+
4752
@property
4853
def expr(self):
4954
"""The expression dict for a `Join` function."""
55+
from sagemaker.workflow.utilities import is_pipeline_variable_expr_string
56+
57+
values = list()
58+
for value in self.values:
59+
if hasattr(value, "expr"):
60+
values.append(value.expr)
61+
elif is_pipeline_variable_expr_string(value):
62+
values.append(json.loads(value))
63+
else:
64+
values.append(value)
5065
return {
5166
"Std:Join": {
5267
"On": self.on,
53-
"Values": [
54-
value.expr if hasattr(value, "expr") else value for value in self.values
55-
],
68+
"Values": values,
5669
},
5770
}
5871

5972

6073
@attr.s
61-
class JsonGet(Expression):
74+
class JsonGet(PipelineVariable):
6275
"""Get JSON properties from PropertyFiles.
6376
6477
Attributes:
@@ -72,6 +85,21 @@ class JsonGet(Expression):
7285
property_file: Union[PropertyFile, str] = attr.ib()
7386
json_path: str = attr.ib()
7487

88+
def _expr_helper(self, cast_to_string: bool = False):
89+
"""Get the expression structure for __str__ call.
90+
91+
Args:
92+
cast_to_string (bool): To indicate if the expression value
93+
is need to be casted to string in runtime.
94+
95+
Return:
96+
Union[Dict[str, Any], List[Dict[str, Any]]]
97+
"""
98+
expression = self.expr
99+
if cast_to_string:
100+
expression["Std:JsonGet"]["CastToString"] = cast_to_string
101+
return expression
102+
75103
@property
76104
def expr(self):
77105
"""The expression dict for a `JsonGet` function."""
@@ -82,6 +110,7 @@ def expr(self):
82110
name = self.property_file.name
83111
else:
84112
name = self.property_file
113+
85114
return {
86115
"Std:JsonGet": {
87116
"PropertyFile": {"Get": f"Steps.{self.step_name}.PropertyFiles.{name}"},

src/sagemaker/workflow/parameters.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Pipeline parameters and conditions for workflow."""
1414
from __future__ import absolute_import
1515

16+
import json
1617
from enum import Enum
1718
from functools import partial
1819
from typing import Dict, List, Type
@@ -24,6 +25,7 @@
2425
Entity,
2526
PrimitiveType,
2627
RequestType,
28+
PipelineVariable,
2729
)
2830

2931

@@ -48,7 +50,7 @@ def python_type(self) -> Type:
4850

4951

5052
@attr.s
51-
class Parameter(Entity):
53+
class Parameter(PipelineVariable, Entity):
5254
"""Pipeline parameter for workflow.
5355
5456
Attributes:
@@ -84,6 +86,21 @@ def to_request(self) -> RequestType:
8486
value["DefaultValue"] = self.default_value
8587
return value
8688

89+
def _expr_helper(self, cast_to_string: bool = False) -> RequestType:
90+
"""Get the expression structure for __str__ call.
91+
92+
Args:
93+
cast_to_string (bool): To indicate if the expression value
94+
is need to be casted to string in runtime.
95+
96+
Return:
97+
Union[Dict[str, Any], List[Dict[str, Any]]]
98+
"""
99+
expression = self.expr
100+
if cast_to_string:
101+
expression["Get"] += "._CastToString"
102+
return expression
103+
87104
@property
88105
def expr(self) -> Dict[str, str]:
89106
"""The 'Get' expression dict for a `Parameter`."""
@@ -170,6 +187,10 @@ def __hash__(self):
170187
"""Hash function for parameter types"""
171188
return hash(tuple(self.to_request()))
172189

190+
def __str__(self):
191+
"""String function for ParameterString"""
192+
return json.dumps(self.expr)
193+
173194
def to_request(self) -> RequestType:
174195
"""Get the request structure for workflow service calls."""
175196
request_dict = super(ParameterString, self).to_request()

src/sagemaker/workflow/pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from sagemaker.workflow.properties import Properties
4040
from sagemaker.workflow.steps import Step
4141
from sagemaker.workflow.step_collections import StepCollection
42-
from sagemaker.workflow.utilities import list_to_request
42+
from sagemaker.workflow.utilities import list_to_request, is_pipeline_variable_expr_string
4343

4444

4545
@attr.s
@@ -309,7 +309,6 @@ def definition(self) -> str:
309309
callback_output_to_step_map=callback_output_to_step_map,
310310
lambda_output_to_step_map=lambda_output_to_step_name,
311311
)
312-
313312
return json.dumps(request_dict)
314313

315314

@@ -380,6 +379,8 @@ def _interpolate(
380379
interpolate(value, callback_output_to_step_map, lambda_output_to_step_map)
381380
for value in obj
382381
)
382+
elif is_pipeline_variable_expr_string(obj):
383+
return json.loads(obj)
383384
else:
384385
return obj
385386
return new

src/sagemaker/workflow/properties.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
"""The properties definitions for workflow."""
1414
from __future__ import absolute_import
1515

16+
from abc import ABCMeta
1617
from typing import Dict, Union, List
1718

1819
import attr
1920

2021
import botocore.loaders
2122

22-
from sagemaker.workflow.entities import Expression
23+
from sagemaker.workflow.entities import Expression, PipelineVariable, RequestType
2324

2425

25-
class PropertiesMeta(type):
26+
class PropertiesMeta(ABCMeta):
2627
"""Load an internal shapes attribute from the botocore service model
2728
2829
for sagemaker and emr service.
@@ -44,7 +45,7 @@ def __new__(mcs, *args, **kwargs):
4445
return super().__new__(mcs, *args, **kwargs)
4546

4647

47-
class Properties(metaclass=PropertiesMeta):
48+
class Properties(PipelineVariable, metaclass=PropertiesMeta):
4849
"""Properties for use in workflow expressions."""
4950

5051
def __init__(
@@ -61,6 +62,7 @@ def __init__(
6162
shape_name (str): The botocore service model shape name.
6263
shape_names (str): A List of the botocore service model shape name.
6364
"""
65+
# super().__init__()
6466
self._path = path
6567
shape_names = [] if shape_names is None else shape_names
6668
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names
@@ -88,6 +90,21 @@ def __init__(
8890
f"{path}.{key}", info["shape"], service_name=service_name
8991
)
9092

93+
def _expr_helper(self, cast_to_string: bool = False) -> RequestType:
94+
"""Get the expression structure for __str__ call.
95+
96+
Args:
97+
cast_to_string (bool): To indicate if the expression value
98+
is need to be casted to string in runtime.
99+
100+
Return:
101+
Union[Dict[str, Any], List[Dict[str, Any]]]
102+
"""
103+
expression = self.expr
104+
if cast_to_string:
105+
expression["Get"] += "._CastToString"
106+
return expression
107+
91108
@property
92109
def expr(self):
93110
"""The 'Get' expression dict for a `Properties`."""

0 commit comments

Comments
 (0)