@@ -300,17 +300,49 @@ def get_explainability_config(self):
300
300
return None
301
301
302
302
303
+ class PDPConfig (ExplainabilityConfig ):
304
+ """Config class for Partial Dependence Plots (PDP).
305
+
306
+ If PDP is requested, the Partial Dependence Plots will be included in the report, and the
307
+ corresponding values will be included in the analysis output.
308
+ """
309
+
310
+ def __init__ (self , features = None , grid_resolution = 15 , top_k_features = 10 ):
311
+ """Initializes config for PDP.
312
+
313
+ Args:
314
+ features (None or list): List of features names or indices for which partial dependence
315
+ plots must be computed and plotted. When ShapConfig is provided, this parameter is
316
+ optional as Clarify will try to compute the partial dependence plots for top
317
+ feature based on SHAP attributions. When ShapConfig is not provided, 'features'
318
+ must be provided.
319
+ grid_resolution (int): In case of numerical features, this number represents that
320
+ number of buckets that range of values must be divided into. This decides the
321
+ granularity of the grid in which the PDP are plotted.
322
+ top_k_features (int): Set the number of top SHAP attributes to be selected to compute
323
+ partial dependence plots.
324
+ """
325
+ self .pdp_config = {"grid_resolution" : grid_resolution , "top_k_features" : top_k_features }
326
+ if features is not None :
327
+ self .pdp_config ["features" ] = features
328
+
329
+ def get_explainability_config (self ):
330
+ """Returns config."""
331
+ return copy .deepcopy ({"pdp" : self .pdp_config })
332
+
333
+
303
334
class SHAPConfig (ExplainabilityConfig ):
304
335
"""Config class of SHAP."""
305
336
306
337
def __init__ (
307
338
self ,
308
- baseline ,
309
- num_samples ,
310
- agg_method ,
339
+ baseline = None ,
340
+ num_samples = None ,
341
+ agg_method = None ,
311
342
use_logit = False ,
312
343
save_local_shap_values = True ,
313
344
seed = None ,
345
+ num_clusters = None ,
314
346
):
315
347
"""Initializes config for SHAP.
316
348
@@ -320,34 +352,49 @@ def __init__(
320
352
be the same as the dataset format. Each row should contain only the feature
321
353
columns/values and omit the label column/values. If None a baseline will be
322
354
calculated automatically by using K-means or K-prototypes in the input dataset.
323
- num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
355
+ num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
324
356
This number determines the size of the generated synthetic dataset to compute the
325
- SHAP values.
326
- agg_method (str): Aggregation method for global SHAP values. Valid values are
357
+ SHAP values. If not provided then Clarify job will choose a proper value according
358
+ to the count of features.
359
+ agg_method (None or str): Aggregation method for global SHAP values. Valid values are
327
360
"mean_abs" (mean of absolute SHAP values for all instances),
328
361
"median" (median of SHAP values for all instances) and
329
362
"mean_sq" (mean of squared SHAP values for all instances).
363
+ If not provided then Clarify job uses method "mean_abs"
330
364
use_logit (bool): Indicator of whether the logit function is to be applied to the model
331
365
predictions. Default is False. If "use_logit" is true then the SHAP values will
332
366
have log-odds units.
333
367
save_local_shap_values (bool): Indicator of whether to save the local SHAP values
334
368
in the output location. Default is True.
335
369
seed (int): seed value to get deterministic SHAP values. Default is None.
370
+ num_clusters (None or int): If a baseline is not provided, Clarify automatically
371
+ computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
372
+ num_clusters is a parameter for this algorithm. num_clusters will be the resulting
373
+ size of the baseline dataset. If not provided, Clarify job will use a default value.
336
374
"""
337
- if agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
375
+ if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
338
376
raise ValueError (
339
377
f"Invalid agg_method { agg_method } ." f" Please choose mean_abs, median, or mean_sq."
340
378
)
341
-
379
+ if num_clusters is not None and baseline is not None :
380
+ raise ValueError (
381
+ "Baseline and num_clusters cannot be provided together. "
382
+ "Please specify one of the two."
383
+ )
342
384
self .shap_config = {
343
- "baseline" : baseline ,
344
- "num_samples" : num_samples ,
345
- "agg_method" : agg_method ,
346
385
"use_logit" : use_logit ,
347
386
"save_local_shap_values" : save_local_shap_values ,
348
387
}
388
+ if baseline is not None :
389
+ self .shap_config ["baseline" ] = baseline
390
+ if num_samples is not None :
391
+ self .shap_config ["num_samples" ] = num_samples
392
+ if agg_method is not None :
393
+ self .shap_config ["agg_method" ] = agg_method
349
394
if seed is not None :
350
395
self .shap_config ["seed" ] = seed
396
+ if num_clusters is not None :
397
+ self .shap_config ["num_clusters" ] = num_clusters
351
398
352
399
def get_explainability_config (self ):
353
400
"""Returns config."""
@@ -776,8 +823,9 @@ def run_explainability(
776
823
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
777
824
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
778
825
endpoint to be created.
779
- explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
780
- specific explainability method. Currently, only SHAP is supported.
826
+ explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
827
+ Config of the specific explainability method or a list of ExplainabilityConfig
828
+ objects. Currently, SHAP and PDP are the two methods supported.
781
829
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
782
830
model output for the predicted scores to be explained. This is not required if the
783
831
model output is a single score. Alternatively, an instance of
@@ -811,7 +859,30 @@ def run_explainability(
811
859
predictor_config .update (predicted_label_config )
812
860
else :
813
861
_set (model_scores , "label" , predictor_config )
814
- analysis_config ["methods" ] = explainability_config .get_explainability_config ()
862
+
863
+ explainability_methods = {}
864
+ if isinstance (explainability_config , list ):
865
+ if len (explainability_config ) == 0 :
866
+ raise ValueError ("Please provide at least one explainability config." )
867
+ for config in explainability_config :
868
+ explain_config = config .get_explainability_config ()
869
+ explainability_methods .update (explain_config )
870
+ if not len (explainability_methods .keys ()) == len (explainability_config ):
871
+ raise ValueError ("Duplicate explainability configs are provided" )
872
+ if (
873
+ "shap" not in explainability_methods
874
+ and explainability_methods ["pdp" ].get ("features" , None ) is None
875
+ ):
876
+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
877
+ else :
878
+ if (
879
+ isinstance (explainability_config , PDPConfig )
880
+ and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
881
+ is None
882
+ ):
883
+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
884
+ explainability_methods = explainability_config .get_explainability_config ()
885
+ analysis_config ["methods" ] = explainability_methods
815
886
analysis_config ["predictor" ] = predictor_config
816
887
if job_name is None :
817
888
if self .job_name_prefix :
0 commit comments