Skip to content

Commit 7c8beec

Browse files
authored
Update BiasConfig to accept multiple facet params
BiasConfig will now accept a list of feature/attribute names to perform the bias analysis. This is already supported by the service and with this update the SDK will be able to make use of it.
1 parent d666bbd commit 7c8beec

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

src/sagemaker/clarify.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,28 +88,43 @@ def __init__(
8888
Args:
8989
label_values_or_threshold (Any): List of label values or threshold to indicate positive
9090
outcome used for bias metrics.
91-
facet_name (str): Sensitive attribute in the input data for which we like to compare
92-
metrics.
91+
facet_name (Any): String or List of strings of sensitive attribute(s) in the input data
92+
for which we like to compare metrics.
9393
facet_values_or_threshold (list): Optional list of values to form a sensitive group or
9494
threshold for a numeric facet column that defines the lower bound of a sensitive
9595
group. Defaults to considering each possible value as sensitive group and
96-
computing metrics vs all the other examples.
96+
computing metrics vs all the other examples.
97+
If facet_name is a list, this needs to be None or a List of lists or None, with the same length.
9798
group_name (str): Optional column name or index to indicate a group column to be used
9899
for the bias metric 'Conditional Demographic Disparity in Labels - CDDL' or
99100
'Conditional Demographic Disparity in Predicted Labels - CDDPL'.
100101
"""
101-
facet = {"name_or_index": facet_name}
102-
_set(facet_values_or_threshold, "value_or_threshold", facet)
102+
if isinstance(facet_name, str):
103+
facet = {"name_or_index": facet_name}
104+
_set(facet_values_or_threshold, "value_or_threshold", facet)
105+
facet_list = [facet]
106+
elif facet_values_or_threshold is None or len(facet_name) == len(facet_values_or_threshold):
107+
facet_list = []
108+
for i, single_facet_name in enumerate(facet_name):
109+
facet = {"name_or_index": single_facet_name}
110+
if facet_values_or_threshold is not None: _set(facet_values_or_threshold[i], "value_or_threshold", facet)
111+
facet_list.append(facet)
112+
else:
113+
raise ValueError("Wrong combination of argument values passed")
103114
self.analysis_config = {
104115
"label_values_or_threshold": label_values_or_threshold,
105-
"facet": [facet],
116+
"facet": facet_list,
106117
}
107118
_set(group_name, "group_variable", self.analysis_config)
108119

109120
def get_config(self):
110121
"""Returns part of an analysis config dictionary."""
111122
return copy.deepcopy(self.analysis_config)
112123

124+
def get_config(self):
125+
"""Returns part of an analysis config dictionary."""
126+
return copy.deepcopy(self.analysis_config)
127+
113128

114129
class ModelConfig:
115130
"""Config object related to a model and its endpoint to be created."""

0 commit comments

Comments
 (0)