Skip to content

Add features_to_explain to shap config #3951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@
},
},
SchemaOptional("seed"): int,
SchemaOptional("features_to_explain"): [Or(int, str)],
},
SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])},
SchemaOptional("post_training_bias"): {"methods": Or(str, [str])},
Expand Down Expand Up @@ -1308,6 +1309,7 @@ def __init__(
num_clusters: Optional[int] = None,
text_config: Optional[TextConfig] = None,
image_config: Optional[ImageConfig] = None,
features_to_explain: Optional[List[Union[str, int]]] = None,
):
"""Initializes config for SHAP analysis.

Expand Down Expand Up @@ -1343,6 +1345,14 @@ def __init__(
text features. Default is None.
image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
features. Default is None.
features_to_explain: A list of names or indices of dataset features to compute SHAP
values for. If not provided, SHAP values are computed for all features by default.
Currently only supported for tabular datasets.

Raises:
ValueError: when ``agg_method`` is invalid, ``baseline`` and ``num_clusters`` are provided
together, or ``features_to_explain`` is specified when ``text_config`` or
``image_config`` is provided
""" # noqa E501 # pylint: disable=c0301
if agg_method is not None and agg_method not in [
"mean_abs",
Expand Down Expand Up @@ -1376,6 +1386,13 @@ def __init__(
)
if image_config:
_set(image_config.get_image_config(), "image_config", self.shap_config)
if features_to_explain is not None and (
text_config is not None or image_config is not None
):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add this in method docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in latest commit.

"`features_to_explain` is not supported for datasets containing text features or images."
)
_set(features_to_explain, "features_to_explain", self.shap_config)

def get_explainability_config(self):
"""Returns a shap config dictionary."""
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,37 @@ def test_valid_shap_config(baseline):
assert expected_config == shap_config.get_explainability_config()


def test_shap_config_features_to_explain():
baseline = [1, 2, 3]
num_samples = 100
agg_method = "mean_sq"
use_logit = True
save_local_shap_values = True
seed = 123
features_to_explain = [0, 1]
shap_config = SHAPConfig(
baseline=baseline,
num_samples=num_samples,
agg_method=agg_method,
use_logit=use_logit,
save_local_shap_values=save_local_shap_values,
seed=seed,
features_to_explain=features_to_explain,
)
expected_config = {
"shap": {
"baseline": baseline,
"num_samples": num_samples,
"agg_method": agg_method,
"use_logit": use_logit,
"save_local_shap_values": save_local_shap_values,
"seed": seed,
"features_to_explain": features_to_explain,
}
}
assert expected_config == shap_config.get_explainability_config()


def test_shap_config_no_baseline():
num_samples = 100
agg_method = "mean_sq"
Expand Down Expand Up @@ -852,6 +883,17 @@ def test_invalid_shap_config():
"Baseline and num_clusters cannot be provided together. Please specify one of the two."
in str(error.value)
)
with pytest.raises(ValueError) as error:
SHAPConfig(
baseline=[[1, 2]],
num_samples=1,
text_config=TextConfig(granularity="token", language="english"),
features_to_explain=[0],
)
assert (
"`features_to_explain` is not supported for datasets containing text features or images."
in str(error.value)
)


@pytest.fixture(scope="module")
Expand Down