Skip to content

Commit 5f6282e

Browse files
authored
Merge branch 'master' into serverless-support-cr
2 parents 965bbad + b0f2b4a commit 5f6282e

32 files changed

+832
-204
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from glob import glob
1818

19-
from setuptools import setup, find_packages
19+
from setuptools import find_packages, setup
2020

2121

2222
def read(fname):
@@ -81,6 +81,7 @@ def read_version():
8181
"fabric==2.6.0",
8282
"requests==2.27.1",
8383
"sagemaker-experiments==0.1.35",
84+
"Jinja2==3.0.3",
8485
],
8586
)
8687

src/sagemaker/estimator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@
7474
get_config_value,
7575
name_from_base,
7676
)
77-
from sagemaker.workflow.entities import Expression
78-
from sagemaker.workflow.parameters import Parameter
79-
from sagemaker.workflow.properties import Properties
77+
from sagemaker.workflow.entities import PipelineVariable
8078

8179
logger = logging.getLogger(__name__)
8280

@@ -602,7 +600,7 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A
602600
current_hyperparameters = hyperparameters
603601
if current_hyperparameters is not None:
604602
hyperparameters = {
605-
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v))
603+
str(k): (v.to_string() if isinstance(v, PipelineVariable) else json.dumps(v))
606604
for (k, v) in current_hyperparameters.items()
607605
}
608606
return hyperparameters
@@ -1813,7 +1811,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
18131811
current_hyperparameters = estimator.hyperparameters()
18141812
if current_hyperparameters is not None:
18151813
hyperparameters = {
1816-
str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v))
1814+
str(k): (v.to_string() if isinstance(v, PipelineVariable) else str(v))
18171815
for (k, v) in current_hyperparameters.items()
18181816
}
18191817

