|
267 | 267 | },
|
268 | 268 | },
|
269 | 269 | SchemaOptional("seed"): int,
|
| 270 | + SchemaOptional("features_to_explain"): [Or(int, str)], |
270 | 271 | },
|
271 | 272 | SchemaOptional("pre_training_bias"): {"methods": Or(str, [str])},
|
272 | 273 | SchemaOptional("post_training_bias"): {"methods": Or(str, [str])},
|
@@ -1308,6 +1309,7 @@ def __init__(
|
1308 | 1309 | num_clusters: Optional[int] = None,
|
1309 | 1310 | text_config: Optional[TextConfig] = None,
|
1310 | 1311 | image_config: Optional[ImageConfig] = None,
|
| 1312 | + features_to_explain: Optional[List[Union[str, int]]] = None, |
1311 | 1313 | ):
|
1312 | 1314 | """Initializes config for SHAP analysis.
|
1313 | 1315 |
|
@@ -1343,6 +1345,14 @@ def __init__(
|
1343 | 1345 | text features. Default is None.
|
1344 | 1346 | image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
|
1345 | 1347 | features. Default is None.
|
| 1348 | + features_to_explain: A list of names or indices of dataset features to compute SHAP |
| 1349 | + values for. If not provided, SHAP values are computed for all features by default. |
| 1350 | + Currently only supported for tabular datasets. |
| 1351 | +
|
| 1352 | + Raises: |
| 1353 | + ValueError: when ``agg_method`` is invalid, ``baseline`` and ``num_clusters`` are provided |
| 1354 | + together, or ``features_to_explain`` is specified when ``text_config`` or |
| 1355 | + ``image_config`` is provided |
1346 | 1356 | """ # noqa E501 # pylint: disable=c0301
|
1347 | 1357 | if agg_method is not None and agg_method not in [
|
1348 | 1358 | "mean_abs",
|
@@ -1376,6 +1386,13 @@ def __init__(
|
1376 | 1386 | )
|
1377 | 1387 | if image_config:
|
1378 | 1388 | _set(image_config.get_image_config(), "image_config", self.shap_config)
|
| 1389 | + if features_to_explain is not None and ( |
| 1390 | + text_config is not None or image_config is not None |
| 1391 | + ): |
| 1392 | + raise ValueError( |
| 1393 | + "`features_to_explain` is not supported for datasets containing text features or images." |
| 1394 | + ) |
| 1395 | + _set(features_to_explain, "features_to_explain", self.shap_config) |
1379 | 1396 |
|
1380 | 1397 | def get_explainability_config(self):
|
1381 | 1398 | """Returns a shap config dictionary."""
|
|
0 commit comments