|
| 1 | +# Copyright (c) MediaTek Inc. |
| 2 | +# All rights reserved |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import sys |
| 8 | +import os |
| 9 | +import piq |
| 10 | +import numpy as np |
| 11 | +import argparse |
| 12 | +import torch |
| 13 | +import json |
| 14 | + |
| 15 | + |
| 16 | +def check_data(target_f, predict_f): |
| 17 | + target_files = os.listdir(target_f) |
| 18 | + predict_files = os.listdir(predict_f) |
| 19 | + if len(target_files) != len(predict_files): |
| 20 | + raise RuntimeError("Data number in target folder and prediction folder must be same") |
| 21 | + |
| 22 | + predict_set = set(predict_files) |
| 23 | + for f in target_files: |
| 24 | + # target file naming rule is golden_sampleId_outId.bin |
| 25 | + # predict file naming rule is output_sampleId_outId.bin |
| 26 | + pred_name = f.replace("golden", "output") |
| 27 | + try: |
| 28 | + predict_set.remove(pred_name) |
| 29 | + except KeyError: |
| 30 | + raise RuntimeError(f"Cannot find {pred_name} in {predict_f}") |
| 31 | + |
| 32 | + if predict_set: |
| 33 | + target_name = next(predict_set).replace("output", "golden") |
| 34 | + raise RuntimeError(f"Cannot find {target_name} in {target_f}") |
| 35 | + |
| 36 | + |
| 37 | +def eval_topk(target_f, predict_f): |
| 38 | + def solve(prob, target, k): |
| 39 | + _, indices = torch.topk(prob, k=k, sorted=True) |
| 40 | + golden = torch.reshape(target, [-1, 1]) |
| 41 | + correct = (golden == indices) |
| 42 | + if torch.any(correct): |
| 43 | + return 1 |
| 44 | + else: |
| 45 | + return 0 |
| 46 | + |
| 47 | + target_files = os.listdir(target_f) |
| 48 | + |
| 49 | + cnt10 = 0 |
| 50 | + cnt50 = 0 |
| 51 | + for target_name in target_files: |
| 52 | + pred_name = target_name.replace("golden", "output") |
| 53 | + |
| 54 | + pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) |
| 55 | + target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.int64)[0] |
| 56 | + cnt10 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 10) |
| 57 | + cnt50 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 50) |
| 58 | + |
| 59 | + print("Top10 acc:", cnt10 * 100.0 / len(target_files)) |
| 60 | + print("Top50 acc:", cnt50 * 100.0 / len(target_files)) |
| 61 | + |
| 62 | + |
| 63 | +def eval_piq(target_f, predict_f): |
| 64 | + target_files = os.listdir(target_f) |
| 65 | + |
| 66 | + psnr_list = [] |
| 67 | + ssim_list = [] |
| 68 | + for target_name in target_files: |
| 69 | + pred_name = target_name.replace("golden", "output") |
| 70 | + hr = np.fromfile(os.path.join(target_f, target_name), dtype=np.float32) |
| 71 | + hr = hr.reshape((1,448,448,3)) |
| 72 | + hr = np.moveaxis(hr, 3, 1) |
| 73 | + hr = torch.from_numpy(hr) |
| 74 | + |
| 75 | + sr = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) |
| 76 | + sr = sr.reshape((1,448,448,3)) |
| 77 | + sr = np.moveaxis(sr, 3, 1) |
| 78 | + sr = torch.from_numpy(sr).clamp(0, 1) |
| 79 | + |
| 80 | + psnr_list.append(piq.psnr(hr, sr)) |
| 81 | + ssim_list.append(piq.ssim(hr, sr)) |
| 82 | + |
| 83 | + avg_psnr = sum(psnr_list).item() / len(psnr_list) |
| 84 | + avg_ssim = sum(ssim_list).item() / len(ssim_list) |
| 85 | + |
| 86 | + print(f"Avg of PSNR is: {avg_psnr}") |
| 87 | + print(f"Avg of SSIM is: {avg_ssim}") |
| 88 | + |
| 89 | + |
| 90 | +def eval_segmentation(target_f, predict_f): |
| 91 | + classes = [ |
| 92 | + "Backround", |
| 93 | + "Aeroplane", |
| 94 | + "Bicycle", |
| 95 | + "Bird", |
| 96 | + "Boat", |
| 97 | + "Bottle", |
| 98 | + "Bus", |
| 99 | + "Car", |
| 100 | + "Cat", |
| 101 | + "Chair", |
| 102 | + "Cow", |
| 103 | + "DiningTable", |
| 104 | + "Dog", |
| 105 | + "Horse", |
| 106 | + "MotorBike", |
| 107 | + "Person", |
| 108 | + "PottedPlant", |
| 109 | + "Sheep", |
| 110 | + "Sofa", |
| 111 | + "Train", |
| 112 | + "TvMonitor", |
| 113 | + ] |
| 114 | + |
| 115 | + target_files = os.listdir(target_f) |
| 116 | + |
| 117 | + def make_confusion(goldens, predictions, num_classes): |
| 118 | + def histogram(golden, predict): |
| 119 | + mask = golden < num_classes |
| 120 | + hist = np.bincount( |
| 121 | + num_classes * golden[mask].astype(int) + predict[mask], |
| 122 | + minlength=num_classes**2, |
| 123 | + ).reshape(num_classes, num_classes) |
| 124 | + return hist |
| 125 | + |
| 126 | + confusion = np.zeros((num_classes, num_classes)) |
| 127 | + for g, p in zip(goldens, predictions): |
| 128 | + confusion += histogram(g.flatten(), p.flatten()) |
| 129 | + |
| 130 | + return confusion |
| 131 | + |
| 132 | + pred_list = [] |
| 133 | + target_list = [] |
| 134 | + for target_name in target_files: |
| 135 | + pred_name = target_name.replace("golden", "output") |
| 136 | + target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.uint8) |
| 137 | + target_npy = target_npy.reshape((224, 224)) |
| 138 | + target_list.append(target_npy) |
| 139 | + |
| 140 | + pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) |
| 141 | + pred_npy = pred_npy.reshape((224, 224, len(classes))) |
| 142 | + pred_npy = pred_npy.argmax(2).astype(np.uint8) |
| 143 | + pred_list.append(pred_npy) |
| 144 | + |
| 145 | + eps = 1e-6 |
| 146 | + confusion = make_confusion(target_list, pred_list, len(classes)) |
| 147 | + |
| 148 | + pa = np.diag(confusion).sum() / (confusion.sum() + eps) |
| 149 | + mpa = np.mean(np.diag(confusion) / (confusion.sum(axis=1) + eps)) |
| 150 | + iou = np.diag(confusion) / ( |
| 151 | + confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion) + eps |
| 152 | + ) |
| 153 | + miou = np.mean(iou) |
| 154 | + cls_iou = dict(zip(classes, iou)) |
| 155 | + |
| 156 | + print(f"PA : {pa}") |
| 157 | + print(f"MPA : {mpa}") |
| 158 | + print(f"MIoU : {miou}") |
| 159 | + print(f"CIoU : \n{json.dumps(cls_iou, indent=2)}") |
| 160 | + |
| 161 | + |
| 162 | +if __name__ == "__main__": |
| 163 | + parser = argparse.ArgumentParser() |
| 164 | + |
| 165 | + parser.add_argument( |
| 166 | + "--target_f", |
| 167 | + help="folder of target data", |
| 168 | + type=str, |
| 169 | + required=True, |
| 170 | + ) |
| 171 | + |
| 172 | + parser.add_argument( |
| 173 | + "--out_f", |
| 174 | + help="folder of model prediction data", |
| 175 | + type=str, |
| 176 | + required=True, |
| 177 | + ) |
| 178 | + |
| 179 | + parser.add_argument( |
| 180 | + "--eval_type", |
| 181 | + help="Choose eval type from: topk, piq, segmentation", |
| 182 | + type=str, |
| 183 | + choices=["topk", "piq", "segmentation"], |
| 184 | + required=True, |
| 185 | + ) |
| 186 | + |
| 187 | + args = parser.parse_args() |
| 188 | + |
| 189 | + check_data(args.target_f, args.out_f) |
| 190 | + |
| 191 | + if args.eval_type == 'topk': |
| 192 | + eval_topk(args.target_f, args.out_f) |
| 193 | + elif args.eval_type == 'piq': |
| 194 | + eval_piq(args.target_f, args.out_f) |
| 195 | + elif args.eval_type == 'segmentation': |
| 196 | + eval_segmentation(args.target_f, args.out_f) |
| 197 | + |
| 198 | + |
0 commit comments