|
| 1 | +# Copyright 2022 MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +import argparse |
| 13 | +import json |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +import pandas as pd |
| 17 | +from sklearn import metrics as sk_metrics |
| 18 | + |
| 19 | +site_names = ["site-1", "site-2", "site-3"] |
| 20 | +merge_patients = True |
| 21 | + |
| 22 | + |
| 23 | +def read_ground_truth(filename): |
| 24 | + with open(filename, "r") as f: |
| 25 | + data = json.load(f) |
| 26 | + |
| 27 | + df = {"patient_id": [], "image": [], "label": [], "split": []} |
| 28 | + for split in data.keys(): |
| 29 | + print(f"loading {split}: {len(data[split])} cases from {filename}") |
| 30 | + for item in data[split]: |
| 31 | + [df[k].append(item[k]) for k in item.keys()] |
| 32 | + df["split"].append(split) |
| 33 | + if "label" not in item.keys(): |
| 34 | + df["label"].append(np.NAN) |
| 35 | + |
| 36 | + return pd.DataFrame(df) |
| 37 | + |
| 38 | + |
| 39 | +def read_prediction(filename, gt, model_name): |
| 40 | + with open(filename, "r") as f: |
| 41 | + data = json.load(f) |
| 42 | + |
| 43 | + result = {} |
| 44 | + for s in site_names: |
| 45 | + result[s] = { |
| 46 | + "pred_probs": [], |
| 47 | + "gt_labels": [], |
| 48 | + "pred_probs_bin": [], |
| 49 | + "gt_labels_bin": [], |
| 50 | + "patient_ids": [], |
| 51 | + } |
| 52 | + for site in data.keys(): |
| 53 | + for item in data[site][model_name]["test_probs"]: |
| 54 | + # multi-class |
| 55 | + assert ( |
| 56 | + len(item["probs"]) == 4 |
| 57 | + ), f"Expected four probs but got {len(item['probs'])}: {item['probs']}" |
| 58 | + result[site]["pred_probs"].append(item["probs"]) |
| 59 | + gt_item = gt[gt["image"] == item["image"]] |
| 60 | + gt_label = gt_item["label"] |
| 61 | + assert len(gt_label) == 1, f"gt label was {gt_label}" |
| 62 | + result[site]["patient_ids"].append(gt_item["patient_id"].item()) |
| 63 | + result[site]["gt_labels"].append(gt_label.item()) |
| 64 | + |
| 65 | + # binary (non-dense vs dense) |
| 66 | + result[site]["pred_probs_bin"].append( |
| 67 | + np.sum(item["probs"][2::]) |
| 68 | + ) # prob for dense (class 3 and 4). |
| 69 | + if gt_label.item() in [0, 1]: # non-dense (class 1 and 2) |
| 70 | + result[site]["gt_labels_bin"].append(0) |
| 71 | + elif gt_label.item() in [2, 3]: # dense (class 3 and 4) |
| 72 | + result[site]["gt_labels_bin"].append(1) |
| 73 | + else: |
| 74 | + raise ValueError(f"didn't expect a label of {gt_label}") |
| 75 | + assert ( |
| 76 | + len(result[site]["gt_labels"]) |
| 77 | + == len(result[site]["pred_probs"]) |
| 78 | + == len(result[site]["gt_labels_bin"]) |
| 79 | + == len(result[site]["pred_probs_bin"]) |
| 80 | + == len(result[site]["patient_ids"]) |
| 81 | + ) |
| 82 | + assert len(np.unique(result[site]["gt_labels_bin"])) == 2, ( |
| 83 | + f"Expected two kinds of binary labels but got " |
| 84 | + f"unique labels {np.unique(result[site]['gt_labels_bin'])}" |
| 85 | + ) |
| 86 | + return result |
| 87 | + |
| 88 | + |
| 89 | +def evaluate(site_result): |
| 90 | + gt_labels = site_result["gt_labels"] |
| 91 | + pred_probs = site_result["pred_probs"] |
| 92 | + gt_labels_bin = site_result["gt_labels_bin"] |
| 93 | + pred_probs_bin = site_result["pred_probs_bin"] |
| 94 | + |
| 95 | + # get pred labels |
| 96 | + pred_labels = [] |
| 97 | + for prob in pred_probs: |
| 98 | + pred_labels.append(np.argmax(prob)) |
| 99 | + |
| 100 | + assert ( |
| 101 | + len(gt_labels) == len(pred_labels) == len(gt_labels_bin) == len(pred_probs_bin) |
| 102 | + ) |
| 103 | + |
| 104 | + # multi-class metrics |
| 105 | + linear_kappa = sk_metrics.cohen_kappa_score( |
| 106 | + gt_labels, pred_labels, weights="linear" |
| 107 | + ) |
| 108 | + quadratic_kappa = sk_metrics.cohen_kappa_score( |
| 109 | + gt_labels, pred_labels, weights="quadratic" |
| 110 | + ) |
| 111 | + |
| 112 | + # per-image distance metrics |
| 113 | + dist = np.abs(np.squeeze(gt_labels) - np.squeeze(pred_labels)) |
| 114 | + lin_dist = -dist |
| 115 | + quad_dist = -(dist ** 2) |
| 116 | + avg_lin_dist = np.mean(lin_dist) |
| 117 | + avg_quad_dist = np.mean(quad_dist) |
| 118 | + |
| 119 | + # binary metrics |
| 120 | + fpr, tpr, thresholds = sk_metrics.roc_curve( |
| 121 | + gt_labels_bin, pred_probs_bin, pos_label=1 |
| 122 | + ) |
| 123 | + auc = sk_metrics.auc(fpr, tpr) |
| 124 | + |
| 125 | + metrics = { |
| 126 | + "linear_kappa": linear_kappa, |
| 127 | + "quadratic_kappa": quadratic_kappa, |
| 128 | + "auc": auc, |
| 129 | + "lin_dist": lin_dist, |
| 130 | + "quad_dist": quad_dist, |
| 131 | + "avg_lin_dist": avg_lin_dist, |
| 132 | + "avg_quad_dist": avg_quad_dist, |
| 133 | + } |
| 134 | + print( |
| 135 | + f"evaluating {len(gt_labels)} predictions: " |
| 136 | + f"lin. kappa {linear_kappa:.3f}, " |
| 137 | + f"quad. kappa {quadratic_kappa:.3f}, " |
| 138 | + f"auc. {auc:.3f}, " |
| 139 | + f"avg. lin. dist {avg_lin_dist:.3f}, " |
| 140 | + f"avg. quad. dist {avg_quad_dist:.3f}, " |
| 141 | + ) |
| 142 | + |
| 143 | + return metrics |
| 144 | + |
| 145 | + |
| 146 | +def merge_patients(site_result): |
| 147 | + merged_results = {} |
| 148 | + for k in site_result.keys(): |
| 149 | + merged_results[k] = [] |
| 150 | + site_result[k] = np.array(site_result[k]) # needed for merging |
| 151 | + merged_results["counts"] = [] |
| 152 | + |
| 153 | + patient_ids = site_result["patient_ids"] |
| 154 | + unique_patients = np.unique(patient_ids) |
| 155 | + print( |
| 156 | + f"Merging {len(patient_ids)} predictions from {len(unique_patients)} patients." |
| 157 | + ) |
| 158 | + |
| 159 | + for patient in unique_patients: |
| 160 | + idx = np.where(patient_ids == patient) |
| 161 | + assert np.size(idx) > 0, "no matching patient found!" |
| 162 | + merged_results["patient_ids"].append(patient) |
| 163 | + merged_results["counts"].append(np.size(idx)) |
| 164 | + # merge labels |
| 165 | + merged_results["gt_labels"].append(np.unique(site_result["gt_labels"][idx])) |
| 166 | + merged_results["gt_labels_bin"].append( |
| 167 | + np.unique(site_result["gt_labels_bin"][idx]) |
| 168 | + ) |
| 169 | + # merged labels should be all the same |
| 170 | + assert len(merged_results["gt_labels"][-1]) == 1 |
| 171 | + assert len(merged_results["gt_labels_bin"][-1]) == 1 |
| 172 | + # average probs |
| 173 | + merged_results["pred_probs"].append( |
| 174 | + np.mean(site_result["pred_probs"][idx], axis=0) |
| 175 | + ) |
| 176 | + merged_results["pred_probs_bin"].append( |
| 177 | + np.mean(site_result["pred_probs_bin"][idx]) |
| 178 | + ) |
| 179 | + assert len(merged_results["pred_probs"][-1]) == 4 # should be still four probs |
| 180 | + assert isinstance( |
| 181 | + merged_results["pred_probs_bin"][-1], float |
| 182 | + ) # should be just one prob |
| 183 | + print( |
| 184 | + f"Found patients with these nr of exams: {np.unique(merged_results['counts'])}" |
| 185 | + ) |
| 186 | + |
| 187 | + return merged_results |
| 188 | + |
| 189 | + |
| 190 | +def compute_metrics(args): |
| 191 | + gt1 = read_ground_truth(args.gt1) |
| 192 | + gt2 = read_ground_truth(args.gt2) |
| 193 | + gt3 = read_ground_truth(args.gt3) |
| 194 | + ground_truth = pd.concat((gt1, gt2, gt3)) |
| 195 | + pred_result = read_prediction( |
| 196 | + args.pred, |
| 197 | + gt=ground_truth[ground_truth["split"] == args.test_name], |
| 198 | + model_name=args.model_name, |
| 199 | + ) # read predictions and merge with ground truth |
| 200 | + |
| 201 | + print(f"Evaluating {args.model_name} on {args.test_name}:") |
| 202 | + overall_pred_result = {} |
| 203 | + |
| 204 | + metrics = {} |
| 205 | + for s in site_names: |
| 206 | + if merge_patients: |
| 207 | + pred_result[s] = merge_patients(pred_result[s]) |
| 208 | + |
| 209 | + print(f"==={s}===") |
| 210 | + if not overall_pred_result: |
| 211 | + overall_pred_result = pred_result[s] |
| 212 | + else: |
| 213 | + [ |
| 214 | + overall_pred_result[k].extend(pred_result[s][k]) |
| 215 | + for k in overall_pred_result.keys() |
| 216 | + ] |
| 217 | + metrics[s] = evaluate(pred_result[s]) |
| 218 | + print("===overall===") |
| 219 | + metrics["overall"] = evaluate(overall_pred_result) |
| 220 | + |
| 221 | + return metrics |
| 222 | + |
| 223 | + |
| 224 | +def main(): |
| 225 | + parser = argparse.ArgumentParser() |
| 226 | + parser.add_argument( |
| 227 | + "--gt1", type=str, default="../../../dmist_files/dataset_site-1.json" |
| 228 | + ) |
| 229 | + parser.add_argument( |
| 230 | + "--gt2", type=str, default="../../../dmist_files/dataset_site-2.json" |
| 231 | + ) |
| 232 | + parser.add_argument( |
| 233 | + "--gt3", type=str, default="../../../dmist_files/dataset_site-3.json" |
| 234 | + ) |
| 235 | + parser.add_argument( |
| 236 | + "--pred", |
| 237 | + type=str, |
| 238 | + default="../../../results_acr_5-11-2022/result_server/predictions.json", |
| 239 | + ) |
| 240 | + parser.add_argument("--test_name", type=str, default="test1") |
| 241 | + parser.add_argument("--model_name", type=str, default="SRV_best_FL_global_model.pt") |
| 242 | + args = parser.parse_args() |
| 243 | + |
| 244 | + metrics = compute_metrics(args) |
| 245 | + |
| 246 | + # print(f"Evaluation metrics for {args.model_name} on {args.test_name}:") |
| 247 | + # print(metrics) |
| 248 | + |
| 249 | + |
| 250 | +if __name__ == "__main__": |
| 251 | + main() |
0 commit comments