Skip to content

Commit 61d6a3d

Browse files
committed
feat: enable EnableInterContainerTrafficEncryption for model monitoring (aws#3010)
1 parent d74befa commit 61d6a3d

File tree

5 files changed

+15
-112
lines changed

5 files changed

+15
-112
lines changed

src/sagemaker/model_monitor/clarify_model_monitoring.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ def _build_create_job_definition_request(
397397

398398
if network_config is not None:
399399
network_config_dict = network_config._to_request_dict()
400-
self._validate_network_config(network_config_dict)
401400
request_dict["NetworkConfig"] = network_config_dict
402401
elif existing_network_config is not None:
403402
request_dict["NetworkConfig"] = existing_network_config

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def create_monitoring_schedule(
295295
network_config_dict = None
296296
if self.network_config is not None:
297297
network_config_dict = self.network_config._to_request_dict()
298-
self._validate_network_config(network_config_dict)
299298

300299
self.sagemaker_session.create_monitoring_schedule(
301300
monitoring_schedule_name=self.monitoring_schedule_name,
@@ -448,7 +447,6 @@ def update_monitoring_schedule(
448447
network_config_dict = None
449448
if self.network_config is not None:
450449
network_config_dict = self.network_config._to_request_dict()
451-
self._validate_network_config(network_config_dict)
452450

453451
self.sagemaker_session.update_monitoring_schedule(
454452
monitoring_schedule_name=self.monitoring_schedule_name,
@@ -708,6 +706,9 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
708706
if network_config_dict:
709707
network_config = NetworkConfig(
710708
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
709+
encrypt_inter_container_traffic=network_config_dict[
710+
"EnableInterContainerTrafficEncryption"
711+
],
711712
security_group_ids=security_group_ids,
712713
subnets=subnets,
713714
)
@@ -784,6 +785,9 @@ def _attach(clazz, sagemaker_session, schedule_desc, job_desc, tags):
784785
if network_config_dict:
785786
network_config = NetworkConfig(
786787
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
788+
encrypt_inter_container_traffic=network_config_dict[
789+
"EnableInterContainerTrafficEncryption"
790+
],
787791
security_group_ids=security_group_ids,
788792
subnets=subnets,
789793
)
@@ -1164,31 +1168,6 @@ def _wait_for_schedule_changes_to_apply(self):
11641168
if schedule_desc["MonitoringScheduleStatus"] != "Pending":
11651169
break
11661170

1167-
def _validate_network_config(self, network_config_dict):
1168-
"""Function to validate EnableInterContainerTrafficEncryption.
1169-
1170-
It validates EnableInterContainerTrafficEncryption is not set in the provided
1171-
NetworkConfig request dictionary.
1172-
1173-
Args:
1174-
network_config_dict (dict): NetworkConfig request dictionary.
1175-
Contains parameters from :class:`~sagemaker.network.NetworkConfig` object
1176-
that configures network isolation, encryption of
1177-
inter-container traffic, security group IDs, and subnets.
1178-
1179-
"""
1180-
if "EnableInterContainerTrafficEncryption" in network_config_dict:
1181-
message = (
1182-
"EnableInterContainerTrafficEncryption is not supported in Model Monitor. "
1183-
"Please ensure that encrypt_inter_container_traffic=None "
1184-
"when creating your NetworkConfig object. "
1185-
"Current encrypt_inter_container_traffic value: {}".format(
1186-
self.network_config.encrypt_inter_container_traffic
1187-
)
1188-
)
1189-
_LOGGER.info(message)
1190-
raise ValueError(message)
1191-
11921171
@classmethod
11931172
def monitoring_type(cls):
11941173
"""Type of the monitoring job."""
@@ -1781,7 +1760,6 @@ def update_monitoring_schedule(
17811760
network_config_dict = None
17821761
if self.network_config is not None:
17831762
network_config_dict = self.network_config._to_request_dict()
1784-
super(DefaultModelMonitor, self)._validate_network_config(network_config_dict)
17851763

17861764
if role is not None:
17871765
self.role = role
@@ -2034,6 +2012,9 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
20342012
subnets = vpc_config.get("Subnets")
20352013
network_config = NetworkConfig(
20362014
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
2015+
encrypt_inter_container_traffic=network_config_dict[
2016+
"EnableInterContainerTrafficEncryption"
2017+
],
20372018
security_group_ids=security_group_ids,
20382019
subnets=subnets,
20392020
)
@@ -2304,7 +2285,6 @@ def _build_create_data_quality_job_definition_request(
23042285

23052286
if network_config is not None:
23062287
network_config_dict = network_config._to_request_dict()
2307-
self._validate_network_config(network_config_dict)
23082288
request_dict["NetworkConfig"] = network_config_dict
23092289
elif existing_network_config is not None:
23102290
request_dict["NetworkConfig"] = existing_network_config
@@ -3007,7 +2987,6 @@ def _build_create_model_quality_job_definition_request(
30072987

30082988
if network_config is not None:
30092989
network_config_dict = network_config._to_request_dict()
3010-
self._validate_network_config(network_config_dict)
30112990
request_dict["NetworkConfig"] = network_config_dict
30122991
elif existing_network_config is not None:
30132992
request_dict["NetworkConfig"] = existing_network_config

tests/integ/test_model_monitor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@
5353
TAG_KEY_1 = "tag_key_1"
5454
TAG_VALUE_1 = "tag_value_1"
5555
TAGS = [{"Key": TAG_KEY_1, "Value": TAG_VALUE_1}]
56-
NETWORK_CONFIG = NetworkConfig(enable_network_isolation=True)
56+
NETWORK_CONFIG = NetworkConfig(
57+
enable_network_isolation=True,
58+
encrypt_inter_container_traffic=True,
59+
)
5760
ENABLE_CLOUDWATCH_METRICS = True
5861

5962
DEFAULT_BASELINING_MAX_RUNTIME_IN_SECONDS = 86400

tests/unit/sagemaker/monitor/test_clarify_model_monitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
SUBNETS = ["test_subnets"]
8282
NETWORK_CONFIG = NetworkConfig(
8383
enable_network_isolation=False,
84+
encrypt_inter_container_traffic=False,
8485
security_group_ids=SECURITY_GROUP_IDS,
8586
subnets=SUBNETS,
8687
)

tests/unit/sagemaker/monitor/test_model_monitoring.py

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
CronExpressionGenerator,
2424
DefaultModelMonitor,
2525
EndpointInput,
26-
ModelMonitor,
2726
ModelQualityMonitor,
28-
MonitoringOutput,
2927
Statistics,
3028
)
3129

@@ -52,7 +50,7 @@
5250
TAG_KEY_1 = "tag_key_1"
5351
TAG_VALUE_1 = "tag_value_1"
5452
TAGS = [{"Key": TAG_KEY_1, "Value": TAG_VALUE_1}]
55-
NETWORK_CONFIG = NetworkConfig(enable_network_isolation=False)
53+
NETWORK_CONFIG = NetworkConfig(enable_network_isolation=False, encrypt_inter_container_traffic=True)
5654
ENABLE_CLOUDWATCH_METRICS = True
5755
PROBLEM_TYPE = "Regression"
5856
GROUND_TRUTH_ATTRIBUTE = "TestAttribute"
@@ -429,53 +427,6 @@ def test_default_model_monitor_suggest_baseline(sagemaker_session):
429427
assert my_default_monitor.env[ENV_KEY_1] == ENV_VALUE_1
430428

431429

432-
def test_default_model_monitor_with_invalid_network_config(sagemaker_session):
433-
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
434-
my_default_monitor = DefaultModelMonitor(
435-
role=ROLE, sagemaker_session=sagemaker_session, network_config=invalid_network_config
436-
)
437-
with pytest.raises(ValueError) as exception:
438-
my_default_monitor.create_monitoring_schedule(endpoint_input="test_endpoint")
439-
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
440-
441-
with pytest.raises(ValueError) as exception:
442-
my_default_monitor.update_monitoring_schedule()
443-
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
444-
445-
446-
def test_model_monitor_without_network_config(sagemaker_session):
447-
my_model_monitor = ModelMonitor(
448-
role=ROLE,
449-
image_uri=CUSTOM_IMAGE_URI,
450-
sagemaker_session=sagemaker_session,
451-
)
452-
model_monitor_schedule_name = "model-monitoring-without-network-config"
453-
attached = my_model_monitor.attach(model_monitor_schedule_name, sagemaker_session)
454-
assert attached.network_config is None
455-
456-
457-
def test_model_monitor_with_invalid_network_config(sagemaker_session):
458-
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
459-
my_model_monitor = ModelMonitor(
460-
role=ROLE,
461-
image_uri=CUSTOM_IMAGE_URI,
462-
sagemaker_session=sagemaker_session,
463-
network_config=invalid_network_config,
464-
)
465-
with pytest.raises(ValueError) as exception:
466-
my_model_monitor.create_monitoring_schedule(
467-
endpoint_input="test_endpoint",
468-
output=MonitoringOutput(
469-
source="/opt/ml/processing/output", destination="/opt/ml/processing/output"
470-
),
471-
)
472-
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
473-
474-
with pytest.raises(ValueError) as exception:
475-
my_model_monitor.update_monitoring_schedule()
476-
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
477-
478-
479430
def test_data_quality_monitor_suggest_baseline(sagemaker_session, data_quality_monitor):
480431
data_quality_monitor.suggest_baseline(
481432
baseline_dataset=BASELINE_DATASET_PATH,
@@ -639,20 +590,6 @@ def test_data_quality_monitor_update_failure(data_quality_monitor, sagemaker_ses
639590
data_quality_monitor.update_monitoring_schedule()
640591

641592

642-
def test_data_quality_monitor_with_invalid_network_config(sagemaker_session):
643-
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
644-
data_quality_monitor = DefaultModelMonitor(
645-
role=ROLE,
646-
sagemaker_session=sagemaker_session,
647-
network_config=invalid_network_config,
648-
)
649-
with pytest.raises(ValueError) as exception:
650-
data_quality_monitor.create_monitoring_schedule(
651-
endpoint_input="test_endpoint",
652-
)
653-
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
654-
655-
656593
def _test_data_quality_monitor_create_schedule(
657594
data_quality_monitor,
658595
sagemaker_session,
@@ -1053,22 +990,6 @@ def test_model_quality_monitor_update_failure(model_quality_monitor, sagemaker_s
1053990
model_quality_monitor.update_monitoring_schedule()
1054991

1055992

1056-
def test_model_quality_monitor_with_invalid_network_config(sagemaker_session):
1057-
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
1058-
model_quality_monitor = ModelQualityMonitor(
1059-
role=ROLE,
1060-
sagemaker_session=sagemaker_session,
1061-
network_config=invalid_network_config,
1062-
)
1063-
with pytest.raises(ValueError) as exception:
1064-
model_quality_monitor.create_monitoring_schedule(
1065-
endpoint_input="test_endpoint",
1066-
problem_type=PROBLEM_TYPE,
1067-
ground_truth_input=GROUND_TRUTH_S3_URI,
1068-
)
1069-
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
1070-
1071-
1072993
def _test_model_quality_monitor_create_schedule(
1073994
model_quality_monitor,
1074995
sagemaker_session,

0 commit comments

Comments
 (0)