src/sagemaker/model_monitor/clarify_model_monitoring.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ def _build_create_job_definition_request(
397397

398398
if network_config is not None:
399399
network_config_dict = network_config._to_request_dict()
400-
self._validate_network_config(network_config_dict)
401400
request_dict["NetworkConfig"] = network_config_dict
402401
elif existing_network_config is not None:
403402
request_dict["NetworkConfig"] = existing_network_config

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def create_monitoring_schedule(
295295
network_config_dict = None
296296
if self.network_config is not None:
297297
network_config_dict = self.network_config._to_request_dict()
298-
self._validate_network_config(network_config_dict)
299298

300299
self.sagemaker_session.create_monitoring_schedule(
301300
monitoring_schedule_name=self.monitoring_schedule_name,
@@ -448,7 +447,6 @@ def update_monitoring_schedule(
448447
network_config_dict = None
449448
if self.network_config is not None:
450449
network_config_dict = self.network_config._to_request_dict()
451-
self._validate_network_config(network_config_dict)
452450

453451
self.sagemaker_session.update_monitoring_schedule(
454452
monitoring_schedule_name=self.monitoring_schedule_name,
@@ -708,6 +706,9 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
708706
if network_config_dict:
709707
network_config = NetworkConfig(
710708
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
709+
encrypt_inter_container_traffic=network_config_dict[
710+
"EnableInterContainerTrafficEncryption"
711+
],
711712
security_group_ids=security_group_ids,
712713
subnets=subnets,
713714
)
@@ -784,6 +785,9 @@ def _attach(clazz, sagemaker_session, schedule_desc, job_desc, tags):
784785
if network_config_dict:
785786
network_config = NetworkConfig(
786787
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
788+
encrypt_inter_container_traffic=network_config_dict[
789+
"EnableInterContainerTrafficEncryption"
790+
],
787791
security_group_ids=security_group_ids,
788792
subnets=subnets,
789793
)
@@ -1164,31 +1168,6 @@ def _wait_for_schedule_changes_to_apply(self):
11641168
if schedule_desc["MonitoringScheduleStatus"] != "Pending":
11651169
break
11661170

1167-
def _validate_network_config(self, network_config_dict):
1168-
"""Function to validate EnableInterContainerTrafficEncryption.
1169-
1170-
It validates EnableInterContainerTrafficEncryption is not set in the provided
1171-
NetworkConfig request dictionary.
1172-
1173-
Args:
1174-
network_config_dict (dict): NetworkConfig request dictionary.
1175-
Contains parameters from :class:`~sagemaker.network.NetworkConfig` object
1176-
that configures network isolation, encryption of
1177-
inter-container traffic, security group IDs, and subnets.
1178-
1179-
"""
1180-
if "EnableInterContainerTrafficEncryption" in network_config_dict:
1181-
message = (
1182-
"EnableInterContainerTrafficEncryption is not supported in Model Monitor. "
1183-
"Please ensure that encrypt_inter_container_traffic=None "
1184-
"when creating your NetworkConfig object. "
1185-
"Current encrypt_inter_container_traffic value: {}".format(
1186-
self.network_config.encrypt_inter_container_traffic
1187-
)
1188-
)
1189-
_LOGGER.info(message)
1190-
raise ValueError(message)
1191-
11921171
@classmethod
11931172
def monitoring_type(cls):
11941173
"""Type of the monitoring job."""
@@ -1781,7 +1760,6 @@ def update_monitoring_schedule(
17811760
network_config_dict = None
17821761
if self.network_config is not None:
17831762
network_config_dict = self.network_config._to_request_dict()
1784-
super(DefaultModelMonitor, self)._validate_network_config(network_config_dict)
17851763

17861764
if role is not None:
17871765
self.role = role
@@ -2034,6 +2012,9 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
20342012
subnets = vpc_config.get("Subnets")
20352013
network_config = NetworkConfig(
20362014
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
2015+
encrypt_inter_container_traffic=network_config_dict[
2016+
"EnableInterContainerTrafficEncryption"
2017+
],
20372018
security_group_ids=security_group_ids,
20382019
subnets=subnets,
20392020
)
@@ -2304,7 +2285,6 @@ def _build_create_data_quality_job_definition_request(
23042285

23052286
if network_config is not None:
23062287
network_config_dict = network_config._to_request_dict()
2307-
self._validate_network_config(network_config_dict)
23082288
request_dict["NetworkConfig"] = network_config_dict
23092289
elif existing_network_config is not None:
23102290
request_dict["NetworkConfig"] = existing_network_config
@@ -3007,7 +2987,6 @@ def _build_create_model_quality_job_definition_request(
30072987

30082988
if network_config is not None:
30092989
network_config_dict = network_config._to_request_dict()
3010-
self._validate_network_config(network_config_dict)
30112990
request_dict["NetworkConfig"] = network_config_dict
30122991
elif existing_network_config is not None:
30132992
request_dict["NetworkConfig"] = existing_network_config

src/sagemaker/parameter.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
from __future__ import absolute_import
1515

1616
import json
17-
from sagemaker.workflow.parameters import Parameter as PipelineParameter
18-
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
19-
from sagemaker.workflow.functions import Join as PipelineJoin
17+
18+
from sagemaker.workflow.entities import PipelineVariable
2019

2120

2221
class ParameterRange(object):
@@ -73,11 +72,11 @@ def as_tuning_range(self, name):
7372
return {
7473
"Name": name,
7574
"MinValue": str(self.min_value)
76-
if not isinstance(self.min_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
77-
else self.min_value,
75+
if not isinstance(self.min_value, PipelineVariable)
76+
else self.min_value.to_string(),
7877
"MaxValue": str(self.max_value)
79-
if not isinstance(self.max_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
80-
else self.max_value,
78+
if not isinstance(self.max_value, PipelineVariable)
79+
else self.max_value.to_string(),
8180
"ScalingType": self.scaling_type,
8281
}
8382

@@ -112,8 +111,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called
112111
"""
113112
values = values if isinstance(values, list) else [values]
114113
self.values = [
115-
str(v) if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) else v
116-
for v in values
114+
str(v) if not isinstance(v, PipelineVariable) else v.to_string() for v in values
117115
]
118116

119117
def as_tuning_range(self, name):

src/sagemaker/tuner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
IntegerParameter,
3939
ParameterRange,
4040
)
41+
from sagemaker.workflow.entities import PipelineVariable
4142
from sagemaker.workflow.parameters import Parameter as PipelineParameter
4243
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
4344
from sagemaker.workflow.functions import Join as PipelineJoin
@@ -376,9 +377,7 @@ def _prepare_static_hyperparameters(
376377
"""Prepare static hyperparameters for one estimator before tuning."""
377378
# Remove any hyperparameter that will be tuned
378379
static_hyperparameters = {
379-
str(k): str(v)
380-
if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin))
381-
else v
380+
str(k): str(v) if not isinstance(v, PipelineVariable) else v.to_string()
382381
for (k, v) in estimator.hyperparameters().items()
383382
}
384383
for hyperparameter_name in hyperparameter_ranges.keys():

src/sagemaker/workflow/entities.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import abc
1717

1818
from enum import EnumMeta
19-
from typing import Any, Dict, List, Union
19+
from typing import Any, Dict, List, Union, Optional
2020

2121
PrimitiveType = Union[str, int, bool, float, None]
2222
RequestType = Union[Dict[str, Any], List[Dict[str, Any]]]
@@ -57,3 +57,80 @@ class Expression(abc.ABC):
5757
@abc.abstractmethod
5858
def expr(self) -> RequestType:
5959
"""Get the expression structure for workflow service calls."""
60+
61+
62+
class PipelineVariable(Expression):
63+
"""Base object for pipeline variables
64+
65+
PipelineVariables must implement the expr property.
66+
"""
67+
68+
def __add__(self, other: Union[Expression, PrimitiveType]):
69+
"""Add function for PipelineVariable
70+
71+
Args:
72+
other (Union[Expression, PrimitiveType]): The other object to be concatenated.
73+
74+
Always raise an error since pipeline variables do not support concatenation
75+
"""
76+
77+
raise TypeError("Pipeline variables do not support concatenation.")
78+
79+
def __str__(self):
80+
"""Override built-in String function for PipelineVariable"""
81+
raise TypeError("Pipeline variables do not support __str__ operation.")
82+
83+
def __int__(self):
84+
"""Override built-in Integer function for PipelineVariable"""
85+
raise TypeError("Pipeline variables do not support __int__ operation.")
86+
87+
def __float__(self):
88+
"""Override built-in Float function for PipelineVariable"""
89+
raise TypeError("Pipeline variables do not support __float__ operation.")
90+
91+
def to_string(self):
92+
"""Prompt the pipeline to convert the pipeline variable to String in runtime"""
93+
from sagemaker.workflow.functions import Join
94+
95+
return Join(on="", values=[self])
96+
97+
@property
98+
@abc.abstractmethod
99+
def expr(self) -> RequestType:
100+
"""Get the expression structure for workflow service calls."""
101+
102+
def startswith(
103+
self,
104+
prefix: Union[str, tuple], # pylint: disable=unused-argument
105+
start: Optional[int] = None, # pylint: disable=unused-argument
106+
end: Optional[int] = None, # pylint: disable=unused-argument
107+
) -> bool:
108+
"""Simulate the Python string's built-in method: startswith
109+
110+
Args:
111+
prefix (str, tuple): The (tuple of) string to be checked.
112+
start (int): To set the start index of the matching boundary (default: None).
113+
end (int): To set the end index of the matching boundary (default: None).
114+
115+
Return:
116+
bool: Always return False as Pipeline variables are parsed during execution runtime
117+
"""
118+
return False
119+
120+
def endswith(
121+
self,
122+
suffix: Union[str, tuple], # pylint: disable=unused-argument
123+
start: Optional[int] = None, # pylint: disable=unused-argument
124+
end: Optional[int] = None, # pylint: disable=unused-argument
125+
) -> bool:
126+
"""Simulate the Python string's built-in method: endswith
127+
128+
Args:
129+
suffix (str, tuple): The (tuple of) string to be checked.
130+
start (int): To set the start index of the matching boundary (default: None).
131+
end (int): To set the end index of the matching boundary (default: None).
132+
133+
Return:
134+
bool: Always return False as Pipeline variables are parsed during execution runtime
135+
"""
136+
return False

src/sagemaker/workflow/execution_variables.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.workflow.entities import (
17-
Expression,
1817
RequestType,
18+
PipelineVariable,
1919
)
2020

2121

22-
class ExecutionVariable(Expression):
22+
class ExecutionVariable(PipelineVariable):
2323
"""Pipeline execution variables for workflow."""
2424

2525
def __init__(self, name: str):
@@ -30,6 +30,13 @@ def __init__(self, name: str):
3030
"""
3131
self.name = name
3232

33+
def to_string(self) -> PipelineVariable:
34+
"""Prompt the pipeline to convert the pipeline variable to String in runtime
35+
36+
As ExecutionVariable is treated as String in runtime, no extra actions are needed.
37+
"""
38+
return self
39+
3340
@property
3441
def expr(self) -> RequestType:
3542
"""The 'Get' expression dict for an `ExecutionVariable`."""

src/sagemaker/workflow/functions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
import attr
1919

20-
from sagemaker.workflow.entities import Expression
20+
from sagemaker.workflow.entities import PipelineVariable
2121
from sagemaker.workflow.properties import PropertyFile
2222

2323

2424
@attr.s
25-
class Join(Expression):
25+
class Join(PipelineVariable):
2626
"""Join together properties.
2727
2828
Examples:
@@ -38,15 +38,23 @@ class Join(Expression):
3838
Attributes:
3939
values (List[Union[PrimitiveType, Parameter, Expression]]):
4040
The primitive type values, parameters, step properties, expressions to join.
41-
on_str (str): The string to join the values on (Defaults to "").
41+
on (str): The string to join the values on (Defaults to "").
4242
"""
4343

4444
on: str = attr.ib(factory=str)
4545
values: List = attr.ib(factory=list)
4646

47+
def to_string(self) -> PipelineVariable:
48+
"""Prompt the pipeline to convert the pipeline variable to String in runtime
49+
50+
As Join is treated as String in runtime, no extra actions are needed.
51+
"""
52+
return self
53+
4754
@property
4855
def expr(self):
4956
"""The expression dict for a `Join` function."""
57+
5058
return {
5159
"Std:Join": {
5260
"On": self.on,
@@ -58,7 +66,7 @@ def expr(self):
5866

5967

6068
@attr.s
61-
class JsonGet(Expression):
69+
class JsonGet(PipelineVariable):
6270
"""Get JSON properties from PropertyFiles.
6371
6472
Attributes:

0 commit comments

Comments
 (0)