Skip to content

Commit 5aefca0

Browse files
authored
feature: Amazon SageMaker Model Monitoring
tl;dr: Added Amazon SageMaker Model Monitoring feature support for sagemaker-python-sdk. Specifically: * Added ModelMonitor and DefaultModelMonitor objects capable of running BaseliningJobs and create/updating schedules. These classes have many convenience functions to suggest baseline statistics and constraints, view monitor-generated statistics and violations, and many more. * Customers are able to pull data generated by these jobs and analyze/modify them. * Customers can attach Monitor objects to existing schedules in order to update them or review the data they generated. * Customers can enable Data Capture on their existing or new endpoints to allow for model monitoring. This can be done by utilizing a DataCaptureConfig object, or by called .enable_data_capture() on a predictor object. * Customers can specify a DatasetFormat for their monitoring schedules utilizing a DatasetFormat object. * Added a CronExpressionGenerator to generate cron expressions supported by the feature. * Added an output capturer as part of the unit tests, in order to test print statements to the customer. * Added a convenience method to upload a string to S3: S3Uploader.upload_string_as_file_body(my_string). * Added a convenience method to read an S3 file into a python string object: S3Downloader.read_file(my_uri). * Added method to create new endpoint_config from existing endpoint_config to streamline the process of updating endpoints. * Added schedule cleanup to endpoint cleanup to allow for cleaning up endpoints that have schedules attached. * Imports greatly simplified for customers to allow them to import all the modules they need from a single namespace: "from sagemaker.model_monitor import *". Note that I do not condone the use of import * in a code-base that is meant to be maintained.
1 parent 140bcf9 commit 5aefca0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+8599
-121
lines changed

buildspec.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ phases:
1919
- start_time=`date +%s`
2020
- |
2121
if has-matching-changes "tests/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"; then
22-
tox -e py36 -- tests/integ -m "not local_mode" -n 48 --reruns 3 --reruns-delay 5 --durations 50 --boto-config '{"region_name": "us-east-1"}'
22+
tox -e py36 -- tests/integ -m "not local_mode" -n 48 --reruns 3 --reruns-delay 5 --durations 50 --boto-config '{"region_name": "us-east-2"}'
2323
fi
2424
- ./ci-scripts/displaytime.sh 'py36 tests/integ' $start_time
2525

doc/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __getattr__(cls, name):
4141
"tensorflow.python.framework",
4242
"tensorflow_serving",
4343
"tensorflow_serving.apis",
44-
"numpy",
4544
"scipy",
4645
"scipy.sparse",
4746
]

