Skip to content

Commit 8643204

Browse files
authored
change: add sagemaker_session parameter to DataCaptureConfig (#1313)
1 parent 60ada5b commit 8643204

File tree

5 files changed

+32
-9
lines changed

5 files changed

+32
-9
lines changed

src/sagemaker/model_monitor/data_capture_config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
capture_options=None,
4242
csv_content_types=None,
4343
json_content_types=None,
44+
sagemaker_session=None,
4445
):
4546
"""Initialize a DataCaptureConfig object for capturing data from Amazon SageMaker Endpoints.
4647
@@ -56,14 +57,21 @@ def __init__(
5657
which data to capture between request and response.
5758
csv_content_types ([str]): Optional. Default=["text/csv"].
5859
json_content_types([str]): Optional. Default=["application/json"].
59-
60+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
61+
object, used for SageMaker interactions (default: None). If not
62+
specified, one is created using the default AWS configuration
63+
chain.
6064
"""
6165
self.enable_capture = enable_capture
6266
self.sampling_percentage = sampling_percentage
6367
self.destination_s3_uri = destination_s3_uri
6468
if self.destination_s3_uri is None:
69+
sagemaker_session = sagemaker_session or Session()
6570
self.destination_s3_uri = os.path.join(
66-
"s3://", Session().default_bucket(), _MODEL_MONITOR_S3_PATH, _DATA_CAPTURE_S3_PATH
71+
"s3://",
72+
sagemaker_session.default_bucket(),
73+
_MODEL_MONITOR_S3_PATH,
74+
_DATA_CAPTURE_S3_PATH,
6775
)
6876

6977
self.kms_key_id = kms_key_id

src/sagemaker/predictor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,22 @@ def enable_data_capture(self):
192192
to enable data capture. For a more customized experience, refer to
193193
update_data_capture_config, instead.
194194
"""
195-
self.update_data_capture_config(data_capture_config=DataCaptureConfig(enable_capture=True))
195+
self.update_data_capture_config(
196+
data_capture_config=DataCaptureConfig(
197+
enable_capture=True, sagemaker_session=self.sagemaker_session
198+
)
199+
)
196200

197201
def disable_data_capture(self):
198202
"""Updates the DataCaptureConfig for the Predictor's associated Amazon SageMaker Endpoint
199203
to disable data capture. For a more customized experience, refer to
200204
update_data_capture_config, instead.
201205
"""
202-
self.update_data_capture_config(data_capture_config=DataCaptureConfig(enable_capture=False))
206+
self.update_data_capture_config(
207+
data_capture_config=DataCaptureConfig(
208+
enable_capture=False, sagemaker_session=self.sagemaker_session
209+
)
210+
)
203211

204212
def update_data_capture_config(self, data_capture_config):
205213
"""Updates the DataCaptureConfig for the Predictor's associated Amazon SageMaker Endpoint

tests/integ/test_data_capture_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
126126
capture_options=CUSTOM_CAPTURE_OPTIONS,
127127
csv_content_types=CUSTOM_CSV_CONTENT_TYPES,
128128
json_content_types=CUSTOM_JSON_CONTENT_TYPES,
129+
sagemaker_session=sagemaker_session,
129130
),
130131
)
131132

@@ -224,6 +225,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
224225
capture_options=CUSTOM_CAPTURE_OPTIONS,
225226
csv_content_types=CUSTOM_CSV_CONTENT_TYPES,
226227
json_content_types=CUSTOM_JSON_CONTENT_TYPES,
228+
sagemaker_session=sagemaker_session,
227229
)
228230
)
229231

tests/integ/test_model_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def predictor(sagemaker_session, tf_full_version):
107107
INSTANCE_COUNT,
108108
INSTANCE_TYPE,
109109
endpoint_name=endpoint_name,
110-
data_capture_config=DataCaptureConfig(True),
110+
data_capture_config=DataCaptureConfig(True, sagemaker_session=sagemaker_session),
111111
)
112112
yield predictor
113113

tests/unit/sagemaker/monitor/test_data_capture_config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from mock import Mock
16+
1517
from sagemaker.model_monitor import DataCaptureConfig
1618

1719
DEFAULT_ENABLE_CAPTURE = True
1820
DEFAULT_SAMPLING_PERCENTAGE = 20
1921
DEFAULT_BUCKET_NAME = "default-bucket"
20-
DEFAULT_DESTINATION_S3_URI = "s3://" + DEFAULT_BUCKET_NAME + "/model-monitor/data-capture"
22+
DEFAULT_DESTINATION_S3_URI = "s3://{}/model-monitor/data-capture".format(DEFAULT_BUCKET_NAME)
2123
DEFAULT_KMS_KEY_ID = None
2224
DEFAULT_CAPTURE_MODES = ["REQUEST", "RESPONSE"]
2325
DEFAULT_CSV_CONTENT_TYPES = ["text/csv"]
@@ -33,7 +35,7 @@
3335
NON_DEFAULT_JSON_CONTENT_TYPES = ["custom/json-format"]
3436

3537

36-
def test_to_request_dict_returns_correct_params_when_non_defaults_provided():
38+
def test_init_when_non_defaults_provided():
3739
data_capture_config = DataCaptureConfig(
3840
enable_capture=NON_DEFAULT_ENABLE_CAPTURE,
3941
sampling_percentage=NON_DEFAULT_SAMPLING_PERCENTAGE,
@@ -51,9 +53,12 @@ def test_to_request_dict_returns_correct_params_when_non_defaults_provided():
5153
assert data_capture_config.json_content_types == NON_DEFAULT_JSON_CONTENT_TYPES
5254

5355

54-
def test_to_request_dict_returns_correct_default_params_when_optionals_not_provided():
56+
def test_init_when_optionals_not_provided():
57+
sagemaker_session = Mock()
58+
sagemaker_session.default_bucket.return_value = DEFAULT_BUCKET_NAME
59+
5560
data_capture_config = DataCaptureConfig(
56-
enable_capture=DEFAULT_ENABLE_CAPTURE, destination_s3_uri=DEFAULT_DESTINATION_S3_URI
61+
enable_capture=DEFAULT_ENABLE_CAPTURE, sagemaker_session=sagemaker_session
5762
)
5863

5964
assert data_capture_config.enable_capture == DEFAULT_ENABLE_CAPTURE

0 commit comments

Comments
 (0)