Skip to content

change: Implement override solution for pipeline variables #2995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@
get_config_value,
name_from_base,
)
from sagemaker.workflow.entities import Expression
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -602,7 +600,7 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A
current_hyperparameters = hyperparameters
if current_hyperparameters is not None:
hyperparameters = {
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v))
str(k): (v.to_string() if isinstance(v, PipelineVariable) else json.dumps(v))
for (k, v) in current_hyperparameters.items()
}
return hyperparameters
Expand Down Expand Up @@ -1813,7 +1811,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
current_hyperparameters = estimator.hyperparameters()
if current_hyperparameters is not None:
hyperparameters = {
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
str(k): (v.to_string() if isinstance(v, PipelineVariable) else str(v))
for (k, v) in current_hyperparameters.items()
}

Expand Down
16 changes: 7 additions & 9 deletions src/sagemaker/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
from __future__ import absolute_import

import json
from sagemaker.workflow.parameters import Parameter as PipelineParameter
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
from sagemaker.workflow.functions import Join as PipelineJoin

from sagemaker.workflow.entities import PipelineVariable


class ParameterRange(object):
Expand Down Expand Up @@ -73,11 +72,11 @@ def as_tuning_range(self, name):
return {
"Name": name,
"MinValue": str(self.min_value)
if not isinstance(self.min_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
else self.min_value,
if not isinstance(self.min_value, PipelineVariable)
else self.min_value.to_string(),
"MaxValue": str(self.max_value)
if not isinstance(self.max_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
else self.max_value,
if not isinstance(self.max_value, PipelineVariable)
else self.max_value.to_string(),
"ScalingType": self.scaling_type,
}

Expand Down Expand Up @@ -112,8 +111,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
"""
values = values if isinstance(values, list) else [values]
self.values = [
str(v) if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) else v
for v in values
str(v) if not isinstance(v, PipelineVariable) else v.to_string() for v in values
]

def as_tuning_range(self, name):
Expand Down
5 changes: 2 additions & 3 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
IntegerParameter,
ParameterRange,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import Parameter as PipelineParameter
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
from sagemaker.workflow.functions import Join as PipelineJoin
Expand Down Expand Up @@ -376,9 +377,7 @@ def _prepare_static_hyperparameters(
"""Prepare static hyperparameters for one estimator before tuning."""
# Remove any hyperparameter that will be tuned
static_hyperparameters = {
str(k): str(v)
if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin))
else v
str(k): str(v) if not isinstance(v, PipelineVariable) else v.to_string()
for (k, v) in estimator.hyperparameters().items()
}
for hyperparameter_name in hyperparameter_ranges.keys():
Expand Down
79 changes: 78 additions & 1 deletion src/sagemaker/workflow/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import abc

from enum import EnumMeta
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Optional

PrimitiveType = Union[str, int, bool, float, None]
RequestType = Union[Dict[str, Any], List[Dict[str, Any]]]
Expand Down Expand Up @@ -57,3 +57,80 @@ class Expression(abc.ABC):
@abc.abstractmethod
def expr(self) -> RequestType:
"""Get the expression structure for workflow service calls."""


class PipelineVariable(Expression):
"""Base object for pipeline variables

PipelineVariables must implement the expr property.
"""

def __add__(self, other: Union[Expression, PrimitiveType]):
"""Add function for PipelineVariable

Args:
other (Union[Expression, PrimitiveType]): The other object to be concatenated.

Always raise an error since pipeline variables do not support concatenation
"""

raise TypeError("Pipeline variables do not support concatenation.")

def __str__(self):
"""Override built-in String function for PipelineVariable"""
raise TypeError("Pipeline variables do not support __str__ operation.")

def __int__(self):
"""Override built-in Integer function for PipelineVariable"""
raise TypeError("Pipeline variables do not support __int__ operation.")

def __float__(self):
"""Override built-in Float function for PipelineVariable"""
raise TypeError("Pipeline variables do not support __float__ operation.")

def to_string(self):
"""Prompt the pipeline to convert the pipeline variable to String in runtime"""
from sagemaker.workflow.functions import Join

return Join(on="", values=[self])

@property
@abc.abstractmethod
def expr(self) -> RequestType:
"""Get the expression structure for workflow service calls."""

def startswith(
self,
prefix: Union[str, tuple], # pylint: disable=unused-argument
start: Optional[int] = None, # pylint: disable=unused-argument
end: Optional[int] = None, # pylint: disable=unused-argument
) -> bool:
"""Simulate the Python string's built-in method: startswith

Args:
prefix (str, tuple): The (tuple of) string to be checked.
start (int): To set the start index of the matching boundary (default: None).
end (int): To set the end index of the matching boundary (default: None).

Return:
bool: Always return False as Pipeline variables are parsed during execution runtime
"""
return False

def endswith(
self,
suffix: Union[str, tuple], # pylint: disable=unused-argument
start: Optional[int] = None, # pylint: disable=unused-argument
end: Optional[int] = None, # pylint: disable=unused-argument
) -> bool:
"""Simulate the Python string's built-in method: endswith

Args:
suffix (str, tuple): The (tuple of) string to be checked.
start (int): To set the start index of the matching boundary (default: None).
end (int): To set the end index of the matching boundary (default: None).

Return:
bool: Always return False as Pipeline variables are parsed during execution runtime
"""
return False
11 changes: 9 additions & 2 deletions src/sagemaker/workflow/execution_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from __future__ import absolute_import

from sagemaker.workflow.entities import (
Expression,
RequestType,
PipelineVariable,
)


class ExecutionVariable(Expression):
class ExecutionVariable(PipelineVariable):
"""Pipeline execution variables for workflow."""

def __init__(self, name: str):
Expand All @@ -30,6 +30,13 @@ def __init__(self, name: str):
"""
self.name = name

def to_string(self) -> PipelineVariable:
"""Prompt the pipeline to convert the pipeline variable to String in runtime

As ExecutionVariable is treated as String in runtime, no extra actions are needed.
"""
return self

@property
def expr(self) -> RequestType:
"""The 'Get' expression dict for an `ExecutionVariable`."""
Expand Down
16 changes: 12 additions & 4 deletions src/sagemaker/workflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

import attr

from sagemaker.workflow.entities import Expression
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.properties import PropertyFile


@attr.s
class Join(Expression):
class Join(PipelineVariable):
"""Join together properties.

Examples:
Expand All @@ -38,15 +38,23 @@ class Join(Expression):
Attributes:
values (List[Union[PrimitiveType, Parameter, Expression]]):
The primitive type values, parameters, step properties, expressions to join.
on_str (str): The string to join the values on (Defaults to "").
on (str): The string to join the values on (Defaults to "").
"""

on: str = attr.ib(factory=str)
values: List = attr.ib(factory=list)

def to_string(self) -> PipelineVariable:
"""Prompt the pipeline to convert the pipeline variable to String in runtime

As Join is treated as String in runtime, no extra actions are needed.
"""
return self

@property
def expr(self):
"""The expression dict for a `Join` function."""

return {
"Std:Join": {
"On": self.on,
Expand All @@ -58,7 +66,7 @@ def expr(self):


@attr.s
class JsonGet(Expression):
class JsonGet(PipelineVariable):
"""Get JSON properties from PropertyFiles.

Attributes:
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/workflow/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Entity,
PrimitiveType,
RequestType,
PipelineVariable,
)


Expand All @@ -48,7 +49,7 @@ def python_type(self) -> Type:


@attr.s
class Parameter(Entity):
class Parameter(PipelineVariable, Entity):
"""Pipeline parameter for workflow.

Attributes:
Expand Down Expand Up @@ -170,6 +171,13 @@ def __hash__(self):
"""Hash function for parameter types"""
return hash(tuple(self.to_request()))

def to_string(self) -> PipelineVariable:
"""Prompt the pipeline to convert the pipeline variable to String in runtime

As ParameterString is treated as String in runtime, no extra actions are needed.
"""
return self

def to_request(self) -> RequestType:
"""Get the request structure for workflow service calls."""
request_dict = super(ParameterString, self).to_request()
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/workflow/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
"""The properties definitions for workflow."""
from __future__ import absolute_import

from abc import ABCMeta
from typing import Dict, Union, List

import attr

import botocore.loaders

from sagemaker.workflow.entities import Expression
from sagemaker.workflow.entities import Expression, PipelineVariable


class PropertiesMeta(type):
class PropertiesMeta(ABCMeta):
"""Load an internal shapes attribute from the botocore service model

for sagemaker and emr service.
Expand All @@ -44,7 +45,7 @@ def __new__(mcs, *args, **kwargs):
return super().__new__(mcs, *args, **kwargs)


class Properties(metaclass=PropertiesMeta):
class Properties(PipelineVariable, metaclass=PropertiesMeta):
"""Properties for use in workflow expressions."""

def __init__(
Expand Down
Loading