Skip to content

Commit 2b812c8

Browse files
chuyang-dengChuyang Dengicywang86ruiajaykarpur
authored
fix: check optional keyword before accessing (#1911)
* fix: check optional keyword before accessing * refactor attach method Co-authored-by: Chuyang Deng <[email protected]> Co-authored-by: icywang86rui <[email protected]> Co-authored-by: Ajay Karpur <[email protected]>
1 parent 23d2b7f commit 2b812c8

File tree

2 files changed

+62
-36
lines changed

2 files changed

+62
-36
lines changed

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -641,46 +641,32 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
641641
monitoring_schedule_name=monitor_schedule_name
642642
)
643643

644-
role = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"]["RoleArn"]
645-
image_uri = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
646-
"MonitoringAppSpecification"
647-
]["ImageUri"]
648-
instance_count = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
649-
"MonitoringResources"
650-
]["ClusterConfig"]["InstanceCount"]
651-
instance_type = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
652-
"MonitoringResources"
653-
]["ClusterConfig"]["InstanceType"]
654-
entrypoint = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
655-
"MonitoringAppSpecification"
656-
].get("ContainerEntrypoint")
657-
volume_size_in_gb = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
658-
"MonitoringResources"
659-
]["ClusterConfig"]["VolumeSizeInGB"]
660-
volume_kms_key = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
661-
"MonitoringResources"
662-
]["ClusterConfig"].get("VolumeKmsKeyId")
663-
output_kms_key = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
664-
"MonitoringOutputConfig"
665-
].get("KmsKeyId")
644+
monitoring_job_definition = schedule_desc["MonitoringScheduleConfig"][
645+
"MonitoringJobDefinition"
646+
]
647+
role = monitoring_job_definition["RoleArn"]
648+
image_uri = monitoring_job_definition["MonitoringAppSpecification"].get("ImageUri")
649+
cluster_config = monitoring_job_definition["MonitoringResources"]["ClusterConfig"]
650+
instance_count = cluster_config.get("InstanceCount")
651+
instance_type = cluster_config["InstanceType"]
652+
volume_size_in_gb = cluster_config["VolumeSizeInGB"]
653+
volume_kms_key = cluster_config.get("VolumeKmsKeyId")
654+
entrypoint = monitoring_job_definition["MonitoringAppSpecification"].get(
655+
"ContainerEntrypoint"
656+
)
657+
output_kms_key = monitoring_job_definition["MonitoringOutputConfig"].get("KmsKeyId")
658+
network_config_dict = monitoring_job_definition.get("NetworkConfig")
666659

667660
max_runtime_in_seconds = None
668-
if schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"].get(
669-
"StoppingCondition"
670-
):
671-
max_runtime_in_seconds = schedule_desc["MonitoringScheduleConfig"][
672-
"MonitoringJobDefinition"
673-
]["StoppingCondition"].get("MaxRuntimeInSeconds")
661+
stopping_condition = monitoring_job_definition.get("StoppingCondition")
662+
if stopping_condition:
663+
max_runtime_in_seconds = stopping_condition.get("MaxRuntimeInSeconds")
674664

675-
env = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"]["Environment"]
676-
677-
network_config_dict = schedule_desc["MonitoringScheduleConfig"][
678-
"MonitoringJobDefinition"
679-
].get("NetworkConfig")
665+
env = monitoring_job_definition.get("Environment", None)
680666

681-
vpc_config = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
682-
"NetworkConfig"
683-
].get("VpcConfig")
667+
vpc_config = None
668+
if network_config_dict:
669+
vpc_config = network_config_dict.get("VpcConfig")
684670

685671
security_group_ids = None
686672
if vpc_config is not None:
@@ -690,6 +676,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
690676
if vpc_config is not None:
691677
subnets = vpc_config["Subnets"]
692678

679+
network_config = None
693680
if network_config_dict:
694681
network_config = NetworkConfig(
695682
enable_network_isolation=network_config_dict["EnableNetworkIsolation"],

tests/unit/sagemaker/monitor/test_model_monitoring.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,31 @@
6363
)
6464
"encrypt_inter_container_traffic=None when creating your NetworkConfig object."
6565

66+
MONITORING_SCHEDULE_DESC = {
67+
"MonitoringScheduleArn": "arn:aws:monitoring-schedule",
68+
"MonitoringScheduleName": "my-monitoring-schedule",
69+
"MonitoringScheduleConfig": {
70+
"MonitoringJobDefinition": {
71+
"MonitoringOutputConfig": {},
72+
"MonitoringResources": {
73+
"ClusterConfig": {
74+
"InstanceCount": 1,
75+
"InstanceType": "ml.t3.medium",
76+
"VolumeSizeInGB": 8,
77+
}
78+
},
79+
"MonitoringAppSpecification": {
80+
"ImageUri": "image-uri",
81+
"ContainerEntrypoint": [
82+
"entrypoint.py",
83+
],
84+
},
85+
"RoleArn": ROLE,
86+
}
87+
},
88+
"EndpointName": "my-endpoint",
89+
}
90+
6691

6792
# TODO-reinvent-2019: Continue to flesh these out.
6893
@pytest.fixture()
@@ -80,6 +105,9 @@ def sagemaker_session():
80105
name="upload_data", return_value="mocked_s3_uri_from_upload_data"
81106
)
82107
session_mock.download_data = Mock(name="download_data")
108+
session_mock.describe_monitoring_schedule = Mock(
109+
name="describe_monitoring_schedule", return_value=MONITORING_SCHEDULE_DESC
110+
)
83111
return session_mock
84112

85113

@@ -153,6 +181,17 @@ def test_default_model_monitor_with_invalid_network_config(sagemaker_session):
153181
assert INTER_CONTAINER_ENCRYPTION_EXCEPTION_MSG in str(exception.value)
154182

155183

184+
def test_model_monitor_without_network_config(sagemaker_session):
185+
my_model_monitor = ModelMonitor(
186+
role=ROLE,
187+
image_uri=CUSTOM_IMAGE_URI,
188+
sagemaker_session=sagemaker_session,
189+
)
190+
model_monitor_schedule_name = "model-monitoring-without-network-config"
191+
attached = my_model_monitor.attach(model_monitor_schedule_name, sagemaker_session)
192+
assert attached.network_config is None
193+
194+
156195
def test_model_monitor_with_invalid_network_config(sagemaker_session):
157196
invalid_network_config = NetworkConfig(encrypt_inter_container_traffic=False)
158197
my_model_monitor = ModelMonitor(

0 commit comments

Comments
 (0)