@@ -88,28 +88,43 @@ def __init__(
88
88
Args:
89
89
label_values_or_threshold (Any): List of label values or threshold to indicate positive
90
90
outcome used for bias metrics.
91
- facet_name (str ): Sensitive attribute in the input data for which we like to compare
92
- metrics.
91
+ facet_name (Any ): String or List of strings of sensitive attribute(s) in the input data
92
+ for which we like to compare metrics.
93
93
facet_values_or_threshold (list): Optional list of values to form a sensitive group or
94
94
threshold for a numeric facet column that defines the lower bound of a sensitive
95
95
group. Defaults to considering each possible value as sensitive group and
96
- computing metrics vs all the other examples.
96
+ computing metrics vs all the other examples.
97
+ If facet_name is a list, this needs to be None or a List of lists or None, with the same length.
97
98
group_name (str): Optional column name or index to indicate a group column to be used
98
99
for the bias metric 'Conditional Demographic Disparity in Labels - CDDL' or
99
100
'Conditional Demographic Disparity in Predicted Labels - CDDPL'.
100
101
"""
101
- facet = {"name_or_index" : facet_name }
102
- _set (facet_values_or_threshold , "value_or_threshold" , facet )
102
+ if isinstance (facet_name , str ):
103
+ facet = {"name_or_index" : facet_name }
104
+ _set (facet_values_or_threshold , "value_or_threshold" , facet )
105
+ facet_list = [facet ]
106
+ elif facet_values_or_threshold is None or len (facet_name ) == len (facet_values_or_threshold ):
107
+ facet_list = []
108
+ for i , single_facet_name in enumerate (facet_name ):
109
+ facet = {"name_or_index" : single_facet_name }
110
+ if facet_values_or_threshold is not None : _set (facet_values_or_threshold [i ], "value_or_threshold" , facet )
111
+ facet_list .append (facet )
112
+ else :
113
+ raise ValueError ("Wrong combination of argument values passed" )
103
114
self .analysis_config = {
104
115
"label_values_or_threshold" : label_values_or_threshold ,
105
- "facet" : [ facet ] ,
116
+ "facet" : facet_list ,
106
117
}
107
118
_set (group_name , "group_variable" , self .analysis_config )
108
119
109
120
def get_config (self ):
110
121
"""Returns part of an analysis config dictionary."""
111
122
return copy .deepcopy (self .analysis_config )
112
123
124
+ def get_config (self ):
125
+ """Returns part of an analysis config dictionary."""
126
+ return copy .deepcopy (self .analysis_config )
127
+
113
128
114
129
class ModelConfig :
115
130
"""Config object related to a model and its endpoint to be created."""
0 commit comments