11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
"""This module configures the SageMaker Clarify bias and model explainability processor job."""
14
- from __future__ import print_function , absolute_import
14
+ from __future__ import absolute_import , print_function
15
15
16
16
import copy
17
-
18
- from abc import ABC , abstractmethod
19
17
import json
18
+ import logging
20
19
import os
21
- import tempfile
22
20
import re
23
- from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
21
+ import tempfile
22
+ from abc import ABC , abstractmethod
23
+
24
24
from sagemaker import image_uris , s3 , utils
25
+ from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
26
+
27
+ logger = logging .getLogger (__name__ )
25
28
26
29
27
30
class DataConfig :
@@ -338,6 +341,121 @@ def get_explainability_config(self):
338
341
return copy .deepcopy ({"pdp" : self .pdp_config })
339
342
340
343
344
+ class TextConfig :
345
+ """Config object to handle text features.
346
+
347
+ The SHAP analysis will break down longer text into chunks (e.g. tokens, sentences, or paragraphs
348
+ ) and replace them with the strings specified in the baseline for that feature. The shap value
349
+ of a chunk then captures how much replacing it affects the prediction.
350
+ """
351
+
352
+ _SUPPORTED_GRANULARITIES = ["token" , "sentence" , "paragraph" ]
353
+ _SUPPORTED_LANGUAGES = [
354
+ "chinese" ,
355
+ "danish" ,
356
+ "dutch" ,
357
+ "english" ,
358
+ "french" ,
359
+ "german" ,
360
+ "greek" ,
361
+ "italian" ,
362
+ "japanese" ,
363
+ "lithuanian" ,
364
+ "multi-language" ,
365
+ "norwegian bokmål" ,
366
+ "polish" ,
367
+ "portuguese" ,
368
+ "romanian" ,
369
+ "russian" ,
370
+ "spanish" ,
371
+ "afrikaans" ,
372
+ "albanian" ,
373
+ "arabic" ,
374
+ "armenian" ,
375
+ "basque" ,
376
+ "bengali" ,
377
+ "bulgarian" ,
378
+ "catalan" ,
379
+ "croatian" ,
380
+ "czech" ,
381
+ "estonian" ,
382
+ "finnish" ,
383
+ "gujarati" ,
384
+ "hebrew" ,
385
+ "hindi" ,
386
+ "hungarian" ,
387
+ "icelandic" ,
388
+ "indonesian" ,
389
+ "irish" ,
390
+ "kannada" ,
391
+ "kyrgyz" ,
392
+ "latvian" ,
393
+ "ligurian" ,
394
+ "luxembourgish" ,
395
+ "macedonian" ,
396
+ "malayalam" ,
397
+ "marathi" ,
398
+ "nepali" ,
399
+ "persian" ,
400
+ "sanskrit" ,
401
+ "serbian" ,
402
+ "setswana" ,
403
+ "sinhala" ,
404
+ "slovak" ,
405
+ "slovenian" ,
406
+ "swedish" ,
407
+ "tagalog" ,
408
+ "tamil" ,
409
+ "tatar" ,
410
+ "telugu" ,
411
+ "thai" ,
412
+ "turkish" ,
413
+ "ukrainian" ,
414
+ "urdu" ,
415
+ "vietnamese" ,
416
+ "yoruba" ,
417
+ ]
418
+
419
+ def __init__ (
420
+ self ,
421
+ granularity ,
422
+ language ,
423
+ ):
424
+ """Initializes a text configuration.
425
+
426
+ Args: granularity (str): Determines the granularity in which text features are broken down
427
+ to, can be "token", "sentence", or "paragraph". Shap values are computed for these units.
428
+ language (str): Specifies the language of the text features, can be "chinese", "danish",
429
+ "dutch", "english", "french", "german", "greek", "italian", "japanese", "lithuanian",
430
+ "multi-language", "norwegian bokmål", "polish", "portuguese", "romanian", "russian",
431
+ "spanish", "afrikaans", "albanian", "arabic", "armenian", "basque", "bengali", "bulgarian",
432
+ "catalan", "croatian", "czech", "estonian", "finnish", "gujarati", "hebrew", "hindi",
433
+ "hungarian", "icelandic", "indonesian", "irish", "kannada", "kyrgyz", "latvian", "ligurian",
434
+ "luxembourgish", "macedonian", "malayalam", "marathi", "nepali", "persian", "sanskrit",
435
+ "serbian", "setswana", "sinhala", "slovak", "slovenian", "swedish", "tagalog", "tamil",
436
+ "tatar", "telugu", "thai", "turkish", "ukrainian", "urdu", "vietnamese", "yoruba". Use
437
+ "multi-language" for a mix of mulitple languages.
438
+ """
439
+ if granularity not in TextConfig ._SUPPORTED_GRANULARITIES :
440
+ raise ValueError (
441
+ f"Invalid granularity { granularity } . Please choose among "
442
+ f"{ TextConfig ._SUPPORTED_GRANULARITIES } "
443
+ )
444
+ if language not in TextConfig ._SUPPORTED_LANGUAGES :
445
+ raise ValueError (
446
+ f"Invalid language { language } . Please choose among "
447
+ f"{ TextConfig ._SUPPORTED_LANGUAGES } "
448
+ )
449
+ self .text_config = {
450
+ "granularity" : granularity ,
451
+ "language" : language ,
452
+ }
453
+
454
+ def get_text_config (self ):
455
+ """Returns part of an analysis config dictionary."""
456
+ return copy .deepcopy (self .text_config )
457
+
458
+
341
459
class SHAPConfig (ExplainabilityConfig ):
342
460
"""Config class of SHAP."""
343
461
@@ -350,6 +468,7 @@ def __init__(
350
468
save_local_shap_values = True ,
351
469
seed = None ,
352
470
num_clusters = None ,
471
+ text_config = None ,
353
472
):
354
473
"""Initializes config for SHAP.
355
474
@@ -378,6 +497,7 @@ def __init__(
378
497
computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
379
498
num_clusters is a parameter for this algorithm. num_clusters will be the resulting
380
499
size of the baseline dataset. If not provided, Clarify job will use a default value.
500
+ text_config (:class:`~sagemaker.clarify.TextConfig`): Config to handle text features
381
501
"""
382
502
if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
383
503
raise ValueError (
@@ -402,6 +522,15 @@ def __init__(
402
522
self .shap_config ["seed" ] = seed
403
523
if num_clusters is not None :
404
524
self .shap_config ["num_clusters" ] = num_clusters
525
+ _set (seed , "seed" , self .shap_config )
526
+ if text_config :
527
+ _set (text_config .get_text_config (), "text_config" , self .shap_config )
528
+ if not save_local_shap_values :
529
+ logger .warning (
530
+ "Global aggregation is not yet supported for text features. "
531
+ "Consider setting save_local_shap_values=True to inspect local text "
532
+ "explanations."
533
+ )
405
534
406
535
def get_explainability_config (self ):
407
536
"""Returns config."""
@@ -525,7 +654,10 @@ def _run(
525
654
will be unassociated.
526
655
* `TrialComponentDisplayName` is used for display in Studio.
527
656
"""
528
- analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
657
+ analysis_config ["methods" ]["report" ] = {
658
+ "name" : "report" ,
659
+ "title" : "Analysis Report" ,
660
+ }
529
661
with tempfile .TemporaryDirectory () as tmpdirname :
530
662
analysis_config_file = os .path .join (tmpdirname , "analysis_config.json" )
531
663
with open (analysis_config_file , "w" ) as f :
@@ -627,7 +759,15 @@ def run_pre_training_bias(
627
759
job_name = utils .name_from_base (self .job_name_prefix )
628
760
else :
629
761
job_name = utils .name_from_base ("Clarify-Pretraining-Bias" )
630
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
762
+ self ._run (
763
+ data_config ,
764
+ analysis_config ,
765
+ wait ,
766
+ logs ,
767
+ job_name ,
768
+ kms_key ,
769
+ experiment_config ,
770
+ )
631
771
632
772
def run_post_training_bias (
633
773
self ,
@@ -705,7 +845,15 @@ def run_post_training_bias(
705
845
job_name = utils .name_from_base (self .job_name_prefix )
706
846
else :
707
847
job_name = utils .name_from_base ("Clarify-Posttraining-Bias" )
708
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
848
+ self ._run (
849
+ data_config ,
850
+ analysis_config ,
851
+ wait ,
852
+ logs ,
853
+ job_name ,
854
+ kms_key ,
855
+ experiment_config ,
856
+ )
709
857
710
858
def run_bias (
711
859
self ,
@@ -800,7 +948,15 @@ def run_bias(
800
948
job_name = utils .name_from_base (self .job_name_prefix )
801
949
else :
802
950
job_name = utils .name_from_base ("Clarify-Bias" )
803
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
951
+ self ._run (
952
+ data_config ,
953
+ analysis_config ,
954
+ wait ,
955
+ logs ,
956
+ job_name ,
957
+ kms_key ,
958
+ experiment_config ,
959
+ )
804
960
805
961
def run_explainability (
806
962
self ,
@@ -861,7 +1017,10 @@ def run_explainability(
861
1017
analysis_config = data_config .get_config ()
862
1018
predictor_config = model_config .get_predictor_config ()
863
1019
if isinstance (model_scores , ModelPredictedLabelConfig ):
864
- probability_threshold , predicted_label_config = model_scores .get_predictor_config ()
1020
+ (
1021
+ probability_threshold ,
1022
+ predicted_label_config ,
1023
+ ) = model_scores .get_predictor_config ()
865
1024
_set (probability_threshold , "probability_threshold" , analysis_config )
866
1025
predictor_config .update (predicted_label_config )
867
1026
else :
@@ -896,7 +1055,15 @@ def run_explainability(
896
1055
job_name = utils .name_from_base (self .job_name_prefix )
897
1056
else :
898
1057
job_name = utils .name_from_base ("Clarify-Explainability" )
899
- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
1058
+ self ._run (
1059
+ data_config ,
1060
+ analysis_config ,
1061
+ wait ,
1062
+ logs ,
1063
+ job_name ,
1064
+ kms_key ,
1065
+ experiment_config ,
1066
+ )
900
1067
901
1068
902
1069
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
0 commit comments