src/sagemaker/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def deploy(
557557
wait=True,
558558
model_name=None,
559559
kms_key=None,
560+
data_capture_config=None,
560561
**kwargs
561562
):
562563
"""Deploy the trained model to an Amazon SageMaker endpoint and return a
@@ -599,6 +600,9 @@ def deploy(
599600
kms_key (str): The ARN of the KMS key that is used to encrypt the
600601
data on the storage volume attached to the instance hosting the
601602
endpoint.
603+
data_capture_config (DataCaptureConfig): Specifies configuration
604+
related to Endpoint data capture for use with
605+
Amazon SageMaker Model Monitoring. Default: None.
602606
**kwargs: Passed to invocation of ``create_model()``.
603607
Implementations may customize ``create_model()`` to accept
604608
``**kwargs`` to customize model creation during deploy.
@@ -624,7 +628,9 @@ def deploy(
624628
else:
625629
kwargs["model_kms_key"] = self.output_kms_key
626630
model = self.create_model(**kwargs)
631+
627632
model.name = model_name
633+
628634
return model.deploy(
629635
instance_type=instance_type,
630636
initial_instance_count=initial_instance_count,
@@ -634,6 +640,7 @@ def deploy(
634640
tags=self.tags,
635641
wait=wait,
636642
kms_key=kms_key,
643+
data_capture_config=data_capture_config,
637644
)
638645

639646
@property

src/sagemaker/model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def deploy(
384384
tags=None,
385385
kms_key=None,
386386
wait=True,
387+
data_capture_config=None,
387388
):
388389
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a
389390
``Predictor``.
@@ -424,6 +425,9 @@ def deploy(
424425
endpoint.
425426
wait (bool): Whether the call should wait until the deployment of
426427
this model completes (default: True).
428+
data_capture_config (DataCaptureConfig): Specifies configuration
429+
related to Endpoint data capture for use with
430+
Amazon SageMaker Model Monitoring. Default: None.
427431
428432
Returns:
429433
callable[string, sagemaker.session.Session] or None: Invocation of
@@ -454,6 +458,10 @@ def deploy(
454458
if self._is_compiled_model and not self.endpoint_name.endswith(compiled_model_suffix):
455459
self.endpoint_name += compiled_model_suffix
456460

461+
data_capture_config_dict = (
462+
data_capture_config.to_request_dict() if data_capture_config else None
463+
)
464+
457465
if update_endpoint:
458466
endpoint_config_name = self.sagemaker_session.create_endpoint_config(
459467
name=self.name,
@@ -463,11 +471,17 @@ def deploy(
463471
accelerator_type=accelerator_type,
464472
tags=tags,
465473
kms_key=kms_key,
474+
data_capture_config_dict=data_capture_config_dict,
466475
)
467476
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
468477
else:
469478
self.sagemaker_session.endpoint_from_production_variants(
470-
self.endpoint_name, [production_variant], tags, kms_key, wait
479+
name=self.endpoint_name,
480+
production_variants=[production_variant],
481+
tags=tags,
482+
kms_key=kms_key,
483+
wait=wait,
484+
data_capture_config_dict=data_capture_config_dict,
471485
)
472486

473487
if self.predictor_cls:
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Imports the classes in this module to simplify customer imports
14+
15+
Example:
16+
>>> from sagemaker.model_monitor import ModelMonitor
17+
18+
"""
19+
from __future__ import absolute_import
20+
21+
from sagemaker.model_monitor.model_monitoring import ModelMonitor # noqa: F401
22+
from sagemaker.model_monitor.model_monitoring import DefaultModelMonitor # noqa: F401
23+
from sagemaker.model_monitor.model_monitoring import MonitoringOutput # noqa: F401
24+
25+
from sagemaker.model_monitor.cron_expression_generator import CronExpressionGenerator # noqa: F401
26+
from sagemaker.model_monitor.monitoring_files import Statistics # noqa: F401
27+
from sagemaker.model_monitor.monitoring_files import Constraints # noqa: F401
28+
from sagemaker.model_monitor.monitoring_files import ConstraintViolations # noqa: F401
29+
30+
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig # noqa: F401
31+
32+
from sagemaker.network import NetworkConfig # noqa: F401
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code related to the CronExpressionGenerator class, which is used
14+
for generating cron expressions compatible with Amazon SageMaker Model Monitoring Schedules.
15+
"""
16+
from __future__ import print_function, absolute_import
17+
18+
19+
class CronExpressionGenerator(object):
20+
"""Generates cron expression strings for use with the Amazon SageMaker Model Monitoring Schedule
21+
API.
22+
"""
23+
24+
@staticmethod
25+
def hourly():
26+
"""Generates hourly cron expression that denotes that a job runs at the top of every hour.
27+
28+
Returns:
29+
str: The cron expression format accepted by the Amazon SageMaker Model Monitoring
30+
Schedule API.
31+
32+
"""
33+
return "cron(0 * ? * * *)"
34+
35+
@staticmethod
36+
def daily(hour=0):
37+
"""Generates daily cron expression that denotes that a job runs at the top of every hour.
38+
39+
Args:
40+
hour (int): The hour in HH24 format (UTC) to run the job at, on a daily schedule.
41+
Examples:
42+
- 00
43+
- 12
44+
- 17
45+
- 23
46+
47+
Returns:
48+
str: The cron expression format accepted by the Amazon SageMaker Model Monitoring
49+
Schedule API.
50+
51+
"""
52+
return "cron(0 {} ? * * *)".format(hour)
53+
54+
@staticmethod
55+
def daily_every_x_hours(hour_interval, starting_hour=0):
56+
"""Generates "daily every x hours" cron expression that denotes that a job runs every day
57+
at the specified hour, and then every x hours, as specified in hour_interval.
58+
59+
Example:
60+
>>> daily_every_x_hours(hour_interval=2, starting_hour=0)
61+
This will run every 2 hours starting at midnight.
62+
63+
>>> daily_every_x_hours(hour_interval=10, starting_hour=0)
64+
This will run at midnight, 10am, and 8pm every day.
65+
66+
Args:
67+
hour_interval (int): The hour interval to run the job at.
68+
starting_hour (int): The hour at which to begin in HH24 format (UTC).
69+
70+
Returns:
71+
str: The cron expression format accepted by the Amazon SageMaker Model Monitoring
72+
Schedule API.
73+
74+
"""
75+
return "cron(0 {}/{} ? * * *)".format(starting_hour, hour_interval)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code related to the DataCaptureConfig class, which is used
14+
for configuring capture, collection, and storage, for prediction requests and responses
15+
for models hosted on SageMaker Endpoints.
16+
"""
17+
from __future__ import print_function, absolute_import
18+
19+
import os
20+
21+
from sagemaker.session import Session
22+
23+
_MODEL_MONITOR_S3_PATH = "model-monitor"
24+
_DATA_CAPTURE_S3_PATH = "data-capture"
25+
26+
27+
class DataCaptureConfig(object):
28+
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.
29+
This object specifies configuration related to endpoint data capture for use with
30+
Amazon SageMaker Model Monitoring.
31+
"""
32+
33+
API_MAPPING = {"REQUEST": "Input", "RESPONSE": "Output"}
34+
35+
def __init__(
36+
self,
37+
enable_capture,
38+
sampling_percentage=20,
39+
destination_s3_uri=None,
40+
kms_key_id=None,
41+
capture_options=None,
42+
csv_content_types=None,
43+
json_content_types=None,
44+
):
45+
"""Initialize a DataCaptureConfig object for capturing data from Amazon SageMaker Endpoints.
46+
47+
Args:
48+
enable_capture (bool): Required. Whether data capture should be enabled or not.
49+
sampling_percentage (int): Optional. Default=20. The percentage of data to sample.
50+
Must be between 0 and 100.
51+
destination_s3_uri (str): Optional. Defaults to "s3://<default-session-bucket>/
52+
<model-monitor>/data-capture
53+
kms_key_id (str): Optional. Default=None. The kms key to use when writing to S3.
54+
capture_options ([str]): Optional. Must be a list containing any combination of the
55+
following values: "REQUEST", "RESPONSE". Default=["REQUEST", "RESPONSE"]. Denotes
56+
which data to capture between request and response.
57+
csv_content_types ([str]): Optional. Default=["text/csv"].
58+
json_content_types([str]): Optional. Default=["application/json"].
59+
60+
"""
61+
self.enable_capture = enable_capture
62+
self.sampling_percentage = sampling_percentage
63+
self.destination_s3_uri = destination_s3_uri
64+
if self.destination_s3_uri is None:
65+
self.destination_s3_uri = os.path.join(
66+
"s3://", Session().default_bucket(), _MODEL_MONITOR_S3_PATH, _DATA_CAPTURE_S3_PATH
67+
)
68+
69+
self.kms_key_id = kms_key_id
70+
self.capture_options = capture_options or ["REQUEST", "RESPONSE"]
71+
self.csv_content_types = csv_content_types or ["text/csv"]
72+
self.json_content_types = json_content_types or ["application/json"]
73+
74+
def to_request_dict(self):
75+
"""Generates a request dictionary using the parameters provided to the class."""
76+
request_dict = {
77+
"EnableCapture": self.enable_capture,
78+
"InitialSamplingPercentage": self.sampling_percentage,
79+
"DestinationS3Uri": self.destination_s3_uri,
80+
"CaptureOptions": [
81+
{"CaptureMode": dict(self.API_MAPPING).get(capture_option.upper(), capture_option)}
82+
for capture_option in list(self.capture_options)
83+
],
84+
}
85+
86+
if self.kms_key_id is not None:
87+
request_dict["KmsKeyId"] = self.kms_key_id
88+
89+
if self.csv_content_types is not None or self.json_content_types is not None:
90+
request_dict["CaptureContentTypeHeader"] = {}
91+
92+
if self.csv_content_types is not None:
93+
request_dict["CaptureContentTypeHeader"]["CsvContentTypes"] = self.csv_content_types
94+
95+
if self.json_content_types is not None:
96+
request_dict["CaptureContentTypeHeader"]["JsonContentTypes"] = self.json_content_types
97+
98+
return request_dict
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code related to the DatasetFormat class, which is used
14+
for managing the constraints JSON file generated and consumed by Amazon SageMaker Model Monitoring
15+
Schedules.
16+
"""
17+
from __future__ import print_function, absolute_import
18+
19+
20+
class DatasetFormat(object):
21+
"""Represents a Dataset Format that is used when calling a DefaultModelMonitor.
22+
"""
23+
24+
@staticmethod
25+
def csv(header=True, output_columns_position="START"):
26+
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.
27+
28+
Args:
29+
header (bool): Whether the csv dataset to baseline and monitor has a header.
30+
Default: True.
31+
output_columns_position (str): The position of the output columns.
32+
Must be one of ("START", "END"). Default: "START".
33+
34+
Returns:
35+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
36+
37+
"""
38+
return {"csv": {"header": header, "output_columns_position": output_columns_position}}
39+
40+
@staticmethod
41+
def json(lines=True):
42+
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.
43+
44+
Args:
45+
lines (bool): Read the file as a json object per line. Default: True.
46+
47+
Returns:
48+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
49+
50+
"""
51+
return {"json": {"lines": lines}}
52+
53+
@staticmethod
54+
def sagemaker_capture_json():
55+
"""Returns a DatasetFormat SageMaker Capture Json string for use with a DefaultModelMonitor.
56+
57+
Returns:
58+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
59+
60+
"""
61+
return {"sagemaker_capture_json": {}}

0 commit comments

Comments
 (0)