|
29 | 29 | from monai.transforms import AsChannelFirstd, Compose, LoadNiftid, ScaleIntensityd, ToTensord
|
30 | 30 |
|
31 | 31 |
|
32 |
| -def main(): |
| 32 | +def main(tempdir): |
33 | 33 | monai.config.print_config()
|
34 | 34 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
35 | 35 |
|
36 |
| - with tempfile.TemporaryDirectory() as tempdir: |
37 |
| - print(f"generating synthetic data to {tempdir} (this may take a while)") |
38 |
| - for i in range(5): |
39 |
| - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) |
40 |
| - |
41 |
| - n = nib.Nifti1Image(im, np.eye(4)) |
42 |
| - nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) |
43 |
| - |
44 |
| - n = nib.Nifti1Image(seg, np.eye(4)) |
45 |
| - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) |
46 |
| - |
47 |
| - images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) |
48 |
| - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) |
49 |
| - val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] |
50 |
| - |
51 |
| - # define transforms for image and segmentation |
52 |
| - val_transforms = Compose( |
53 |
| - [ |
54 |
| - LoadNiftid(keys=["img", "seg"]), |
55 |
| - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), |
56 |
| - ScaleIntensityd(keys="img"), |
57 |
| - ToTensord(keys=["img", "seg"]), |
58 |
| - ] |
59 |
| - ) |
60 |
| - val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) |
61 |
| - # sliding window inference need to input 1 image in every iteration |
62 |
| - val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) |
63 |
| - dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") |
64 |
| - |
65 |
| - # try to use all the available GPUs |
66 |
| - devices = get_devices_spec(None) |
67 |
| - model = UNet( |
68 |
| - dimensions=3, |
69 |
| - in_channels=1, |
70 |
| - out_channels=1, |
71 |
| - channels=(16, 32, 64, 128, 256), |
72 |
| - strides=(2, 2, 2, 2), |
73 |
| - num_res_units=2, |
74 |
| - ).to(devices[0]) |
75 |
| - |
76 |
| - model.load_state_dict(torch.load("best_metric_model.pth")) |
77 |
| - |
78 |
| - # if we have multiple GPUs, set data parallel to execute sliding window inference |
79 |
| - if len(devices) > 1: |
80 |
| - model = torch.nn.DataParallel(model, device_ids=devices) |
81 |
| - |
82 |
| - model.eval() |
83 |
| - with torch.no_grad(): |
84 |
| - metric_sum = 0.0 |
85 |
| - metric_count = 0 |
86 |
| - saver = NiftiSaver(output_dir="./output") |
87 |
| - for val_data in val_loader: |
88 |
| - val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) |
89 |
| - # define sliding window size and batch size for windows inference |
90 |
| - roi_size = (96, 96, 96) |
91 |
| - sw_batch_size = 4 |
92 |
| - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) |
93 |
| - value = dice_metric(y_pred=val_outputs, y=val_labels) |
94 |
| - metric_count += len(value) |
95 |
| - metric_sum += value.item() * len(value) |
96 |
| - val_outputs = (val_outputs.sigmoid() >= 0.5).float() |
97 |
| - saver.save_batch(val_outputs, val_data["img_meta_dict"]) |
98 |
| - metric = metric_sum / metric_count |
99 |
| - print("evaluation metric:", metric) |
| 36 | + print(f"generating synthetic data to {tempdir} (this may take a while)") |
| 37 | + for i in range(5): |
| 38 | + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) |
| 39 | + |
| 40 | + n = nib.Nifti1Image(im, np.eye(4)) |
| 41 | + nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) |
| 42 | + |
| 43 | + n = nib.Nifti1Image(seg, np.eye(4)) |
| 44 | + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) |
| 45 | + |
| 46 | + images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) |
| 47 | + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) |
| 48 | + val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] |
| 49 | + |
| 50 | + # define transforms for image and segmentation |
| 51 | + val_transforms = Compose( |
| 52 | + [ |
| 53 | + LoadNiftid(keys=["img", "seg"]), |
| 54 | + AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), |
| 55 | + ScaleIntensityd(keys="img"), |
| 56 | + ToTensord(keys=["img", "seg"]), |
| 57 | + ] |
| 58 | + ) |
| 59 | + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) |
| 60 | + # sliding window inference need to input 1 image in every iteration |
| 61 | + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) |
| 62 | + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") |
| 63 | + |
| 64 | + # try to use all the available GPUs |
| 65 | + devices = get_devices_spec(None) |
| 66 | + model = UNet( |
| 67 | + dimensions=3, |
| 68 | + in_channels=1, |
| 69 | + out_channels=1, |
| 70 | + channels=(16, 32, 64, 128, 256), |
| 71 | + strides=(2, 2, 2, 2), |
| 72 | + num_res_units=2, |
| 73 | + ).to(devices[0]) |
| 74 | + |
| 75 | + model.load_state_dict(torch.load("best_metric_model.pth")) |
| 76 | + |
| 77 | + # if we have multiple GPUs, set data parallel to execute sliding window inference |
| 78 | + if len(devices) > 1: |
| 79 | + model = torch.nn.DataParallel(model, device_ids=devices) |
| 80 | + |
| 81 | + model.eval() |
| 82 | + with torch.no_grad(): |
| 83 | + metric_sum = 0.0 |
| 84 | + metric_count = 0 |
| 85 | + saver = NiftiSaver(output_dir="./output") |
| 86 | + for val_data in val_loader: |
| 87 | + val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) |
| 88 | + # define sliding window size and batch size for windows inference |
| 89 | + roi_size = (96, 96, 96) |
| 90 | + sw_batch_size = 4 |
| 91 | + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) |
| 92 | + value = dice_metric(y_pred=val_outputs, y=val_labels) |
| 93 | + metric_count += len(value) |
| 94 | + metric_sum += value.item() * len(value) |
| 95 | + val_outputs = (val_outputs.sigmoid() >= 0.5).float() |
| 96 | + saver.save_batch(val_outputs, val_data["img_meta_dict"]) |
| 97 | + metric = metric_sum / metric_count |
| 98 | + print("evaluation metric:", metric) |
100 | 99 |
|
101 | 100 |
|
102 | 101 | if __name__ == "__main__":
|
103 |
| - main() |
| 102 | + with tempfile.TemporaryDirectory() as tempdir: |
| 103 | + main(tempdir) |
0 commit comments