Skip to content

Commit 64e4690

Browse files
xgchenaBasilBeirouti
authored andcommitted
fix: support specifying a facet by its column index
Currently the Clarify BiasConfig only accepts facet name. Actually Clarify analysis configuration supports both name and index. This commit adds the same support to BiasConfig.
1 parent 9d25fcf commit 64e4690

File tree

2 files changed

+152
-57
lines changed

2 files changed

+152
-57
lines changed

src/sagemaker/clarify.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,33 +111,58 @@ def __init__(
111111
"""Initializes a configuration of the sensitive groups in the dataset.
112112
113113
Args:
114-
label_values_or_threshold (Any): List of label values or threshold to indicate positive
115-
outcome used for bias metrics.
116-
facet_name (str or [str]): String or List of strings of sensitive attribute(s) in the
117-
input data for which we like to compare metrics.
118-
facet_values_or_threshold (list): Optional list of values to form a sensitive group or
119-
threshold for a numeric facet column that defines the lower bound of a sensitive
120-
group. Defaults to considering each possible value as sensitive group and
121-
computing metrics vs all the other examples.
122-
If facet_name is a list, this needs to be None or a List consisting of lists or None
123-
with the same length as facet_name list.
114+
label_values_or_threshold ([int or float or str]): List of label value(s) or threshold
115+
to indicate positive outcome used for bias metrics. Dependency on the problem type,
116+
117+
* Binary problem: The list shall include one positive value.
118+
* Categorical problem: The list shall include one or more (but not all) categories
119+
which are the positive values.
120+
* Regression problem: The list shall include one threshold that defines the lower
121+
bound of positive values.
122+
123+
facet_name (str or int or [str] or [int]): Sensitive attribute column name (or index in
124+
the input data) for which you like to compute bias metrics. It can also be a list
125+
of names (or indexes) if you like to compute for multiple sensitive attributes.
126+
facet_values_or_threshold ([int or float or str] or [[int or float or str]]):
127+
The parameter indicates the sensitive group. If facet_name is a scalar, then it can
128+
be None or a list. Depending on the data type of the facet column,
129+
130+
* Binary: None means computing the bias metrics for each binary value. Or add one
131+
binary value to the list, to compute its bias metrics only.
132+
* Categorical: None means computing the bias metrics for each category. Or add one
133+
or more (but not all) categories to the list, to compute their bias metrics v.s.
134+
the other categories.
135+
* Continuous: The list shall include one and only one threshold which defines the
136+
lower bound of a sensitive group.
137+
138+
If facet_name is a list, then it can be None if all facets are of binary type or
139+
categorical type. Otherwise it shall be a list, and each element is the values or
140+
threshold of the corresponding facet.
124141
group_name (str): Optional column name or index to indicate a group column to be used
125142
for the bias metric 'Conditional Demographic Disparity in Labels - CDDL' or
126143
'Conditional Demographic Disparity in Predicted Labels - CDDPL'.
127144
"""
128-
if isinstance(facet_name, str):
145+
if isinstance(facet_name, list):
146+
assert len(facet_name) > 0, "Please provide at least one facet"
147+
if facet_values_or_threshold is None:
148+
facet_list = [
149+
{"name_or_index": single_facet_name} for single_facet_name in facet_name
150+
]
151+
elif len(facet_values_or_threshold) == len(facet_name):
152+
facet_list = []
153+
for i, single_facet_name in enumerate(facet_name):
154+
facet = {"name_or_index": single_facet_name}
155+
if facet_values_or_threshold is not None:
156+
_set(facet_values_or_threshold[i], "value_or_threshold", facet)
157+
facet_list.append(facet)
158+
else:
159+
raise ValueError(
160+
"The number of facet names doesn't match the number of facet values"
161+
)
162+
else:
129163
facet = {"name_or_index": facet_name}
130164
_set(facet_values_or_threshold, "value_or_threshold", facet)
131165
facet_list = [facet]
132-
elif facet_values_or_threshold is None or len(facet_name) == len(facet_values_or_threshold):
133-
facet_list = []
134-
for i, single_facet_name in enumerate(facet_name):
135-
facet = {"name_or_index": single_facet_name}
136-
if facet_values_or_threshold is not None:
137-
_set(facet_values_or_threshold[i], "value_or_threshold", facet)
138-
facet_list.append(facet)
139-
else:
140-
raise ValueError("Wrong combination of argument values passed")
141166
self.analysis_config = {
142167
"label_values_or_threshold": label_values_or_threshold,
143168
"facet": facet_list,

tests/unit/test_clarify.py

Lines changed: 107 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_invalid_data_config():
8282
)
8383

8484

85-
def test_data_bias_config():
85+
def test_bias_config():
8686
label_values = [1]
8787
facet_name = "F1"
8888
facet_threshold = 0.3
@@ -103,52 +103,122 @@ def test_data_bias_config():
103103
assert expected_config == data_bias_config.get_config()
104104

105105

106-
def test_data_bias_config_multi_facet():
107-
label_values = [1]
108-
facet_name = ["Facet1", "Facet2"]
109-
facet_threshold = [[0], [1, 2]]
110-
group_name = "A151"
111-
112-
data_bias_config = BiasConfig(
113-
label_values_or_threshold=label_values,
114-
facet_name=facet_name,
115-
facet_values_or_threshold=facet_threshold,
116-
group_name=group_name,
117-
)
106+
def test_invalid_bias_config():
107+
# Empty facet list,
108+
with pytest.raises(AssertionError, match="Please provide at least one facet"):
109+
BiasConfig(
110+
label_values_or_threshold=[1],
111+
facet_name=[],
112+
)
118113

119-
expected_config = {
120-
"label_values_or_threshold": label_values,
121-
"facet": [
122-
{"name_or_index": facet_name[0], "value_or_threshold": facet_threshold[0]},
123-
{"name_or_index": facet_name[1], "value_or_threshold": facet_threshold[1]},
124-
],
125-
"group_variable": group_name,
126-
}
127-
assert expected_config == data_bias_config.get_config()
114+
# Two facets but only one value
115+
with pytest.raises(
116+
ValueError, match="The number of facet names doesn't match the number of facet values"
117+
):
118+
BiasConfig(
119+
label_values_or_threshold=[1],
120+
facet_name=["Feature1", "Feature2"],
121+
facet_values_or_threshold=[[1]],
122+
)
128123

129124

130-
def test_data_bias_config_multi_facet_not_all_with_value():
125+
@pytest.mark.parametrize(
126+
"facet_name,facet_values_or_threshold,expected_result",
127+
[
128+
# One facet, assume that it is binary and value 1 indicates the sensitive group
129+
[
130+
"Feature1",
131+
[1],
132+
{
133+
"facet": [{"name_or_index": "Feature1", "value_or_threshold": [1]}],
134+
},
135+
],
136+
# The same facet as above, facet value is not specified. (Clarify will compute bias metrics
137+
# for each binary value).
138+
[
139+
"Feature1",
140+
None,
141+
{
142+
"facet": [{"name_or_index": "Feature1"}],
143+
},
144+
],
145+
# Assume that the 2nd column (index 1, zero-based) of the dataset as facet, it has
146+
# four categories and two of them indicate the sensitive group.
147+
[
148+
1,
149+
["category1, category2"],
150+
{
151+
"facet": [{"name_or_index": 1, "value_or_threshold": ["category1, category2"]}],
152+
},
153+
],
154+
# The same facet as above, facet values are not specified. (Clarify will iterate
155+
# the categories and compute bias metrics for each category).
156+
[
157+
1,
158+
None,
159+
{
160+
"facet": [{"name_or_index": 1}],
161+
},
162+
],
163+
# Assume that the facet is numeric value in range [0.0, 1.0]. Given facet threshold 0.5,
164+
# interval (0.5, 1.0] indicates the sensitive group.
165+
[
166+
"Feature3",
167+
[0.5],
168+
{
169+
"facet": [{"name_or_index": "Feature3", "value_or_threshold": [0.5]}],
170+
},
171+
],
172+
# Multiple facets
173+
[
174+
["Feature1", 1, "Feature3"],
175+
[[1], ["category1, category2"], [0.5]],
176+
{
177+
"facet": [
178+
{"name_or_index": "Feature1", "value_or_threshold": [1]},
179+
{"name_or_index": 1, "value_or_threshold": ["category1, category2"]},
180+
{"name_or_index": "Feature3", "value_or_threshold": [0.5]},
181+
],
182+
},
183+
],
184+
# Multiple facets, no value or threshold
185+
[
186+
["Feature1", 1, "Feature3"],
187+
None,
188+
{
189+
"facet": [
190+
{"name_or_index": "Feature1"},
191+
{"name_or_index": 1},
192+
{"name_or_index": "Feature3"},
193+
],
194+
},
195+
],
196+
# Multiple facets, specify values or threshold for some of them
197+
[
198+
["Feature1", 1, "Feature3"],
199+
[[1], None, [0.5]],
200+
{
201+
"facet": [
202+
{"name_or_index": "Feature1", "value_or_threshold": [1]},
203+
{"name_or_index": 1},
204+
{"name_or_index": "Feature3", "value_or_threshold": [0.5]},
205+
],
206+
},
207+
],
208+
],
209+
)
210+
def test_facet_of_bias_config(facet_name, facet_values_or_threshold, expected_result):
131211
label_values = [1]
132-
facet_name = ["Facet1", "Facet2"]
133-
facet_threshold = [[0], None]
134-
group_name = "A151"
135-
136-
data_bias_config = BiasConfig(
212+
bias_config = BiasConfig(
137213
label_values_or_threshold=label_values,
138214
facet_name=facet_name,
139-
facet_values_or_threshold=facet_threshold,
140-
group_name=group_name,
215+
facet_values_or_threshold=facet_values_or_threshold,
141216
)
142-
143217
expected_config = {
144218
"label_values_or_threshold": label_values,
145-
"facet": [
146-
{"name_or_index": facet_name[0], "value_or_threshold": facet_threshold[0]},
147-
{"name_or_index": facet_name[1]},
148-
],
149-
"group_variable": group_name,
219+
**expected_result,
150220
}
151-
assert expected_config == data_bias_config.get_config()
221+
assert bias_config.get_config() == expected_config
152222

153223

154224
def test_model_config():

0 commit comments

Comments
 (0)