Skip to content

Commit 6c1a3a1

Browse files
authored
feat: Add features_to_explain to shap config (#3951)
1 parent 201f63e commit 6c1a3a1

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

src/sagemaker/clarify.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@
267267
},
268268
},
269269
SchemaOptional("seed"): int,
270+
SchemaOptional("features_to_explain"): [Or(int, str)],
270271
},
271272
SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])},
272273
SchemaOptional("post_training_bias"): {"methods": Or(str, [str])},
@@ -1308,6 +1309,7 @@ def __init__(
13081309
num_clusters: Optional[int] = None,
13091310
text_config: Optional[TextConfig] = None,
13101311
image_config: Optional[ImageConfig] = None,
1312+
features_to_explain: Optional[List[Union[str, int]]] = None,
13111313
):
13121314
"""Initializes config for SHAP analysis.
13131315
@@ -1343,6 +1345,14 @@ def __init__(
13431345
text features. Default is None.
13441346
image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
13451347
features. Default is None.
1348+
features_to_explain: A list of names or indices of dataset features to compute SHAP
1349+
values for. If not provided, SHAP values are computed for all features by default.
1350+
Currently only supported for tabular datasets.
1351+
1352+
Raises:
1353+
ValueError: when ``agg_method`` is invalid, ``baseline`` and ``num_clusters`` are provided
1354+
together, or ``features_to_explain`` is specified when ``text_config`` or
1355+
``image_config`` is provided
13461356
""" # noqa E501 # pylint: disable=c0301
13471357
if agg_method is not None and agg_method not in [
13481358
"mean_abs",
@@ -1376,6 +1386,13 @@ def __init__(
13761386
)
13771387
if image_config:
13781388
_set(image_config.get_image_config(), "image_config", self.shap_config)
1389+
if features_to_explain is not None and (
1390+
text_config is not None or image_config is not None
1391+
):
1392+
raise ValueError(
1393+
"`features_to_explain` is not supported for datasets containing text features or images."
1394+
)
1395+
_set(features_to_explain, "features_to_explain", self.shap_config)
13791396

13801397
def get_explainability_config(self):
13811398
"""Returns a shap config dictionary."""

tests/unit/test_clarify.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,37 @@ def test_valid_shap_config(baseline):
716716
assert expected_config == shap_config.get_explainability_config()
717717

718718

719+
def test_shap_config_features_to_explain():
720+
baseline = [1, 2, 3]
721+
num_samples = 100
722+
agg_method = "mean_sq"
723+
use_logit = True
724+
save_local_shap_values = True
725+
seed = 123
726+
features_to_explain = [0, 1]
727+
shap_config = SHAPConfig(
728+
baseline=baseline,
729+
num_samples=num_samples,
730+
agg_method=agg_method,
731+
use_logit=use_logit,
732+
save_local_shap_values=save_local_shap_values,
733+
seed=seed,
734+
features_to_explain=features_to_explain,
735+
)
736+
expected_config = {
737+
"shap": {
738+
"baseline": baseline,
739+
"num_samples": num_samples,
740+
"agg_method": agg_method,
741+
"use_logit": use_logit,
742+
"save_local_shap_values": save_local_shap_values,
743+
"seed": seed,
744+
"features_to_explain": features_to_explain,
745+
}
746+
}
747+
assert expected_config == shap_config.get_explainability_config()
748+
749+
719750
def test_shap_config_no_baseline():
720751
num_samples = 100
721752
agg_method = "mean_sq"
@@ -852,6 +883,17 @@ def test_invalid_shap_config():
852883
"Baseline and num_clusters cannot be provided together. Please specify one of the two."
853884
in str(error.value)
854885
)
886+
with pytest.raises(ValueError) as error:
887+
SHAPConfig(
888+
baseline=[[1, 2]],
889+
num_samples=1,
890+
text_config=TextConfig(granularity="token", language="english"),
891+
features_to_explain=[0],
892+
)
893+
assert (
894+
"`features_to_explain` is not supported for datasets containing text features or images."
895+
in str(error.value)
896+
)
855897

856898

857899
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)