Skip to content

Commit 7620749

Browse files
keerthanvasistahsan-z-khan
authored andcommitted
feature: add NLP support for SageMaker Clarify
1 parent 3b070ac commit 7620749

File tree

2 files changed

+293
-29
lines changed

2 files changed

+293
-29
lines changed

src/sagemaker/clarify.py

Lines changed: 178 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""This module configures the SageMaker Clarify bias and model explainability processor job."""
14-
from __future__ import print_function, absolute_import
14+
from __future__ import absolute_import, print_function
1515

1616
import copy
17-
18-
from abc import ABC, abstractmethod
1917
import json
18+
import logging
2019
import os
21-
import tempfile
2220
import re
23-
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
21+
import tempfile
22+
from abc import ABC, abstractmethod
23+
2424
from sagemaker import image_uris, s3, utils
25+
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
26+
27+
logger = logging.getLogger(__name__)
2528

2629

2730
class DataConfig:
@@ -338,6 +341,121 @@ def get_explainability_config(self):
338341
return copy.deepcopy({"pdp": self.pdp_config})
339342

340343

344+
class TextConfig:
345+
"""Config object to handle text features.
346+
347+
The SHAP analysis will break down longer text into chunks (e.g. tokens, sentences, or paragraphs
348+
) and replace them with the strings specified in the baseline for that feature. The shap value
349+
of a chunk then captures how much replacing it affects the prediction.
350+
"""
351+
352+
_SUPPORTED_GRANULARITIES = ["token", "sentence", "paragraph"]
353+
_SUPPORTED_LANGUAGES = [
354+
"chinese",
355+
"danish",
356+
"dutch",
357+
"english",
358+
"french",
359+
"german",
360+
"greek",
361+
"italian",
362+
"japanese",
363+
"lithuanian",
364+
"multi-language",
365+
"norwegian bokmål",
366+
"polish",
367+
"portuguese",
368+
"romanian",
369+
"russian",
370+
"spanish",
371+
"afrikaans",
372+
"albanian",
373+
"arabic",
374+
"armenian",
375+
"basque",
376+
"bengali",
377+
"bulgarian",
378+
"catalan",
379+
"croatian",
380+
"czech",
381+
"estonian",
382+
"finnish",
383+
"gujarati",
384+
"hebrew",
385+
"hindi",
386+
"hungarian",
387+
"icelandic",
388+
"indonesian",
389+
"irish",
390+
"kannada",
391+
"kyrgyz",
392+
"latvian",
393+
"ligurian",
394+
"luxembourgish",
395+
"macedonian",
396+
"malayalam",
397+
"marathi",
398+
"nepali",
399+
"persian",
400+
"sanskrit",
401+
"serbian",
402+
"setswana",
403+
"sinhala",
404+
"slovak",
405+
"slovenian",
406+
"swedish",
407+
"tagalog",
408+
"tamil",
409+
"tatar",
410+
"telugu",
411+
"thai",
412+
"turkish",
413+
"ukrainian",
414+
"urdu",
415+
"vietnamese",
416+
"yoruba",
417+
]
418+
419+
def __init__(
420+
self,
421+
granularity,
422+
language,
423+
):
424+
"""Initializes a text configuration.
425+
426+
Args: granularity (str): Determines the granularity in which text features are broken down
427+
to, can be "token", "sentence", or "paragraph". Shap values are computed for these units.
428+
language (str): Specifies the language of the text features, can be "chinese", "danish",
429+
"dutch", "english", "french", "german", "greek", "italian", "japanese", "lithuanian",
430+
"multi-language", "norwegian bokmål", "polish", "portuguese", "romanian", "russian",
431+
"spanish", "afrikaans", "albanian", "arabic", "armenian", "basque", "bengali", "bulgarian",
432+
"catalan", "croatian", "czech", "estonian", "finnish", "gujarati", "hebrew", "hindi",
433+
"hungarian", "icelandic", "indonesian", "irish", "kannada", "kyrgyz", "latvian", "ligurian",
434+
"luxembourgish", "macedonian", "malayalam", "marathi", "nepali", "persian", "sanskrit",
435+
"serbian", "setswana", "sinhala", "slovak", "slovenian", "swedish", "tagalog", "tamil",
436+
"tatar", "telugu", "thai", "turkish", "ukrainian", "urdu", "vietnamese", "yoruba". Use
437+
"multi-language" for a mix of mulitple languages.
438+
"""
439+
if granularity not in TextConfig._SUPPORTED_GRANULARITIES:
440+
raise ValueError(
441+
f"Invalid granularity {granularity}. Please choose among "
442+
f"{TextConfig._SUPPORTED_GRANULARITIES}"
443+
)
444+
if language not in TextConfig._SUPPORTED_LANGUAGES:
445+
raise ValueError(
446+
f"Invalid language {language}. Please choose among "
447+
f"{TextConfig._SUPPORTED_LANGUAGES}"
448+
)
449+
self.text_config = {
450+
"granularity": granularity,
451+
"language": language,
452+
}
453+
454+
def get_text_config(self):
455+
"""Returns part of an analysis config dictionary."""
456+
return copy.deepcopy(self.text_config)
457+
458+
341459
class SHAPConfig(ExplainabilityConfig):
342460
"""Config class of SHAP."""
343461

@@ -350,6 +468,7 @@ def __init__(
350468
save_local_shap_values=True,
351469
seed=None,
352470
num_clusters=None,
471+
text_config=None,
353472
):
354473
"""Initializes config for SHAP.
355474
@@ -378,6 +497,7 @@ def __init__(
378497
computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
379498
num_clusters is a parameter for this algorithm. num_clusters will be the resulting
380499
size of the baseline dataset. If not provided, Clarify job will use a default value.
500+
text_config (:class:`~sagemaker.clarify.TextConfig`): Config to handle text features
381501
"""
382502
if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]:
383503
raise ValueError(
@@ -402,6 +522,15 @@ def __init__(
402522
self.shap_config["seed"] = seed
403523
if num_clusters is not None:
404524
self.shap_config["num_clusters"] = num_clusters
525+
_set(seed, "seed", self.shap_config)
526+
if text_config:
527+
_set(text_config.get_text_config(), "text_config", self.shap_config)
528+
if not save_local_shap_values:
529+
logger.warning(
530+
"Global aggregation is not yet supported for text features. "
531+
"Consider setting save_local_shap_values=True to inspect local text "
532+
"explanations."
533+
)
405534

406535
def get_explainability_config(self):
407536
"""Returns config."""
@@ -525,7 +654,10 @@ def _run(
525654
will be unassociated.
526655
* `TrialComponentDisplayName` is used for display in Studio.
527656
"""
528-
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
657+
analysis_config["methods"]["report"] = {
658+
"name": "report",
659+
"title": "Analysis Report",
660+
}
529661
with tempfile.TemporaryDirectory() as tmpdirname:
530662
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
531663
with open(analysis_config_file, "w") as f:
@@ -627,7 +759,15 @@ def run_pre_training_bias(
627759
job_name = utils.name_from_base(self.job_name_prefix)
628760
else:
629761
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
630-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
762+
self._run(
763+
data_config,
764+
analysis_config,
765+
wait,
766+
logs,
767+
job_name,
768+
kms_key,
769+
experiment_config,
770+
)
631771

632772
def run_post_training_bias(
633773
self,
@@ -705,7 +845,15 @@ def run_post_training_bias(
705845
job_name = utils.name_from_base(self.job_name_prefix)
706846
else:
707847
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
708-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
848+
self._run(
849+
data_config,
850+
analysis_config,
851+
wait,
852+
logs,
853+
job_name,
854+
kms_key,
855+
experiment_config,
856+
)
709857

710858
def run_bias(
711859
self,
@@ -800,7 +948,15 @@ def run_bias(
800948
job_name = utils.name_from_base(self.job_name_prefix)
801949
else:
802950
job_name = utils.name_from_base("Clarify-Bias")
803-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
951+
self._run(
952+
data_config,
953+
analysis_config,
954+
wait,
955+
logs,
956+
job_name,
957+
kms_key,
958+
experiment_config,
959+
)
804960

805961
def run_explainability(
806962
self,
@@ -861,7 +1017,10 @@ def run_explainability(
8611017
analysis_config = data_config.get_config()
8621018
predictor_config = model_config.get_predictor_config()
8631019
if isinstance(model_scores, ModelPredictedLabelConfig):
864-
probability_threshold, predicted_label_config = model_scores.get_predictor_config()
1020+
(
1021+
probability_threshold,
1022+
predicted_label_config,
1023+
) = model_scores.get_predictor_config()
8651024
_set(probability_threshold, "probability_threshold", analysis_config)
8661025
predictor_config.update(predicted_label_config)
8671026
else:
@@ -896,7 +1055,15 @@ def run_explainability(
8961055
job_name = utils.name_from_base(self.job_name_prefix)
8971056
else:
8981057
job_name = utils.name_from_base("Clarify-Explainability")
899-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
1058+
self._run(
1059+
data_config,
1060+
analysis_config,
1061+
wait,
1062+
logs,
1063+
job_name,
1064+
kms_key,
1065+
experiment_config,
1066+
)
9001067

9011068

9021069
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):

0 commit comments

Comments
 (0)