Skip to content

Commit 424d90e

Browse files
author
Samsara Counts
committed
feature: clarify bias detection when facets not included
* feature: clarify bias detection when facets are not included in the input dataset (and passed through as other parameters). * feature: clarify model inference with excluded columns in the input dataset * feature: clarify bias analysis with input predicted label dataset (supports post-training bias analysis without model inference API calls) * documentation: add clarify DataConfig params for facet not included in input dataset * documentation: add clarify DataConfig param for analysis with predicted labels * documentation: add clarify DataConfig param for analysis with excluded columns * documentation: correct DataConfig label parameter description * documentation: add details about running SHAP and PDP to run_explainability
1 parent 4879ec3 commit 424d90e

File tree

3 files changed

+672
-36
lines changed

3 files changed

+672
-36
lines changed

src/sagemaker/clarify.py

Lines changed: 127 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def __init__(
4545
dataset_type="text/csv",
4646
s3_compression_type="None",
4747
joinsource=None,
48+
facet_dataset_uri=None,
49+
facet_headers=None,
50+
predicted_label_dataset_uri=None,
51+
predicted_label_headers=None,
52+
predicted_label=None,
53+
excluded_columns=None,
4854
):
4955
"""Initializes a configuration of both input and output datasets.
5056
@@ -54,22 +60,57 @@ def __init__(
5460
s3_analysis_config_output_path (str): S3 prefix to store the analysis config output.
5561
If this field is None, then the ``s3_output_path`` will be used
5662
to store the ``analysis_config`` output.
57-
label (str): Target attribute of the model **required** for bias metrics (both pre-
58-
and post-training). Optional when running SHAP explainability.
59-
Specified as column name or index for CSV dataset, or as JSONPath for JSONLines.
60-
headers (list[str]): A list of column names in the input dataset.
63+
label (str): Target attribute of the model required by bias metrics.
64+
Specified as column name or index for CSV dataset or as JSONPath for JSONLines.
65+
*Required parameter* except for when the input dataset does not contain the label.
66+
Cannot be used at the same time as ``predicted_label``.
6167
features (str): JSONPath for locating the feature columns for bias metrics if the
6268
dataset format is JSONLines.
6369
dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
6470
``"application/jsonlines"`` for JSONLines, and
6571
``"application/x-parquet"`` for Parquet.
6672
s3_compression_type (str): Valid options are "None" or ``"Gzip"``.
67-
joinsource (str): The name or index of the column in the dataset that acts as an
68-
identifier column (for instance, while performing a join). This column is only
69-
used as an identifier, and not used for any other computations. This is an
70-
optional field in all cases except when the dataset contains more than one file,
71-
and ``save_local_shap_values`` is set to True
72-
in :class:`~sagemaker.clarify.SHAPConfig`.
73+
joinsource (str or int): The name or index of the column in the dataset that
74+
acts as an identifier column (for instance, while performing a join).
75+
This column is only used as an identifier, and not used for any other computations.
76+
This is an optional field in all cases except:
77+
78+
* The dataset contains more than one file and `save_local_shap_values`
79+
is set to true in :class:`~sagemaker.clarify.ShapConfig`, and/or
80+
* When the dataset and/or facet dataset and/or predicted label dataset
81+
are in separate files.
82+
83+
facet_dataset_uri (str): Dataset S3 prefix/object URI that contains facet attribute(s),
84+
used for bias analysis on datasets without facets.
85+
86+
* If the dataset and the facet dataset are one single file each, then
87+
the original dataset and facet dataset must have the same number of rows.
88+
* If the dataset and facet dataset are in multiple files (either one), then
89+
an index column, ``joinsource``, is required to join the two datasets.
90+
91+
Clarify will not use the ``joinsource`` column and columns present in the facet
92+
dataset when calling model inference APIs.
93+
facet_headers (list[str]): List of column names in the facet dataset.
94+
predicted_label_dataset_uri (str): Dataset S3 prefix/object URI with predicted labels,
95+
which are used directly for analysis instead of making model inference API calls.
96+
97+
* If the dataset and the predicted label dataset are one single file each, then the
98+
original dataset and predicted label dataset must have the same number of rows.
99+
* If the dataset and predicted label dataset are in multiple files (either one),
100+
then an index column, ``joinsource``, is required to join the two datasets.
101+
102+
predicted_label_headers (list[str]): List of column names in the predicted label dataset
103+
predicted_label (str or int): Predicted label of the target attribute of the model
104+
required for running bias analysis. Specified as column name or index for CSV data.
105+
Clarify uses the predicted labels directly instead of making model inference API
106+
calls. Cannot be used at the same time as ``label``.
107+
excluded_columns (list[int] or list[str]): A list of names or indices of the columns
108+
which are to be excluded from making model inference API calls.
109+
110+
Raises:
111+
ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters
112+
are used with un-supported ``dataset_type``, or facet dataset parameters
113+
are used with un-supported ``dataset_type``
73114
"""
74115
if dataset_type not in [
75116
"text/csv",
@@ -81,6 +122,32 @@ def __init__(
81122
f"Invalid dataset_type '{dataset_type}'."
82123
f" Please check the API documentation for the supported dataset types."
83124
)
125+
# parameters for analysis on datasets without facets are only supported for CSV datasets
126+
if dataset_type != "text/csv":
127+
if predicted_label:
128+
raise ValueError(
129+
f"The parameter 'predicted_label' is not supported"
130+
f" for dataset_type '{dataset_type}'."
131+
f" Please check the API documentation for the supported dataset types."
132+
)
133+
if excluded_columns:
134+
raise ValueError(
135+
f"The parameter 'excluded_columns' is not supported"
136+
f" for dataset_type '{dataset_type}'."
137+
f" Please check the API documentation for the supported dataset types."
138+
)
139+
if facet_dataset_uri or facet_headers:
140+
raise ValueError(
141+
f"The parameters 'facet_dataset_uri' and 'facet_headers'"
142+
f" are not supported for dataset_type '{dataset_type}'."
143+
f" Please check the API documentation for the supported dataset types."
144+
)
145+
if predicted_label_dataset_uri or predicted_label_headers:
146+
raise ValueError(
147+
f"The parameters 'predicted_label_dataset_uri' and 'predicted_label_headers'"
148+
f" are not supported for dataset_type '{dataset_type}'."
149+
f" Please check the API documentation for the supported dataset types."
150+
)
84151
self.s3_data_input_path = s3_data_input_path
85152
self.s3_output_path = s3_output_path
86153
self.s3_analysis_config_output_path = s3_analysis_config_output_path
@@ -89,13 +156,25 @@ def __init__(
89156
self.label = label
90157
self.headers = headers
91158
self.features = features
159+
self.facet_dataset_uri = facet_dataset_uri
160+
self.facet_headers = facet_headers
161+
self.predicted_label_dataset_uri = predicted_label_dataset_uri
162+
self.predicted_label_headers = predicted_label_headers
163+
self.predicted_label = predicted_label
164+
self.excluded_columns = excluded_columns
92165
self.analysis_config = {
93166
"dataset_type": dataset_type,
94167
}
95168
_set(features, "features", self.analysis_config)
96169
_set(headers, "headers", self.analysis_config)
97170
_set(label, "label", self.analysis_config)
98171
_set(joinsource, "joinsource_name_or_index", self.analysis_config)
172+
_set(facet_dataset_uri, "facet_dataset_uri", self.analysis_config)
173+
_set(facet_headers, "facet_headers", self.analysis_config)
174+
_set(predicted_label_dataset_uri, "predicted_label_dataset_uri", self.analysis_config)
175+
_set(predicted_label_headers, "predicted_label_headers", self.analysis_config)
176+
_set(predicted_label, "predicted_label", self.analysis_config)
177+
_set(excluded_columns, "excluded_columns", self.analysis_config)
99178

100179
def get_config(self):
101180
"""Returns part of an analysis config dictionary."""
@@ -204,21 +283,23 @@ def __init__(
204283
r"""Initializes a configuration of a model and the endpoint to be created for it.
205284
206285
Args:
207-
model_name (str): Model name (as created by 'CreateModel').
286+
model_name (str): Model name (as created by
287+
`CreateModel <https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_.
208288
instance_count (int): The number of instances of a new endpoint for model inference.
209-
instance_type (str): The type of EC2 instance to use for model inference,
210-
for example, ``"ml.c5.xlarge"``.
289+
instance_type (str): The type of
290+
`EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
291+
to use for model inference; for example, ``"ml.c5.xlarge"``.
211292
accept_type (str): The model output format to be used for getting inferences with the
212-
shadow endpoint. Valid values are "text/csv" for CSV and "application/jsonlines".
213-
Default is the same as content_type.
293+
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
294+
``"application/jsonlines"``. Default is the same as ``content_type``.
214295
content_type (str): The model input format to be used for getting inferences with the
215-
shadow endpoint. Valid values are "text/csv" for CSV and "application/jsonlines".
216-
Default is the same as dataset format.
296+
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
297+
``"application/jsonlines"``. Default is the same as ``dataset_format``.
217298
content_template (str): A template string to be used to construct the model input from
218299
dataset instances. It is only used when ``model_content_type`` is
219300
``"application/jsonlines"``. The template should have one and only one placeholder,
220-
"features", which will be replaced by a features list to form the model inference
221-
input.
301+
``"features"``, which will be replaced by a features list to form the model
302+
inference input.
222303
custom_attributes (str): Provides additional information about a request for an
223304
inference submitted to a model hosted at an Amazon SageMaker endpoint. The
224305
information is an opaque value that is forwarded verbatim. You could use this
@@ -504,16 +585,20 @@ def __init__(
504585
for these units.
505586
language (str): Specifies the language of the text features. Accepted values are
506587
one of the following:
507-
"chinese", "danish", "dutch", "english", "french", "german", "greek", "italian",
508-
"japanese", "lithuanian", "multi-language", "norwegian bokmål", "polish",
509-
"portuguese", "romanian", "russian", "spanish", "afrikaans", "albanian", "arabic",
510-
"armenian", "basque", "bengali", "bulgarian", "catalan", "croatian", "czech",
511-
"estonian", "finnish", "gujarati", "hebrew", "hindi", "hungarian", "icelandic",
512-
"indonesian", "irish", "kannada", "kyrgyz", "latvian", "ligurian", "luxembourgish",
513-
"macedonian", "malayalam", "marathi", "nepali", "persian", "sanskrit", "serbian",
514-
"setswana", "sinhala", "slovak", "slovenian", "swedish", "tagalog", "tamil",
515-
"tatar", "telugu", "thai", "turkish", "ukrainian", "urdu", "vietnamese", "yoruba".
516-
Use "multi-language" for a mix of multiple languages.
588+
``"chinese"``, ``"danish"``, ``"dutch"``, ``"english"``, ``"french"``, ``"german"``,
589+
``"greek"``, ``"italian"``, ``"japanese"``, ``"lithuanian"``, ``"multi-language"``,
590+
``"norwegian bokmål"``, ``"polish"``, ``"portuguese"``, ``"romanian"``,
591+
``"russian"``, ``"spanish"``, ``"afrikaans"``, ``"albanian"``, ``"arabic"``,
592+
``"armenian"``, ``"basque"``, ``"bengali"``, ``"bulgarian"``, ``"catalan"``,
593+
``"croatian"``, ``"czech"``, ``"estonian"``, ``"finnish"``, ``"gujarati"``,
594+
``"hebrew"``, ``"hindi"``, ``"hungarian"``, ``"icelandic"``, ``"indonesian"``,
595+
``"irish"``, ``"kannada"``, ``"kyrgyz"``, ``"latvian"``, ``"ligurian"``,
596+
``"luxembourgish"``, ``"macedonian"``, ``"malayalam"``, ``"marathi"``, ``"nepali"``,
597+
``"persian"``, ``"sanskrit"``, ``"serbian"``, ``"setswana"``, ``"sinhala"``,
598+
``"slovak"``, ``"slovenian"``, ``"swedish"``, ``"tagalog"``, ``"tamil"``,
599+
``"tatar"``, ``"telugu"``, ``"thai"``, ``"turkish"``, ``"ukrainian"``, ``"urdu"``,
600+
``"vietnamese"``, ``"yoruba"``.
601+
Use ``"multi-language"`` for a mix of multiple languages.
517602
518603
Raises:
519604
ValueError: when ``granularity`` is not in list of supported values
@@ -737,12 +822,15 @@ def __init__(
737822
data stored in Amazon S3.
738823
instance_count (int): The number of instances to run
739824
a processing job with.
740-
instance_type (str): The type of EC2 instance to use for
741-
processing, for example, ``'ml.c4.xlarge'``.
742-
volume_size_in_gb (int): Size in GB of the EBS volume
743-
to use for storing data during processing (default: 30).
744-
volume_kms_key (str): A KMS key for the processing
745-
volume (default: None).
825+
instance_type (str): The type of
826+
`EC2 instance <https://aws.amazon.com/ec2/instance-types/>`_
827+
to use for model inference; for example, ``"ml.c5.xlarge"``.
828+
volume_size_in_gb (int): Size in GB of the
829+
`EBS volume <https://docs.aws.amazon.com/sagemaker/latest/dg/host-instance-storage.html>`_.
830+
to use for storing data during processing (default: 30 GB).
831+
volume_kms_key (str): A
832+
`KMS key <https://docs.aws.amazon.com/sagemaker/latest/dg/key-management.html>`_
833+
for the processing volume (default: None).
746834
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
747835
max_runtime_in_seconds (int): Timeout in seconds (default: None).
748836
After this amount of time, Amazon SageMaker terminates the job,
@@ -764,7 +852,7 @@ def __init__(
764852
inter-container traffic, security group IDs, and subnets.
765853
job_name_prefix (str): Processing job name prefix.
766854
version (str): Clarify version to use.
767-
"""
855+
""" # noqa E501 # pylint: disable=c0301
768856
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
769857
self.job_name_prefix = job_name_prefix
770858
super(SageMakerClarifyProcessor, self).__init__(
@@ -1158,6 +1246,7 @@ def run_explainability(
11581246
11591247
Currently, only SHAP and Partial Dependence Plots (PDP) are supported
11601248
as explainability methods.
1249+
You can request both methods or one at a time with the ``explainability_config`` parameter.
11611250
11621251
When SHAP is requested in the ``explainability_config``,
11631252
the SHAP algorithm calculates the feature importance for each input example
@@ -1183,6 +1272,8 @@ def run_explainability(
11831272
Config of the specific explainability method or a list of
11841273
:class:`~sagemaker.clarify.ExplainabilityConfig` objects.
11851274
Currently, SHAP and PDP are the two methods supported.
1275+
You can request multiple methods at once by passing in a list of
1276+
`~sagemaker.clarify.ExplainabilityConfig`.
11861277
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
11871278
Index or JSONPath to locate the predicted scores in the model output. This is not
11881279
required if the model output is a single score. Alternatively, it can be an instance

0 commit comments

Comments
 (0)