Skip to content

Commit ad1a95f

Browse files
authored
Merge branch 'master' into callback-step
2 parents 6212c8b + 8478297 commit ad1a95f

File tree

5 files changed

+46
-4
lines changed

5 files changed

+46
-4
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
## v2.44.0 (2021-06-01)
4+
5+
### Features
6+
7+
* support endpoint_name_prefix, seed and version for Clarify
8+
39
## v2.43.0 (2021-05-31)
410

511
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.43.1.dev0
1+
2.44.1.dev0

src/sagemaker/clarify.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import json
2020
import os
2121
import tempfile
22-
22+
import re
2323
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
2424
from sagemaker import image_uris, s3, utils
2525

@@ -124,8 +124,9 @@ def __init__(
124124
content_template=None,
125125
custom_attributes=None,
126126
accelerator_type=None,
127+
endpoint_name_prefix=None,
127128
):
128-
"""Initializes a configuration of a model and the endpoint to be created for it.
129+
r"""Initializes a configuration of a model and the endpoint to be created for it.
129130
130131
Args:
131132
model_name (str): Model name (as created by 'CreateModel').
@@ -155,12 +156,21 @@ def __init__(
155156
accelerator_type (str): The Elastic Inference accelerator type to deploy to the model
156157
endpoint instance for making inferences to the model, see
157158
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
159+
endpoint_name_prefix (str): The endpoint name prefix of a new endpoint. Must follow
160+
pattern "^[a-zA-Z0-9](-\*[a-zA-Z0-9]".
158161
"""
159162
self.predictor_config = {
160163
"model_name": model_name,
161164
"instance_type": instance_type,
162165
"initial_instance_count": instance_count,
163166
}
167+
if endpoint_name_prefix is not None:
168+
if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None:
169+
raise ValueError(
170+
"Invalid endpoint_name_prefix."
171+
" Please follow pattern ^[a-zA-Z0-9](-*[a-zA-Z0-9])."
172+
)
173+
self.predictor_config["endpoint_name_prefix"] = endpoint_name_prefix
164174
if accept_type is not None:
165175
if accept_type not in ["text/csv", "application/jsonlines"]:
166176
raise ValueError(
@@ -277,6 +287,7 @@ def __init__(
277287
agg_method,
278288
use_logit=False,
279289
save_local_shap_values=True,
290+
seed=None,
280291
):
281292
"""Initializes config for SHAP.
282293
@@ -297,6 +308,7 @@ def __init__(
297308
have log-odds units.
298309
save_local_shap_values (bool): Indicator of whether to save the local SHAP values
299310
in the output location. Default is True.
311+
seed (int): seed value to get deterministic SHAP values. Default is None.
300312
"""
301313
if agg_method not in ["mean_abs", "median", "mean_sq"]:
302314
raise ValueError(
@@ -310,6 +322,8 @@ def __init__(
310322
"use_logit": use_logit,
311323
"save_local_shap_values": save_local_shap_values,
312324
}
325+
if seed is not None:
326+
self.shap_config["seed"] = seed
313327

314328
def get_explainability_config(self):
315329
"""Returns config."""
@@ -336,6 +350,7 @@ def __init__(
336350
env=None,
337351
tags=None,
338352
network_config=None,
353+
version=None,
339354
):
340355
"""Initializes a ``Processor`` instance, computing bias metrics and model explanations.
341356
@@ -369,8 +384,9 @@ def __init__(
369384
A :class:`~sagemaker.network.NetworkConfig`
370385
object that configures network isolation, encryption of
371386
inter-container traffic, security group IDs, and subnets.
387+
version (str): Clarify version want to be used.
372388
"""
373-
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name)
389+
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
374390
super(SageMakerClarifyProcessor, self).__init__(
375391
role,
376392
container_uri,

tests/integ/test_clarify.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def model_config(model_name):
144144
instance_type="ml.c5.xlarge",
145145
instance_count=1,
146146
accept_type="application/jsonlines",
147+
endpoint_name_prefix="myprefix",
147148
)
148149

149150

@@ -172,6 +173,7 @@ def shap_config():
172173
],
173174
num_samples=2,
174175
agg_method="mean_sq",
176+
seed=123,
175177
)
176178

177179

tests/unit/test_clarify.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,21 @@ def test_invalid_model_config():
128128
)
129129

130130

131+
def test_invalid_model_config_with_bad_endpoint_name_prefix():
132+
with pytest.raises(ValueError) as error:
133+
ModelConfig(
134+
model_name="xgboost-model",
135+
instance_type="ml.c5.xlarge",
136+
instance_count=1,
137+
accept_type="invalid_accept_type",
138+
endpoint_name_prefix="~invalid_endpoint_prefix",
139+
)
140+
assert (
141+
"Invalid endpoint_name_prefix. Please follow pattern ^[a-zA-Z0-9](-*[a-zA-Z0-9])."
142+
in str(error.value)
143+
)
144+
145+
131146
def test_model_predicted_label_config():
132147
label = "label"
133148
probability = "pr"
@@ -171,11 +186,13 @@ def test_shap_config():
171186
num_samples = 100
172187
agg_method = "mean_sq"
173188
use_logit = True
189+
seed = 123
174190
shap_config = SHAPConfig(
175191
baseline=baseline,
176192
num_samples=num_samples,
177193
agg_method=agg_method,
178194
use_logit=use_logit,
195+
seed=seed,
179196
)
180197
expected_config = {
181198
"shap": {
@@ -184,6 +201,7 @@ def test_shap_config():
184201
"agg_method": agg_method,
185202
"use_logit": use_logit,
186203
"save_local_shap_values": True,
204+
"seed": seed,
187205
}
188206
}
189207
assert expected_config == shap_config.get_explainability_config()

0 commit comments

Comments
 (0)