Skip to content

Commit 03ec7dd

Browse files
committed
feature: support multiple facets for Clarify
1 parent 125adc3 commit 03ec7dd

File tree

2 files changed

+131
-7
lines changed

2 files changed

+131
-7
lines changed

src/sagemaker/clarify.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __init__(
4848
headers (list[str]): A list of column names in the input dataset.
4949
features (str): JSONPath for locating the feature columns for bias metrics if the
5050
dataset format is JSONLines.
51-
dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV
52-
and "application/jsonlines" for JSONLines.
51+
dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV,
52+
"application/jsonlines" for JSONLines, and "application/x-parquet" for Parquet.
5353
s3_data_distribution_type (str): Valid options are "FullyReplicated" or
5454
"ShardedByS3Key".
5555
s3_compression_type (str): Valid options are "None" or "Gzip".
@@ -61,6 +61,11 @@ def __init__(
6161
self.label = label
6262
self.headers = headers
6363
self.features = features
64+
if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]:
65+
raise ValueError(
66+
f"Invalid dataset_type {dataset_type}."
67+
f" Please choose text/csv or application/jsonlines or application/x-parquet."
68+
)
6469
self.analysis_config = {
6570
"dataset_type": dataset_type,
6671
}
@@ -79,8 +84,9 @@ class BiasConfig:
7984
def __init__(
8085
self,
8186
label_values_or_threshold,
82-
facet_name,
87+
facet_name=None,
8388
facet_values_or_threshold=None,
89+
facet_list=None,
8490
group_name=None,
8591
):
8692
"""Initializes a configuration of the sensitive groups in the dataset.
@@ -94,15 +100,55 @@ def __init__(
94100
threshold for a numeric facet column that defines the lower bound of a sensitive
95101
group. Defaults to considering each possible value as sensitive group and
96102
computing metrics vs all the other examples.
103+
facet_list (list[dict]): Optional list of dictionaries that defines the sensitive
104+
attribute(s). Each dictionary contains two keys in the form of the following:
105+
'name_or_index' (int or str) for facet column name or index,
106+
optional 'value_or_threshold' (list[int or float or str]) for list of values or
107+
threshold that the facet column can take which indicates the sensitive group.
108+
This should can be defined only if there are more than one sensitive attribute.
97109
group_name (str): Optional column name or index to indicate a group column to be used
98110
for the bias metric 'Conditional Demographic Disparity in Labels - CDDL' or
99111
'Conditional Demographic Disparity in Predicted Labels - CDDPL'.
100112
"""
101-
facet = {"name_or_index": facet_name}
102-
_set(facet_values_or_threshold, "value_or_threshold", facet)
113+
if facet_list:
114+
for facet_object in facet_list:
115+
if not all(
116+
field in ["name_or_index", "value_or_threshold"] for field in facet_object
117+
):
118+
raise ValueError(
119+
f"Invalid facet_list {facet_list}."
120+
f" Please only include 'name_or_index' or 'value_or_threshold'"
121+
f" in dictionary keys."
122+
)
123+
if "name_or_index" not in facet_object or not isinstance(
124+
facet_object["name_or_index"], (str, int)
125+
):
126+
raise ValueError(
127+
f"Invalid facet_list {facet_list}."
128+
f" Please include valid format of 'name_or_index' in dictionary:"
129+
f" str, int."
130+
)
131+
if "value_or_threshold" in facet_object and not (
132+
isinstance(facet_object["value_or_threshold"], list)
133+
and all(
134+
isinstance(v, (str, int, float)) for v in facet_object["value_or_threshold"]
135+
)
136+
):
137+
raise ValueError(
138+
f"Invalid facet_list {facet_list}."
139+
f" Please include valid format of 'value_or_threshold' in dictionary:"
140+
f" list[int or float or str]."
141+
)
142+
elif facet_name is not None:
143+
facet = {"name_or_index": facet_name}
144+
_set(facet_values_or_threshold, "value_or_threshold", facet)
145+
facet_list = [facet]
146+
else:
147+
raise ValueError("Please specify facet_name or facet_list.")
148+
103149
self.analysis_config = {
104150
"label_values_or_threshold": label_values_or_threshold,
105-
"facet": [facet],
151+
"facet": facet_list,
106152
}
107153
_set(group_name, "group_variable", self.analysis_config)
108154

tests/unit/test_clarify.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def test_data_config():
7171
def test_data_bias_config():
7272
label_values = [1]
7373
facet_name = "F1"
74-
facet_threshold = 0.3
74+
facet_name2 = "F2"
75+
facet_threshold = [0.3]
76+
facet_threshold2 = [0.1]
7577
group_name = "A151"
7678

7779
data_bias_config = BiasConfig(
@@ -81,12 +83,88 @@ def test_data_bias_config():
8183
group_name=group_name,
8284
)
8385

86+
data_bias_config_without_value_or_threshold = BiasConfig(
87+
label_values_or_threshold=label_values,
88+
facet_list=[{"name_or_index": facet_name}],
89+
group_name=group_name,
90+
)
91+
92+
data_bias_config_with_facet_list = BiasConfig(
93+
label_values_or_threshold=label_values,
94+
facet_list=[{"name_or_index": facet_name, "value_or_threshold": facet_threshold}],
95+
group_name=group_name,
96+
)
97+
98+
data_bias_config_with_multiple_facets = BiasConfig(
99+
label_values_or_threshold=label_values,
100+
facet_list=[
101+
{"name_or_index": facet_name, "value_or_threshold": facet_threshold},
102+
{"name_or_index": facet_name2, "value_or_threshold": facet_threshold2},
103+
],
104+
group_name=group_name,
105+
)
106+
84107
expected_config = {
85108
"label_values_or_threshold": label_values,
86109
"facet": [{"name_or_index": facet_name, "value_or_threshold": facet_threshold}],
87110
"group_variable": group_name,
88111
}
112+
expected_config_without_value_or_threshold = {
113+
"label_values_or_threshold": label_values,
114+
"facet": [{"name_or_index": facet_name}],
115+
"group_variable": group_name,
116+
}
117+
expected_config_with_multiple_facets = {
118+
"label_values_or_threshold": label_values,
119+
"facet": [
120+
{"name_or_index": facet_name, "value_or_threshold": facet_threshold},
121+
{"name_or_index": facet_name2, "value_or_threshold": facet_threshold2},
122+
],
123+
"group_variable": group_name,
124+
}
89125
assert expected_config == data_bias_config.get_config()
126+
assert (
127+
expected_config_without_value_or_threshold
128+
== data_bias_config_without_value_or_threshold.get_config()
129+
)
130+
assert (
131+
expected_config_with_multiple_facets == data_bias_config_with_multiple_facets.get_config()
132+
)
133+
assert expected_config == data_bias_config_with_facet_list.get_config()
134+
135+
136+
def test_invalid_data_bias_config():
137+
label_values = [1]
138+
facet_name = "F1"
139+
facet_threshold = [0.3]
140+
group_name = "A151"
141+
with pytest.raises(ValueError) as error:
142+
BiasConfig(
143+
label_values_or_threshold=label_values,
144+
facet_list=[
145+
{"random_field": "random_string", "name_or_index": facet_name, "value_or_threshold": facet_threshold}
146+
],
147+
group_name=group_name,
148+
)
149+
assert (
150+
"Please only include 'name_or_index' or 'value_or_threshold' in dictionary keys." in str(error.value)
151+
)
152+
with pytest.raises(ValueError) as error:
153+
BiasConfig(
154+
label_values_or_threshold=label_values,
155+
facet_list=[{"value_or_threshold": facet_threshold}],
156+
group_name=group_name
157+
)
158+
assert (
159+
"Please include valid format of 'name_or_index' in dictionary" in str(error.value)
160+
)
161+
with pytest.raises(ValueError) as error:
162+
BiasConfig(
163+
label_values_or_threshold=label_values,
164+
facet_list=[{"name_or_index": facet_name, "value_or_threshold": True}],
165+
group_name=group_name,
166+
)
167+
assert "Please include valid format of 'value_or_threshold' in dictionary" in str(error.value)
90168

91169

92170
def test_model_config():

0 commit comments

Comments
 (0)