Skip to content

Commit f14d70a

Browse files
authored
change: use sagemaker_session when initializing Constraints and Statistics (#1314)
1 parent 8643204 commit f14d70a

File tree

4 files changed

+96
-42
lines changed

4 files changed

+96
-42
lines changed

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def create_monitoring_schedule(
278278
normalized_monitoring_output = self._normalize_monitoring_output(output=output)
279279

280280
statistics_object, constraints_object = self._get_baseline_files(
281-
statistics=statistics, constraints=constraints
281+
statistics=statistics, constraints=constraints, sagemaker_session=self.sagemaker_session
282282
)
283283

284284
statistics_s3_uri = None
@@ -402,7 +402,7 @@ def update_monitoring_schedule(
402402
}
403403

404404
statistics_object, constraints_object = self._get_baseline_files(
405-
statistics=statistics, constraints=constraints
405+
statistics=statistics, constraints=constraints, sagemaker_session=self.sagemaker_session
406406
)
407407

408408
statistics_s3_uri = None
@@ -781,7 +781,7 @@ def _generate_monitoring_schedule_name(self, schedule_name=None):
781781
return name_from_base(base=base_name)
782782

783783
@staticmethod
784-
def _get_baseline_files(statistics, constraints):
784+
def _get_baseline_files(statistics, constraints, sagemaker_session=None):
785785
"""Populates baseline values if possible.
786786
787787
Args:
@@ -791,6 +791,9 @@ def _get_baseline_files(statistics, constraints):
791791
constraints (sagemaker.model_monitor.Constraints or str): The constraints object or str.
792792
If none, this method will attempt to retrieve a previously baselined constraints
793793
object.
794+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
795+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, one
796+
is created using the default AWS configuration chain.
794797
795798
Returns:
796799
sagemaker.model_monitor.Statistics, sagemaker.model_monitor.Constraints: The Statistics
@@ -799,9 +802,13 @@ def _get_baseline_files(statistics, constraints):
799802
800803
"""
801804
if statistics is not None and isinstance(statistics, string_types):
802-
statistics = Statistics.from_s3_uri(statistics_file_s3_uri=statistics)
805+
statistics = Statistics.from_s3_uri(
806+
statistics_file_s3_uri=statistics, sagemaker_session=sagemaker_session
807+
)
803808
if constraints is not None and isinstance(constraints, string_types):
804-
constraints = Constraints.from_s3_uri(constraints_file_s3_uri=constraints)
809+
constraints = Constraints.from_s3_uri(
810+
constraints_file_s3_uri=constraints, sagemaker_session=sagemaker_session
811+
)
805812

806813
return statistics, constraints
807814

@@ -1240,7 +1247,7 @@ def create_monitoring_schedule(
12401247
)
12411248

12421249
statistics_object, constraints_object = self._get_baseline_files(
1243-
statistics=statistics, constraints=constraints
1250+
statistics=statistics, constraints=constraints, sagemaker_session=self.sagemaker_session
12441251
)
12451252

12461253
constraints_s3_uri = None
@@ -1386,7 +1393,7 @@ def update_monitoring_schedule(
13861393
)
13871394

13881395
statistics_object, constraints_object = self._get_baseline_files(
1389-
statistics=statistics, constraints=constraints
1396+
statistics=statistics, constraints=constraints, sagemaker_session=self.sagemaker_session
13901397
)
13911398

13921399
statistics_s3_uri = None
@@ -1829,6 +1836,7 @@ def baseline_statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_k
18291836
return Statistics.from_s3_uri(
18301837
statistics_file_s3_uri=os.path.join(baselining_job_output_s3_path, file_name),
18311838
kms_key=kms_key,
1839+
sagemaker_session=self.sagemaker_session,
18321840
)
18331841
except ClientError as client_error:
18341842
if client_error.response["Error"]["Code"] == "NoSuchKey":
@@ -1866,6 +1874,7 @@ def suggested_constraints(self, file_name=CONSTRAINTS_JSON_DEFAULT_FILE_NAME, km
18661874
return Constraints.from_s3_uri(
18671875
constraints_file_s3_uri=os.path.join(baselining_job_output_s3_path, file_name),
18681876
kms_key=kms_key,
1877+
sagemaker_session=self.sagemaker_session,
18691878
)
18701879
except ClientError as client_error:
18711880
if client_error.response["Error"]["Code"] == "NoSuchKey":
@@ -1981,6 +1990,7 @@ def statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_key=None):
19811990
return Statistics.from_s3_uri(
19821991
statistics_file_s3_uri=os.path.join(baselining_job_output_s3_path, file_name),
19831992
kms_key=kms_key,
1993+
sagemaker_session=self.sagemaker_session,
19841994
)
19851995
except ClientError as client_error:
19861996
if client_error.response["Error"]["Code"] == "NoSuchKey":
@@ -2022,6 +2032,7 @@ def constraint_violations(
20222032
baselining_job_output_s3_path, file_name
20232033
),
20242034
kms_key=kms_key,
2035+
sagemaker_session=self.sagemaker_session,
20252036
)
20262037
except ClientError as client_error:
20272038
if client_error.response["Error"]["Code"] == "NoSuchKey":

src/sagemaker/model_monitor/monitoring_files.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def save(self, new_save_location_s3_uri=None):
6969
self.file_s3_uri = new_save_location_s3_uri
7070

7171
return S3Uploader.upload_string_as_file_body(
72-
body=json.dumps(self.body_dict), desired_s3_uri=self.file_s3_uri, kms_key=self.kms_key
72+
body=json.dumps(self.body_dict),
73+
desired_s3_uri=self.file_s3_uri,
74+
kms_key=self.kms_key,
75+
session=self.session,
7376
)
7477

7578

@@ -252,7 +255,10 @@ def from_s3_uri(cls, constraints_file_s3_uri, kms_key=None, sagemaker_session=No
252255
raise error
253256

254257
return cls(
255-
body_dict=body_dict, constraints_file_s3_uri=constraints_file_s3_uri, kms_key=kms_key
258+
body_dict=body_dict,
259+
constraints_file_s3_uri=constraints_file_s3_uri,
260+
kms_key=kms_key,
261+
sagemaker_session=sagemaker_session,
256262
)
257263

258264
@classmethod

tests/integ/test_model_monitor.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,13 @@ def default_monitoring_schedule_name(sagemaker_session, output_kms_key, volume_k
136136
)
137137

138138
statistics = Statistics.from_file_path(
139-
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
139+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
140+
sagemaker_session=sagemaker_session,
140141
)
141142

142143
constraints = Constraints.from_file_path(
143-
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
144+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
145+
sagemaker_session=sagemaker_session,
144146
)
145147

146148
my_default_monitor.create_monitoring_schedule(
@@ -194,11 +196,13 @@ def byoc_monitoring_schedule_name(sagemaker_session, output_kms_key, volume_kms_
194196
)
195197

196198
statistics = Statistics.from_file_path(
197-
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
199+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
200+
sagemaker_session=sagemaker_session,
198201
)
199202

200203
constraints = Constraints.from_file_path(
201-
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
204+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
205+
sagemaker_session=sagemaker_session,
202206
)
203207

204208
my_byoc_monitor.create_monitoring_schedule(
@@ -676,11 +680,13 @@ def test_default_monitor_create_stop_and_start_monitoring_schedule_with_customiz
676680
)
677681

678682
statistics = Statistics.from_file_path(
679-
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
683+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
684+
sagemaker_session=sagemaker_session,
680685
)
681686

682687
constraints = Constraints.from_file_path(
683-
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
688+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
689+
sagemaker_session=sagemaker_session,
684690
)
685691

686692
my_default_monitor.create_monitoring_schedule(
@@ -844,11 +850,13 @@ def test_default_monitor_create_and_update_schedule_config_with_customizations(
844850
)
845851

846852
statistics = Statistics.from_file_path(
847-
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
853+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
854+
sagemaker_session=sagemaker_session,
848855
)
849856

850857
constraints = Constraints.from_file_path(
851-
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
858+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
859+
sagemaker_session=sagemaker_session,
852860
)
853861

854862
my_default_monitor.create_monitoring_schedule(
@@ -958,11 +966,13 @@ def test_default_monitor_create_and_update_schedule_config_with_customizations(
958966
)
959967

960968
statistics = Statistics.from_file_path(
961-
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
969+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
970+
sagemaker_session=sagemaker_session,
962971
)
963972

964973
constraints = Constraints.from_file_path(
965-
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
974+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
975+
sagemaker_session=sagemaker_session,
966976
)
967977

968978
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
@@ -1338,11 +1348,13 @@ def test_default_monitor_attach_followed_by_baseline_and_update_monitoring_sched
13381348
)
13391349

13401350
statistics = Statistics.from_file_path(
1341-
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
1351+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1352+
sagemaker_session=sagemaker_session,
13421353
)
13431354

13441355
constraints = Constraints.from_file_path(
1345-
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
1356+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1357+
sagemaker_session=sagemaker_session,
13461358
)
13471359

13481360
_wait_for_schedule_changes_to_apply(my_attached_monitor)
@@ -1968,11 +1980,13 @@ def test_byoc_monitor_create_and_update_schedule_config_with_customizations(
19681980
)
19691981

19701982
statistics = Statistics.from_file_path(
1971-
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
1983+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1984+
sagemaker_session=sagemaker_session,
19721985
)
19731986

19741987
constraints = Constraints.from_file_path(
1975-
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
1988+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1989+
sagemaker_session=sagemaker_session,
19761990
)
19771991

19781992
my_byoc_monitor.create_monitoring_schedule(

tests/integ/test_monitoring_files.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ def test_statistics_object_creation_from_file_path_with_customizations(
4444
assert statistics.body_dict["dataset"]["item_count"] == 418
4545

4646

47-
def test_statistics_object_creation_from_file_path_without_customizations():
47+
def test_statistics_object_creation_from_file_path_without_customizations(sagemaker_session):
4848
statistics = Statistics.from_file_path(
49-
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
49+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
50+
sagemaker_session=sagemaker_session,
5051
)
5152

5253
assert statistics.file_s3_uri.startswith("s3://")
@@ -74,11 +75,13 @@ def test_statistics_object_creation_from_string_with_customizations(
7475
assert statistics.body_dict["dataset"]["item_count"] == 418
7576

7677

77-
def test_statistics_object_creation_from_string_without_customizations():
78+
def test_statistics_object_creation_from_string_without_customizations(sagemaker_session):
7879
with open(os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"), "r") as f:
7980
file_body = f.read()
8081

81-
statistics = Statistics.from_string(statistics_file_string=file_body)
82+
statistics = Statistics.from_string(
83+
statistics_file_string=file_body, sagemaker_session=sagemaker_session
84+
)
8285

8386
assert statistics.file_s3_uri.startswith("s3://")
8487
assert statistics.file_s3_uri.endswith("statistics.json")
@@ -133,9 +136,13 @@ def test_statistics_object_creation_from_s3_uri_without_customizations(sagemaker
133136
file_name,
134137
)
135138

136-
s3_uri = S3Uploader.upload_string_as_file_body(body=file_body, desired_s3_uri=desired_s3_uri)
139+
s3_uri = S3Uploader.upload_string_as_file_body(
140+
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
141+
)
137142

138-
statistics = Statistics.from_s3_uri(statistics_file_s3_uri=s3_uri)
143+
statistics = Statistics.from_s3_uri(
144+
statistics_file_s3_uri=s3_uri, sagemaker_session=sagemaker_session
145+
)
139146

140147
assert statistics.file_s3_uri.startswith("s3://")
141148
assert statistics.file_s3_uri.endswith("statistics.json")
@@ -181,14 +188,17 @@ def test_constraints_object_creation_from_file_path_with_customizations(
181188

182189
constraints.save()
183190

184-
new_constraints = Constraints.from_s3_uri(constraints.file_s3_uri)
191+
new_constraints = Constraints.from_s3_uri(
192+
constraints.file_s3_uri, sagemaker_session=sagemaker_session
193+
)
185194

186195
assert new_constraints.body_dict["monitoring_config"]["evaluate_constraints"] == "Disabled"
187196

188197

189-
def test_constraints_object_creation_from_file_path_without_customizations():
198+
def test_constraints_object_creation_from_file_path_without_customizations(sagemaker_session):
190199
constraints = Constraints.from_file_path(
191-
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
200+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
201+
sagemaker_session=sagemaker_session,
192202
)
193203

194204
assert constraints.file_s3_uri.startswith("s3://")
@@ -216,11 +226,13 @@ def test_constraints_object_creation_from_string_with_customizations(
216226
assert constraints.body_dict["monitoring_config"]["evaluate_constraints"] == "Enabled"
217227

218228

219-
def test_constraints_object_creation_from_string_without_customizations():
229+
def test_constraints_object_creation_from_string_without_customizations(sagemaker_session):
220230
with open(os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"), "r") as f:
221231
file_body = f.read()
222232

223-
constraints = Constraints.from_string(constraints_file_string=file_body)
233+
constraints = Constraints.from_string(
234+
constraints_file_string=file_body, sagemaker_session=sagemaker_session
235+
)
224236

225237
assert constraints.file_s3_uri.startswith("s3://")
226238
assert constraints.file_s3_uri.endswith("constraints.json")
@@ -275,9 +287,13 @@ def test_constraints_object_creation_from_s3_uri_without_customizations(sagemake
275287
file_name,
276288
)
277289

278-
s3_uri = S3Uploader.upload_string_as_file_body(body=file_body, desired_s3_uri=desired_s3_uri)
290+
s3_uri = S3Uploader.upload_string_as_file_body(
291+
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
292+
)
279293

280-
constraints = Constraints.from_s3_uri(constraints_file_s3_uri=s3_uri)
294+
constraints = Constraints.from_s3_uri(
295+
constraints_file_s3_uri=s3_uri, sagemaker_session=sagemaker_session
296+
)
281297

282298
assert constraints.file_s3_uri.startswith("s3://")
283299
assert constraints.file_s3_uri.endswith("constraints.json")
@@ -302,11 +318,14 @@ def test_constraint_violations_object_creation_from_file_path_with_customization
302318
assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag"
303319

304320

305-
def test_constraint_violations_object_creation_from_file_path_without_customizations():
321+
def test_constraint_violations_object_creation_from_file_path_without_customizations(
322+
sagemaker_session
323+
):
306324
constraint_violations = ConstraintViolations.from_file_path(
307325
constraint_violations_file_path=os.path.join(
308326
tests.integ.DATA_DIR, "monitor/constraint_violations.json"
309-
)
327+
),
328+
sagemaker_session=sagemaker_session,
310329
)
311330

312331
assert constraint_violations.file_s3_uri.startswith("s3://")
@@ -334,12 +353,14 @@ def test_constraint_violations_object_creation_from_string_with_customizations(
334353
assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag"
335354

336355

337-
def test_constraint_violations_object_creation_from_string_without_customizations():
356+
def test_constraint_violations_object_creation_from_string_without_customizations(
357+
sagemaker_session
358+
):
338359
with open(os.path.join(tests.integ.DATA_DIR, "monitor/constraint_violations.json"), "r") as f:
339360
file_body = f.read()
340361

341362
constraint_violations = ConstraintViolations.from_string(
342-
constraint_violations_file_string=file_body
363+
constraint_violations_file_string=file_body, sagemaker_session=sagemaker_session
343364
)
344365

345366
assert constraint_violations.file_s3_uri.startswith("s3://")
@@ -397,10 +418,12 @@ def test_constraint_violations_object_creation_from_s3_uri_without_customization
397418
file_name,
398419
)
399420

400-
s3_uri = S3Uploader.upload_string_as_file_body(body=file_body, desired_s3_uri=desired_s3_uri)
421+
s3_uri = S3Uploader.upload_string_as_file_body(
422+
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
423+
)
401424

402425
constraint_violations = ConstraintViolations.from_s3_uri(
403-
constraint_violations_file_s3_uri=s3_uri
426+
constraint_violations_file_s3_uri=s3_uri, sagemaker_session=sagemaker_session
404427
)
405428

406429
assert constraint_violations.file_s3_uri.startswith("s3://")

0 commit comments

Comments
 (0)