|
| 1 | +import os |
| 2 | +import glob |
| 3 | +import logging |
| 4 | +import torch |
| 5 | +from argparse import ArgumentParser |
| 6 | +from monai.data import DataLoader, CacheDataset |
| 7 | +from monai.networks.nets import HoVerNet |
| 8 | +from monai.engines import SupervisedEvaluator |
| 9 | +from monai.transforms import ( |
| 10 | + LoadImaged, |
| 11 | + Lambdad, |
| 12 | + Activationsd, |
| 13 | + Compose, |
| 14 | + CastToTyped, |
| 15 | + ComputeHoVerMapsd, |
| 16 | + ScaleIntensityRanged, |
| 17 | + CenterSpatialCropd, |
| 18 | +) |
| 19 | +from monai.handlers import ( |
| 20 | + MeanDice, |
| 21 | + StatsHandler, |
| 22 | + CheckpointLoader, |
| 23 | +) |
| 24 | +from monai.utils.enums import HoVerNetBranch |
| 25 | +from monai.apps.pathology.handlers.utils import from_engine_hovernet |
| 26 | +from monai.apps.pathology.engines.utils import PrepareBatchHoVerNet |
| 27 | +from skimage import measure |
| 28 | + |
| 29 | + |
| 30 | +def prepare_data(data_dir, phase): |
| 31 | + data_dir = os.path.join(data_dir, phase) |
| 32 | + |
| 33 | + images = list(sorted( |
| 34 | + glob.glob(os.path.join(data_dir, "*/*image.npy")))) |
| 35 | + inst_maps = list(sorted( |
| 36 | + glob.glob(os.path.join(data_dir, "*/*inst_map.npy")))) |
| 37 | + type_maps = list(sorted( |
| 38 | + glob.glob(os.path.join(data_dir, "*/*type_map.npy")))) |
| 39 | + |
| 40 | + data_dicts = [ |
| 41 | + {"image": _image, "label_inst": _inst_map, "label_type": _type_map} |
| 42 | + for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps) |
| 43 | + ] |
| 44 | + |
| 45 | + return data_dicts |
| 46 | + |
| 47 | + |
| 48 | +def run(cfg): |
| 49 | + if cfg["mode"].lower() == "original": |
| 50 | + cfg["patch_size"] = [270, 270] |
| 51 | + cfg["out_size"] = [80, 80] |
| 52 | + elif cfg["mode"].lower() == "fast": |
| 53 | + cfg["patch_size"] = [256, 256] |
| 54 | + cfg["out_size"] = [164, 164] |
| 55 | + |
| 56 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 57 | + val_transforms = Compose( |
| 58 | + [ |
| 59 | + LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True), |
| 60 | + Lambdad(keys="label_inst", func=lambda x: measure.label(x)), |
| 61 | + CastToTyped(keys=["image", "label_inst"], dtype=torch.int), |
| 62 | + CenterSpatialCropd( |
| 63 | + keys="image", |
| 64 | + roi_size=cfg["patch_size"], |
| 65 | + ), |
| 66 | + ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), |
| 67 | + ComputeHoVerMapsd(keys="label_inst"), |
| 68 | + Lambdad(keys="label_inst", func=lambda x: x > 0, overwrite="label"), |
| 69 | + CenterSpatialCropd( |
| 70 | + keys=["label", "hover_label_inst", "label_inst", "label_type"], |
| 71 | + roi_size=cfg["out_size"], |
| 72 | + ), |
| 73 | + CastToTyped(keys=["image", "label_inst", "label_type"], dtype=torch.float32), |
| 74 | + ] |
| 75 | + ) |
| 76 | + |
| 77 | + # Create MONAI DataLoaders |
| 78 | + valid_data = prepare_data(cfg["root"], "valid") |
| 79 | + valid_ds = CacheDataset(data=valid_data, transform=val_transforms, cache_rate=1.0, num_workers=4) |
| 80 | + val_loader = DataLoader( |
| 81 | + valid_ds, |
| 82 | + batch_size=cfg["batch_size"], |
| 83 | + num_workers=cfg["num_workers"], |
| 84 | + pin_memory=torch.cuda.is_available() |
| 85 | + ) |
| 86 | + |
| 87 | + # initialize model |
| 88 | + model = HoVerNet( |
| 89 | + mode=cfg["mode"], |
| 90 | + in_channels=3, |
| 91 | + out_classes=cfg["out_classes"], |
| 92 | + act=("relu", {"inplace": True}), |
| 93 | + norm="batch", |
| 94 | + pretrained_url=None, |
| 95 | + freeze_encoder=False, |
| 96 | + ).to(device) |
| 97 | + |
| 98 | + post_process_np = Compose([ |
| 99 | + Activationsd(keys=HoVerNetBranch.NP.value, softmax=True), |
| 100 | + Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1: 2, ...] > 0.5)]) |
| 101 | + post_process = Lambdad(keys="pred", func=post_process_np) |
| 102 | + |
| 103 | + # Evaluator |
| 104 | + val_handlers = [ |
| 105 | + CheckpointLoader(load_path=cfg["ckpt_path"], load_dict={"net": model}), |
| 106 | + StatsHandler(output_transform=lambda x: None), |
| 107 | + ] |
| 108 | + evaluator = SupervisedEvaluator( |
| 109 | + device=device, |
| 110 | + val_data_loader=val_loader, |
| 111 | + prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']), |
| 112 | + network=model, |
| 113 | + postprocessing=post_process, |
| 114 | + key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))}, |
| 115 | + val_handlers=val_handlers, |
| 116 | + amp=cfg["amp"], |
| 117 | + ) |
| 118 | + |
| 119 | + state = evaluator.run() |
| 120 | + print(state) |
| 121 | + |
| 122 | + |
| 123 | +def main(): |
| 124 | + parser = ArgumentParser(description="Tumor detection on whole slide pathology images.") |
| 125 | + parser.add_argument( |
| 126 | + "--root", |
| 127 | + type=str, |
| 128 | + default="/workspace/Data/CoNSeP/Prepared/consep", |
| 129 | + help="root data dir", |
| 130 | + ) |
| 131 | + |
| 132 | + parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size") |
| 133 | + parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp") |
| 134 | + parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes") |
| 135 | + parser.add_argument("--mode", type=str, default="original", help="choose either `original` or `fast`") |
| 136 | + |
| 137 | + parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers") |
| 138 | + parser.add_argument("--use_gpu", type=bool, default=True, dest="use_gpu", help="whether to use gpu") |
| 139 | + parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path") |
| 140 | + |
| 141 | + args = parser.parse_args() |
| 142 | + cfg = vars(args) |
| 143 | + print(cfg) |
| 144 | + |
| 145 | + logging.basicConfig(level=logging.INFO) |
| 146 | + run(cfg) |
| 147 | + |
| 148 | + |
| 149 | +if __name__ == "__main__": |
| 150 | + main() |
0 commit comments