Skip to content

change: use sagemaker_session when initializing Constraints and Statistics #1314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/sagemaker/model_monitor/model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def create_monitoring_schedule(
normalized_monitoring_output = self._normalize_monitoring_output(output=output)

statistics_object, constraints_object = self._get_baseline_files(
statistics=statistics, constraints=constraints
statistics=statistics, constraints=constraints, sagemaker_session=self.sagemaker_session
)

statistics_s3_uri = None
Expand Down Expand Up @@ -402,7 +402,7 @@ def update_monitoring_schedule(
}

statistics_object, constraints_object = self._get_baseline_files(
statistics=statistics, constraints=constraints
statistics=statistics, constraints=constraints, sagemaker_session=self.sagemaker_session
)

statistics_s3_uri = None
Expand Down Expand Up @@ -781,7 +781,7 @@ def _generate_monitoring_schedule_name(self, schedule_name=None):
return name_from_base(base=base_name)

@staticmethod
def _get_baseline_files(statistics, constraints):
def _get_baseline_files(statistics, constraints, sagemaker_session=None):
"""Populates baseline values if possible.

Args:
Expand All @@ -791,6 +791,9 @@ def _get_baseline_files(statistics, constraints):
constraints (sagemaker.model_monitor.Constraints or str): The constraints object or str.
If none, this method will attempt to retrieve a previously baselined constraints
object.
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
with Amazon SageMaker APIs and any other AWS services needed. If not specified, one
is created using the default AWS configuration chain.

Returns:
sagemaker.model_monitor.Statistics, sagemaker.model_monitor.Constraints: The Statistics
Expand All @@ -799,9 +802,13 @@ def _get_baseline_files(statistics, constraints):

"""
if statistics is not None and isinstance(statistics, string_types):
statistics = Statistics.from_s3_uri(statistics_file_s3_uri=statistics)
statistics = Statistics.from_s3_uri(
statistics_file_s3_uri=statistics, sagemaker_session=sagemaker_session
)
if constraints is not None and isinstance(constraints, string_types):
constraints = Constraints.from_s3_uri(constraints_file_s3_uri=constraints)
constraints = Constraints.from_s3_uri(
constraints_file_s3_uri=constraints, sagemaker_session=sagemaker_session
)

return statistics, constraints

Expand Down Expand Up @@ -1240,7 +1247,7 @@ def create_monitoring_schedule(
)

statistics_object, constraints_object = self._get_baseline_files(
statistics=statistics, constraints=constraints
statistics=statistics, constraints=constraints, sagemaker_session=self.sagemaker_session
)

constraints_s3_uri = None
Expand Down Expand Up @@ -1386,7 +1393,7 @@ def update_monitoring_schedule(
)

statistics_object, constraints_object = self._get_baseline_files(
statistics=statistics, constraints=constraints
statistics=statistics, constraints=constraints, sagemaker_session=self.sagemaker_session
)

statistics_s3_uri = None
Expand Down Expand Up @@ -1829,6 +1836,7 @@ def baseline_statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_k
return Statistics.from_s3_uri(
statistics_file_s3_uri=os.path.join(baselining_job_output_s3_path, file_name),
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)
except ClientError as client_error:
if client_error.response["Error"]["Code"] == "NoSuchKey":
Expand Down Expand Up @@ -1866,6 +1874,7 @@ def suggested_constraints(self, file_name=CONSTRAINTS_JSON_DEFAULT_FILE_NAME, km
return Constraints.from_s3_uri(
constraints_file_s3_uri=os.path.join(baselining_job_output_s3_path, file_name),
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)
except ClientError as client_error:
if client_error.response["Error"]["Code"] == "NoSuchKey":
Expand Down Expand Up @@ -1981,6 +1990,7 @@ def statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_key=None):
return Statistics.from_s3_uri(
statistics_file_s3_uri=os.path.join(baselining_job_output_s3_path, file_name),
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)
except ClientError as client_error:
if client_error.response["Error"]["Code"] == "NoSuchKey":
Expand Down Expand Up @@ -2022,6 +2032,7 @@ def constraint_violations(
baselining_job_output_s3_path, file_name
),
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)
except ClientError as client_error:
if client_error.response["Error"]["Code"] == "NoSuchKey":
Expand Down
10 changes: 8 additions & 2 deletions src/sagemaker/model_monitor/monitoring_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def save(self, new_save_location_s3_uri=None):
self.file_s3_uri = new_save_location_s3_uri

