Skip to content

Commit b1a25bd

Browse files
committed
Add features_to_explain to shap config
1 parent 6d42cc8 commit b1a25bd

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/sagemaker/clarify.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,7 @@ def __init__(
12221222
num_clusters: Optional[int] = None,
12231223
text_config: Optional[TextConfig] = None,
12241224
image_config: Optional[ImageConfig] = None,
1225+
features_to_explain: Optional[List[Union[str, int]]] = None,
12251226
):
12261227
"""Initializes config for SHAP analysis.
12271228
@@ -1257,6 +1258,9 @@ def __init__(
12571258
text features. Default is None.
12581259
image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
12591260
features. Default is None.
1261+
features_to_explain: A list of names or indices of dataset features to compute SHAP
1262+
values for. If not provided, SHAP values are computed for all features by default.
1263+
Currently only supported for tabular datasets.
12601264
""" # noqa E501 # pylint: disable=c0301
12611265
if agg_method is not None and agg_method not in [
12621266
"mean_abs",
@@ -1290,6 +1294,13 @@ def __init__(
12901294
)
12911295
if image_config:
12921296
_set(image_config.get_image_config(), "image_config", self.shap_config)
1297+
if features_to_explain is not None and (
1298+
text_config is not None or image_config is not None
1299+
):
1300+
raise ValueError(
1301+
"`features_to_explain` is not supported for datasets containing text features or images."
1302+
)
1303+
_set(features_to_explain, "features_to_explain", self.shap_config)
12931304

12941305
def get_explainability_config(self):
12951306
"""Returns a shap config dictionary."""

tests/unit/test_clarify.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,37 @@ def test_valid_shap_config(baseline):
638638
assert expected_config == shap_config.get_explainability_config()
639639

640640

641+
def test_shap_config_features_to_explain():
642+
baseline = [1, 2, 3]
643+
num_samples = 100
644+
agg_method = "mean_sq"
645+
use_logit = True
646+
save_local_shap_values = True
647+
seed = 123
648+
features_to_explain = [0, 1]
649+
shap_config = SHAPConfig(
650+
baseline=baseline,
651+
num_samples=num_samples,
652+
agg_method=agg_method,
653+
use_logit=use_logit,
654+
save_local_shap_values=save_local_shap_values,
655+
seed=seed,
656+
features_to_explain=features_to_explain,
657+
)
658+
expected_config = {
659+
"shap": {
660+
"baseline": baseline,
661+
"num_samples": num_samples,
662+
"agg_method": agg_method,
663+
"use_logit": use_logit,
664+
"save_local_shap_values": save_local_shap_values,
665+
"seed": seed,
666+
"features_to_explain": features_to_explain,
667+
}
668+
}
669+
assert expected_config == shap_config.get_explainability_config()
670+
671+
641672
def test_shap_config_no_baseline():
642673
num_samples = 100
643674
agg_method = "mean_sq"
@@ -774,6 +805,17 @@ def test_invalid_shap_config():
774805
"Baseline and num_clusters cannot be provided together. Please specify one of the two."
775806
in str(error.value)
776807
)
808+
with pytest.raises(ValueError) as error:
809+
SHAPConfig(
810+
baseline=[[1, 2]],
811+
num_samples=1,
812+
text_config=TextConfig(granularity="token", language="english"),
813+
features_to_explain=[0],
814+
)
815+
assert (
816+
"`features_to_explain` is not supported for datasets containing text features or images."
817+
in str(error.value)
818+
)
777819

778820

779821
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)