Skip to content

Commit 2c52a98

Browse files
authored
add challenge evaluation script (Project-MONAI#765)
1 parent 61a5bdc commit 2c52a98

File tree

2 files changed

+254
-0
lines changed

2 files changed

+254
-0
lines changed

federated_learning/breast_density_challenge/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,6 @@ for the global model (should be named `SRV_best_FL_global_model.pt`).
174174
...
175175
}
176176
```
177+
178+
## Challenge evaluation
179+
The script used for evaluating different submissions is available at [challenge_evaluate.py](./challenge_evaluate.py)
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright 2022 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import argparse
13+
import json
14+
15+
import numpy as np
16+
import pandas as pd
17+
from sklearn import metrics as sk_metrics
18+
19+
site_names = ["site-1", "site-2", "site-3"]
20+
merge_patients = True
21+
22+
23+
def read_ground_truth(filename):
24+
with open(filename, "r") as f:
25+
data = json.load(f)
26+
27+
df = {"patient_id": [], "image": [], "label": [], "split": []}
28+
for split in data.keys():
29+
print(f"loading {split}: {len(data[split])} cases from {filename}")
30+
for item in data[split]:
31+
[df[k].append(item[k]) for k in item.keys()]
32+
df["split"].append(split)
33+
if "label" not in item.keys():
34+
df["label"].append(np.NAN)
35+
36+
return pd.DataFrame(df)
37+
38+
39+
def read_prediction(filename, gt, model_name):
40+
with open(filename, "r") as f:
41+
data = json.load(f)
42+
43+
result = {}
44+
for s in site_names:
45+
result[s] = {
46+
"pred_probs": [],
47+
"gt_labels": [],
48+
"pred_probs_bin": [],
49+
"gt_labels_bin": [],
50+
"patient_ids": [],
51+
}
52+
for site in data.keys():
53+
for item in data[site][model_name]["test_probs"]:
54+
# multi-class
55+
assert (
56+
len(item["probs"]) == 4
57+
), f"Expected four probs but got {len(item['probs'])}: {item['probs']}"
58+
result[site]["pred_probs"].append(item["probs"])
59+
gt_item = gt[gt["image"] == item["image"]]
60+
gt_label = gt_item["label"]
61+
assert len(gt_label) == 1, f"gt label was {gt_label}"
62+
result[site]["patient_ids"].append(gt_item["patient_id"].item())
63+
result[site]["gt_labels"].append(gt_label.item())
64+
65+
# binary (non-dense vs dense)
66+
result[site]["pred_probs_bin"].append(
67+
np.sum(item["probs"][2::])
68+
) # prob for dense (class 3 and 4).
69+
if gt_label.item() in [0, 1]: # non-dense (class 1 and 2)
70+
result[site]["gt_labels_bin"].append(0)
71+
elif gt_label.item() in [2, 3]: # dense (class 3 and 4)
72+
result[site]["gt_labels_bin"].append(1)
73+
else:
74+
raise ValueError(f"didn't expect a label of {gt_label}")
75+
assert (
76+
len(result[site]["gt_labels"])
77+
== len(result[site]["pred_probs"])
78+
== len(result[site]["gt_labels_bin"])
79+
== len(result[site]["pred_probs_bin"])
80+
== len(result[site]["patient_ids"])
81+
)
82+
assert len(np.unique(result[site]["gt_labels_bin"])) == 2, (
83+
f"Expected two kinds of binary labels but got "
84+
f"unique labels {np.unique(result[site]['gt_labels_bin'])}"
85+
)
86+
return result
87+
88+
89+
def evaluate(site_result):
90+
gt_labels = site_result["gt_labels"]
91+
pred_probs = site_result["pred_probs"]
92+
gt_labels_bin = site_result["gt_labels_bin"]
93+
pred_probs_bin = site_result["pred_probs_bin"]
94+
95+
# get pred labels
96+
pred_labels = []
97+
for prob in pred_probs:
98+
pred_labels.append(np.argmax(prob))
99+
100+
assert (
101+
len(gt_labels) == len(pred_labels) == len(gt_labels_bin) == len(pred_probs_bin)
102+
)
103+
104+
# multi-class metrics
105+
linear_kappa = sk_metrics.cohen_kappa_score(
106+
gt_labels, pred_labels, weights="linear"
107+
)
108+
quadratic_kappa = sk_metrics.cohen_kappa_score(
109+
gt_labels, pred_labels, weights="quadratic"
110+
)
111+
112+
# per-image distance metrics
113+
dist = np.abs(np.squeeze(gt_labels) - np.squeeze(pred_labels))
114+
lin_dist = -dist
115+
quad_dist = -(dist ** 2)
116+
avg_lin_dist = np.mean(lin_dist)
117+
avg_quad_dist = np.mean(quad_dist)
118+
119+
# binary metrics
120+
fpr, tpr, thresholds = sk_metrics.roc_curve(
121+
gt_labels_bin, pred_probs_bin, pos_label=1
122+
)
123+
auc = sk_metrics.auc(fpr, tpr)
124+
125+
metrics = {
126+
"linear_kappa": linear_kappa,
127+
"quadratic_kappa": quadratic_kappa,
128+
"auc": auc,
129+
"lin_dist": lin_dist,
130+
"quad_dist": quad_dist,
131+
"avg_lin_dist": avg_lin_dist,
132+
"avg_quad_dist": avg_quad_dist,
133+
}
134+
print(
135+
f"evaluating {len(gt_labels)} predictions: "
136+
f"lin. kappa {linear_kappa:.3f}, "
137+
f"quad. kappa {quadratic_kappa:.3f}, "
138+
f"auc. {auc:.3f}, "
139+
f"avg. lin. dist {avg_lin_dist:.3f}, "
140+
f"avg. quad. dist {avg_quad_dist:.3f}, "
141+
)
142+
143+
return metrics
144+
145+
146+
def merge_patients(site_result):
147+
merged_results = {}
148+
for k in site_result.keys():
149+
merged_results[k] = []
150+
site_result[k] = np.array(site_result[k]) # needed for merging
151+
merged_results["counts"] = []
152+
153+
patient_ids = site_result["patient_ids"]
154+
unique_patients = np.unique(patient_ids)
155+
print(
156+
f"Merging {len(patient_ids)} predictions from {len(unique_patients)} patients."
157+
)
158+
159+
for patient in unique_patients:
160+
idx = np.where(patient_ids == patient)
161+
assert np.size(idx) > 0, "no matching patient found!"
162+
merged_results["patient_ids"].append(patient)
163+
merged_results["counts"].append(np.size(idx))
164+
# merge labels
165+
merged_results["gt_labels"].append(np.unique(site_result["gt_labels"][idx]))
166+
merged_results["gt_labels_bin"].append(
167+
np.unique(site_result["gt_labels_bin"][idx])
168+
)
169+
# merged labels should be all the same
170+
assert len(merged_results["gt_labels"][-1]) == 1
171+
assert len(merged_results["gt_labels_bin"][-1]) == 1
172+
# average probs
173+
merged_results["pred_probs"].append(
174+
np.mean(site_result["pred_probs"][idx], axis=0)
175+
)
176+
merged_results["pred_probs_bin"].append(
177+
np.mean(site_result["pred_probs_bin"][idx])
178+
)
179+
assert len(merged_results["pred_probs"][-1]) == 4 # should be still four probs
180+
assert isinstance(
181+
merged_results["pred_probs_bin"][-1], float
182+
) # should be just one prob
183+
print(
184+
f"Found patients with these nr of exams: {np.unique(merged_results['counts'])}"
185+
)
186+
187+
return merged_results
188+
189+
190+
def compute_metrics(args):
191+
gt1 = read_ground_truth(args.gt1)
192+
gt2 = read_ground_truth(args.gt2)
193+
gt3 = read_ground_truth(args.gt3)
194+
ground_truth = pd.concat((gt1, gt2, gt3))
195+
pred_result = read_prediction(
196+
args.pred,
197+
gt=ground_truth[ground_truth["split"] == args.test_name],
198+
model_name=args.model_name,
199+
) # read predictions and merge with ground truth
200+
201+
print(f"Evaluating {args.model_name} on {args.test_name}:")
202+
overall_pred_result = {}
203+
204+
metrics = {}
205+
for s in site_names:
206+
if merge_patients:
207+
pred_result[s] = merge_patients(pred_result[s])
208+
209+
print(f"==={s}===")
210+
if not overall_pred_result:
211+
overall_pred_result = pred_result[s]
212+
else:
213+
[
214+
overall_pred_result[k].extend(pred_result[s][k])
215+
for k in overall_pred_result.keys()
216+
]
217+
metrics[s] = evaluate(pred_result[s])
218+
print("===overall===")
219+
metrics["overall"] = evaluate(overall_pred_result)
220+
221+
return metrics
222+
223+
224+
def main():
225+
parser = argparse.ArgumentParser()
226+
parser.add_argument(
227+
"--gt1", type=str, default="../../../dmist_files/dataset_site-1.json"
228+
)
229+
parser.add_argument(
230+
"--gt2", type=str, default="../../../dmist_files/dataset_site-2.json"
231+
)
232+
parser.add_argument(
233+
"--gt3", type=str, default="../../../dmist_files/dataset_site-3.json"
234+
)
235+
parser.add_argument(
236+
"--pred",
237+
type=str,
238+
default="../../../results_acr_5-11-2022/result_server/predictions.json",
239+
)
240+
parser.add_argument("--test_name", type=str, default="test1")
241+
parser.add_argument("--model_name", type=str, default="SRV_best_FL_global_model.pt")
242+
args = parser.parse_args()
243+
244+
metrics = compute_metrics(args)
245+
246+
# print(f"Evaluation metrics for {args.model_name} on {args.test_name}:")
247+
# print(metrics)
248+
249+
250+
if __name__ == "__main__":
251+
main()

0 commit comments

Comments
 (0)