return S3Uploader.upload_string_as_file_body(
body=json.dumps(self.body_dict), desired_s3_uri=self.file_s3_uri, kms_key=self.kms_key
body=json.dumps(self.body_dict),
desired_s3_uri=self.file_s3_uri,
kms_key=self.kms_key,
session=self.session,
)


Expand Down Expand Up @@ -252,7 +255,10 @@ def from_s3_uri(cls, constraints_file_s3_uri, kms_key=None, sagemaker_session=No
raise error

return cls(
body_dict=body_dict, constraints_file_s3_uri=constraints_file_s3_uri, kms_key=kms_key
body_dict=body_dict,
constraints_file_s3_uri=constraints_file_s3_uri,
kms_key=kms_key,
sagemaker_session=sagemaker_session,
)

@classmethod
Expand Down
42 changes: 28 additions & 14 deletions tests/integ/test_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,13 @@ def default_monitoring_schedule_name(sagemaker_session, output_kms_key, volume_k
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

my_default_monitor.create_monitoring_schedule(
Expand Down Expand Up @@ -194,11 +196,13 @@ def byoc_monitoring_schedule_name(sagemaker_session, output_kms_key, volume_kms_
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

my_byoc_monitor.create_monitoring_schedule(
Expand Down Expand Up @@ -676,11 +680,13 @@ def test_default_monitor_create_stop_and_start_monitoring_schedule_with_customiz
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

my_default_monitor.create_monitoring_schedule(
Expand Down Expand Up @@ -844,11 +850,13 @@ def test_default_monitor_create_and_update_schedule_config_with_customizations(
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

my_default_monitor.create_monitoring_schedule(
Expand Down Expand Up @@ -958,11 +966,13 @@ def test_default_monitor_create_and_update_schedule_config_with_customizations(
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
Expand Down Expand Up @@ -1338,11 +1348,13 @@ def test_default_monitor_attach_followed_by_baseline_and_update_monitoring_sched
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

_wait_for_schedule_changes_to_apply(my_attached_monitor)
Expand Down Expand Up @@ -1968,11 +1980,13 @@ def test_byoc_monitor_create_and_update_schedule_config_with_customizations(
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

my_byoc_monitor.create_monitoring_schedule(
Expand Down
61 changes: 42 additions & 19 deletions tests/integ/test_monitoring_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def test_statistics_object_creation_from_file_path_with_customizations(
assert statistics.body_dict["dataset"]["item_count"] == 418


def test_statistics_object_creation_from_file_path_without_customizations():
def test_statistics_object_creation_from_file_path_without_customizations(sagemaker_session):
statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json")
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

assert statistics.file_s3_uri.startswith("s3://")
Expand Down Expand Up @@ -74,11 +75,13 @@ def test_statistics_object_creation_from_string_with_customizations(
assert statistics.body_dict["dataset"]["item_count"] == 418


def test_statistics_object_creation_from_string_without_customizations():
def test_statistics_object_creation_from_string_without_customizations(sagemaker_session):
with open(os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"), "r") as f:
file_body = f.read()

statistics = Statistics.from_string(statistics_file_string=file_body)
statistics = Statistics.from_string(
statistics_file_string=file_body, sagemaker_session=sagemaker_session
)

assert statistics.file_s3_uri.startswith("s3://")
assert statistics.file_s3_uri.endswith("statistics.json")
Expand Down Expand Up @@ -133,9 +136,13 @@ def test_statistics_object_creation_from_s3_uri_without_customizations(sagemaker
file_name,
)

s3_uri = S3Uploader.upload_string_as_file_body(body=file_body, desired_s3_uri=desired_s3_uri)
s3_uri = S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
)

statistics = Statistics.from_s3_uri(statistics_file_s3_uri=s3_uri)
statistics = Statistics.from_s3_uri(
statistics_file_s3_uri=s3_uri, sagemaker_session=sagemaker_session
)

assert statistics.file_s3_uri.startswith("s3://")
assert statistics.file_s3_uri.endswith("statistics.json")
Expand Down Expand Up @@ -181,14 +188,17 @@ def test_constraints_object_creation_from_file_path_with_customizations(

constraints.save()

new_constraints = Constraints.from_s3_uri(constraints.file_s3_uri)
new_constraints = Constraints.from_s3_uri(
constraints.file_s3_uri, sagemaker_session=sagemaker_session
)

assert new_constraints.body_dict["monitoring_config"]["evaluate_constraints"] == "Disabled"


def test_constraints_object_creation_from_file_path_without_customizations():
def test_constraints_object_creation_from_file_path_without_customizations(sagemaker_session):
constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json")
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

assert constraints.file_s3_uri.startswith("s3://")
Expand Down Expand Up @@ -216,11 +226,13 @@ def test_constraints_object_creation_from_string_with_customizations(
assert constraints.body_dict["monitoring_config"]["evaluate_constraints"] == "Enabled"


def test_constraints_object_creation_from_string_without_customizations():
def test_constraints_object_creation_from_string_without_customizations(sagemaker_session):
with open(os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"), "r") as f:
file_body = f.read()

constraints = Constraints.from_string(constraints_file_string=file_body)
constraints = Constraints.from_string(
constraints_file_string=file_body, sagemaker_session=sagemaker_session
)

assert constraints.file_s3_uri.startswith("s3://")
assert constraints.file_s3_uri.endswith("constraints.json")
Expand Down Expand Up @@ -275,9 +287,13 @@ def test_constraints_object_creation_from_s3_uri_without_customizations(sagemake
file_name,
)

s3_uri = S3Uploader.upload_string_as_file_body(body=file_body, desired_s3_uri=desired_s3_uri)
s3_uri = S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
)

constraints = Constraints.from_s3_uri(constraints_file_s3_uri=s3_uri)
constraints = Constraints.from_s3_uri(
constraints_file_s3_uri=s3_uri, sagemaker_session=sagemaker_session
)

assert constraints.file_s3_uri.startswith("s3://")
assert constraints.file_s3_uri.endswith("constraints.json")
Expand All @@ -302,11 +318,14 @@ def test_constraint_violations_object_creation_from_file_path_with_customization
assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag"


def test_constraint_violations_object_creation_from_file_path_without_customizations():
def test_constraint_violations_object_creation_from_file_path_without_customizations(
sagemaker_session
):
constraint_violations = ConstraintViolations.from_file_path(
constraint_violations_file_path=os.path.join(
tests.integ.DATA_DIR, "monitor/constraint_violations.json"
)
),
sagemaker_session=sagemaker_session,
)

assert constraint_violations.file_s3_uri.startswith("s3://")
Expand Down Expand Up @@ -334,12 +353,14 @@ def test_constraint_violations_object_creation_from_string_with_customizations(
assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag"


def test_constraint_violations_object_creation_from_string_without_customizations():
def test_constraint_violations_object_creation_from_string_without_customizations(
sagemaker_session
):
with open(os.path.join(tests.integ.DATA_DIR, "monitor/constraint_violations.json"), "r") as f:
file_body = f.read()

constraint_violations = ConstraintViolations.from_string(
constraint_violations_file_string=file_body
constraint_violations_file_string=file_body, sagemaker_session=sagemaker_session
)

assert constraint_violations.file_s3_uri.startswith("s3://")
Expand Down Expand Up @@ -397,10 +418,12 @@ def test_constraint_violations_object_creation_from_s3_uri_without_customization
file_name,
)

s3_uri = S3Uploader.upload_string_as_file_body(body=file_body, desired_s3_uri=desired_s3_uri)
s3_uri = S3Uploader.upload_string_as_file_body(
body=file_body, desired_s3_uri=desired_s3_uri, session=sagemaker_session
)

constraint_violations = ConstraintViolations.from_s3_uri(
constraint_violations_file_s3_uri=s3_uri
constraint_violations_file_s3_uri=s3_uri, sagemaker_session=sagemaker_session
)

assert constraint_violations.file_s3_uri.startswith("s3://")
Expand Down