@@ -85,6 +85,54 @@ def test_data_bias_config():
85
85
assert expected_config == data_bias_config .get_config ()
86
86
87
87
88
+ def test_data_bias_config_multi_facet ():
89
+ label_values = [1 ]
90
+ facet_name = ["Facet1" , "Facet2" ]
91
+ facet_threshold = [[0 ], [1 , 2 ]]
92
+ group_name = "A151"
93
+
94
+ data_bias_config = BiasConfig (
95
+ label_values_or_threshold = label_values ,
96
+ facet_name = facet_name ,
97
+ facet_values_or_threshold = facet_threshold ,
98
+ group_name = group_name ,
99
+ )
100
+
101
+ expected_config = {
102
+ "label_values_or_threshold" : label_values ,
103
+ "facet" : [
104
+ {"name_or_index" : facet_name [0 ], "value_or_threshold" : facet_threshold [0 ]},
105
+ {"name_or_index" : facet_name [1 ], "value_or_threshold" : facet_threshold [1 ]},
106
+ ],
107
+ "group_variable" : group_name
108
+ }
109
+ assert expected_config == data_bias_config .get_config ()
110
+
111
+
112
+ def test_data_bias_config_multi_facet_not_all_with_value ():
113
+ label_values = [1 ]
114
+ facet_name = ["Facet1" , "Facet2" ]
115
+ facet_threshold = [[0 ], None ]
116
+ group_name = "A151"
117
+
118
+ data_bias_config = BiasConfig (
119
+ label_values_or_threshold = label_values ,
120
+ facet_name = facet_name ,
121
+ facet_values_or_threshold = facet_threshold ,
122
+ group_name = group_name ,
123
+ )
124
+
125
+ expected_config = {
126
+ "label_values_or_threshold" : label_values ,
127
+ "facet" : [
128
+ {"name_or_index" : facet_name [0 ], "value_or_threshold" : facet_threshold [0 ]},
129
+ {"name_or_index" : facet_name [1 ]},
130
+ ],
131
+ "group_variable" : group_name
132
+ }
133
+ assert expected_config == data_bias_config .get_config ()
134
+
135
+
88
136
def test_model_config ():
89
137
model_name = "xgboost-model"
90
138
instance_type = "ml.c5.xlarge"
0 commit comments