Skip to content

feat: enable EnableInterContainerTrafficEncryption for model monitoring #3010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/sagemaker/model_monitor/clarify_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def _build_create_job_definition_request(

if network_config is not None:
network_config_dict = network_config._to_request_dict()
self._validate_network_config(network_config_dict)
request_dict["NetworkConfig"] = network_config_dict
elif existing_network_config is not None:
request_dict["NetworkConfig"] = existing_network_config
Expand Down
39 changes: 9 additions & 30 deletions src/sagemaker/model_monitor/model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ def create_monitoring_schedule(
network_config_dict = None
if self.network_config is not None:
network_config_dict = self.network_config._to_request_dict()
self._validate_network_config(network_config_dict)

self.sagemaker_session.create_monitoring_schedule(
monitoring_schedule_name=self.monitoring_schedule_name,
Expand Down Expand Up @@ -448,7 +447,6 @@ def update_monitoring_schedule(
network_config_dict = None
if self.network_config is not None:
network_config_dict = self.network_config._to_request_dict()
self._validate_network_config(network_config_dict)

self.sagemaker_session.update_monitoring_schedule(
monitoring_schedule_name=self.monitoring_schedule_name,
Expand Down Expand Up @@ -708,6 +706,9 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
if network_config_dict:
network_config = NetworkConfig(
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
encrypt_inter_container_traffic=network_config_dict[
"EnableInterContainerTrafficEncryption"
],
security_group_ids=security_group_ids,
subnets=subnets,
)
Expand Down Expand Up @@ -784,6 +785,9 @@ def _attach(clazz, sagemaker_session, schedule_desc, job_desc, tags):
if network_config_dict:
network_config = NetworkConfig(
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
encrypt_inter_container_traffic=network_config_dict[
"EnableInterContainerTrafficEncryption"
],
security_group_ids=security_group_ids,
subnets=subnets,
)
Expand Down Expand Up @@ -1164,31 +1168,6 @@ def _wait_for_schedule_changes_to_apply(self):
if schedule_desc["MonitoringScheduleStatus"] != "Pending":
break

def _validate_network_config(self, network_config_dict):
"""Function to validate EnableInterContainerTrafficEncryption.

It validates EnableInterContainerTrafficEncryption is not set in the provided
NetworkConfig request dictionary.

Args:
network_config_dict (dict): NetworkConfig request dictionary.
Contains parameters from :class:`~sagemaker.network.NetworkConfig` object
that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets.

"""
if "EnableInterContainerTrafficEncryption" in network_config_dict:
message = (
"EnableInterContainerTrafficEncryption is not supported in Model Monitor. "
"Please ensure that encrypt_inter_container_traffic=None "
"when creating your NetworkConfig object. "
"Current encrypt_inter_container_traffic value: {}".format(
self.network_config.encrypt_inter_container_traffic
)
)
_LOGGER.info(message)
raise ValueError(message)

@classmethod
def monitoring_type(cls):
"""Type of the monitoring job."""
Expand Down Expand Up @@ -1781,7 +1760,6 @@ def update_monitoring_schedule(
network_config_dict = None
if self.network_config is not None:
network_config_dict = self.network_config._to_request_dict()
super(DefaultModelMonitor, self)._validate_network_config(network_config_dict)

if role is not None:
self.role = role
Expand Down Expand Up @@ -2034,6 +2012,9 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
subnets = vpc_config.get("Subnets")
network_config = NetworkConfig(
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],
encrypt_inter_container_traffic=network_config_dict[
"EnableInterContainerTrafficEncryption"
],
security_group_ids=security_group_ids,
subnets=subnets,
)
Expand Down Expand Up @@ -2304,7 +2285,6 @@ def _build_create_data_quality_job_definition_request(

if network_config is not None:
network_config_dict = network_config._to_request_dict()
self._validate_network_config(network_config_dict)
request_dict["NetworkConfig"] = network_config_dict
elif existing_network_config is not None:
request_dict["NetworkConfig"] = existing_network_config
Expand Down Expand Up @@ -3007,7 +2987,6 @@ def _build_create_model_quality_job_definition_request(

if network_config is not None:
network_config_dict = network_config._to_request_dict()
self._validate_network_config(network_config_dict)
request_dict["NetworkConfig"] = network_config_dict
elif existing_network_config is not None:
request_dict["NetworkConfig"] = existing_network_config
Expand Down
5 changes: 4 additions & 1 deletion tests/integ/test_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@
TAG_KEY_1 = "tag_key_1"
TAG_VALUE_1 = "tag_value_1"
TAGS = [{"Key": TAG_KEY_1, "Value": TAG_VALUE_1}]
NETWORK_CONFIG = NetworkConfig(enable_network_isolation=True)
NETWORK_CONFIG = NetworkConfig(
enable_network_isolation=True,
encrypt_inter_container_traffic=True,
)
ENABLE_CLOUDWATCH_METRICS = True

DEFAULT_BASELINING_MAX_RUNTIME_IN_SECONDS = 86400
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
SUBNETS = ["test_subnets"]
NETWORK_CONFIG = NetworkConfig(
enable_network_isolation=False,
encrypt_inter_container_traffic=False,
security_group_ids=SECURITY_GROUP_IDS,
subnets=SUBNETS,
)
Expand Down
81 changes: 1 addition & 80 deletions tests/unit/sagemaker/monitor/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
CronExpressionGenerator,
DefaultModelMonitor,
EndpointInput,
ModelMonitor,
ModelQualityMonitor,
MonitoringOutput,
Statistics,
)

Expand All @@ -52,7 +50,7 @@
TAG_KEY_1 = "tag_key_1"
TAG_VALUE_1 = "tag_value_1"
TAGS = [{"Key": TAG_KEY_1, "Value": TAG_VALUE_1}]
NETWORK_CONFIG = NetworkConfig(enable_network_isolation=False)
NETWORK_CONFIG = NetworkConfig(enable_network_isolation=False, encrypt_inter_container_traffic=True)
ENABLE_CLOUDWATCH_METRICS = True
PROBLEM_TYPE = "Regression"
GROUND_TRUTH_ATTRIBUTE = "TestAttribute"
Expand Down Expand Up @@ -429,53 +427,6 @@ def test_default_model_monitor_suggest_baseline(sagemaker_session):
assert my_default_monitor.env[ENV_KEY_1] == ENV_VALUE_1


def test_default_model_monitor_with_invalid_network_config(sagemaker_session):
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
my_default_monitor = DefaultModelMonitor(
role=ROLE, sagemaker_session=sagemaker_session, network_config=invalid_network_config
)
with pytest.raises(ValueError) as exception:
my_default_monitor.create_monitoring_schedule(endpoint_input="test_endpoint")
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)

with pytest.raises(ValueError) as exception:
my_default_monitor.update_monitoring_schedule()
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)


def test_model_monitor_without_network_config(sagemaker_session):
my_model_monitor = ModelMonitor(
role=ROLE,
image_uri=CUSTOM_IMAGE_URI,
sagemaker_session=sagemaker_session,
)
model_monitor_schedule_name = "model-monitoring-without-network-config"
attached = my_model_monitor.attach(model_monitor_schedule_name, sagemaker_session)
assert attached.network_config is None


def test_model_monitor_with_invalid_network_config(sagemaker_session):
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
my_model_monitor = ModelMonitor(
role=ROLE,
image_uri=CUSTOM_IMAGE_URI,
sagemaker_session=sagemaker_session,
network_config=invalid_network_config,
)
with pytest.raises(ValueError) as exception:
my_model_monitor.create_monitoring_schedule(
endpoint_input="test_endpoint",
output=MonitoringOutput(
source="/opt/ml/processing/output", destination="/opt/ml/processing/output"
),
)
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)

with pytest.raises(ValueError) as exception:
my_model_monitor.update_monitoring_schedule()
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)


def test_data_quality_monitor_suggest_baseline(sagemaker_session, data_quality_monitor):
data_quality_monitor.suggest_baseline(
baseline_dataset=BASELINE_DATASET_PATH,
Expand Down Expand Up @@ -639,20 +590,6 @@ def test_data_quality_monitor_update_failure(data_quality_monitor, sagemaker_ses
data_quality_monitor.update_monitoring_schedule()


def test_data_quality_monitor_with_invalid_network_config(sagemaker_session):
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
data_quality_monitor = DefaultModelMonitor(
role=ROLE,
sagemaker_session=sagemaker_session,
network_config=invalid_network_config,
)
with pytest.raises(ValueError) as exception:
data_quality_monitor.create_monitoring_schedule(
endpoint_input="test_endpoint",
)
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)


def _test_data_quality_monitor_create_schedule(
data_quality_monitor,
sagemaker_session,
Expand Down Expand Up @@ -1053,22 +990,6 @@ def test_model_quality_monitor_update_failure(model_quality_monitor, sagemaker_s
model_quality_monitor.update_monitoring_schedule()


def test_model_quality_monitor_with_invalid_network_config(sagemaker_session):
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
model_quality_monitor = ModelQualityMonitor(
role=ROLE,
sagemaker_session=sagemaker_session,
network_config=invalid_network_config,
)
with pytest.raises(ValueError) as exception:
model_quality_monitor.create_monitoring_schedule(
endpoint_input="test_endpoint",
problem_type=PROBLEM_TYPE,
ground_truth_input=GROUND_TRUTH_S3_URI,
)
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)


def _test_model_quality_monitor_create_schedule(
model_quality_monitor,
sagemaker_session,
Expand Down