Skip to content

Commit 0d197e6

Browse files
Nic-Mamonai-bot
andauthored
Update tempdir logic in examples and fix CI issue (#900)
* [DLMED] update according to comments * [DLMED] fix tensorborad summary issue * [MONAI] python code formatting Co-authored-by: monai-bot <[email protected]>
1 parent 57f8dbf commit 0d197e6

11 files changed

+992
-999
lines changed

examples/segmentation_3d/unet_evaluation_array.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -28,62 +28,62 @@
2828
from monai.transforms import AddChannel, Compose, ScaleIntensity, ToTensor
2929

3030

31-
def main():
31+
def main(tempdir):
3232
config.print_config()
3333
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
3434

35-
with tempfile.TemporaryDirectory() as tempdir:
36-
print(f"generating synthetic data to {tempdir} (this may take a while)")
37-
for i in range(5):
38-
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)
35+
print(f"generating synthetic data to {tempdir} (this may take a while)")
36+
for i in range(5):
37+
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)
3938

40-
n = nib.Nifti1Image(im, np.eye(4))
41-
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
39+
n = nib.Nifti1Image(im, np.eye(4))
40+
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
4241

43-
n = nib.Nifti1Image(seg, np.eye(4))
44-
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
42+
n = nib.Nifti1Image(seg, np.eye(4))
43+
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
4544

46-
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
47-
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
45+
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
46+
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
4847

49-
# define transforms for image and segmentation
50-
imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
51-
segtrans = Compose([AddChannel(), ToTensor()])
52-
val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
53-
# sliding window inference for one image at every iteration
54-
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
55-
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
48+
# define transforms for image and segmentation
49+
imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
50+
segtrans = Compose([AddChannel(), ToTensor()])
51+
val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
52+
# sliding window inference for one image at every iteration
53+
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
54+
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
5655

57-
device = torch.device("cuda:0")
58-
model = UNet(
59-
dimensions=3,
60-
in_channels=1,
61-
out_channels=1,
62-
channels=(16, 32, 64, 128, 256),
63-
strides=(2, 2, 2, 2),
64-
num_res_units=2,
65-
).to(device)
56+
device = torch.device("cuda:0")
57+
model = UNet(
58+
dimensions=3,
59+
in_channels=1,
60+
out_channels=1,
61+
channels=(16, 32, 64, 128, 256),
62+
strides=(2, 2, 2, 2),
63+
num_res_units=2,
64+
).to(device)
6665

67-
model.load_state_dict(torch.load("best_metric_model.pth"))
68-
model.eval()
69-
with torch.no_grad():
70-
metric_sum = 0.0
71-
metric_count = 0
72-
saver = NiftiSaver(output_dir="./output")
73-
for val_data in val_loader:
74-
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
75-
# define sliding window size and batch size for windows inference
76-
roi_size = (96, 96, 96)
77-
sw_batch_size = 4
78-
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
79-
value = dice_metric(y_pred=val_outputs, y=val_labels)
80-
metric_count += len(value)
81-
metric_sum += value.item() * len(value)
82-
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
83-
saver.save_batch(val_outputs, val_data[2])
84-
metric = metric_sum / metric_count
85-
print("evaluation metric:", metric)
66+
model.load_state_dict(torch.load("best_metric_model.pth"))
67+
model.eval()
68+
with torch.no_grad():
69+
metric_sum = 0.0
70+
metric_count = 0
71+
saver = NiftiSaver(output_dir="./output")
72+
for val_data in val_loader:
73+
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
74+
# define sliding window size and batch size for windows inference
75+
roi_size = (96, 96, 96)
76+
sw_batch_size = 4
77+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
78+
value = dice_metric(y_pred=val_outputs, y=val_labels)
79+
metric_count += len(value)
80+
metric_sum += value.item() * len(value)
81+
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
82+
saver.save_batch(val_outputs, val_data[2])
83+
metric = metric_sum / metric_count
84+
print("evaluation metric:", metric)
8685

8786

8887
if __name__ == "__main__":
89-
main()
88+
with tempfile.TemporaryDirectory() as tempdir:
89+
main(tempdir)

examples/segmentation_3d/unet_evaluation_dict.py

Lines changed: 66 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -29,75 +29,75 @@
2929
from monai.transforms import AsChannelFirstd, Compose, LoadNiftid, ScaleIntensityd, ToTensord
3030

3131

32-
def main():
32+
def main(tempdir):
3333
monai.config.print_config()
3434
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
3535

36-
with tempfile.TemporaryDirectory() as tempdir:
37-
print(f"generating synthetic data to {tempdir} (this may take a while)")
38-
for i in range(5):
39-
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
40-
41-
n = nib.Nifti1Image(im, np.eye(4))
42-
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
43-
44-
n = nib.Nifti1Image(seg, np.eye(4))
45-
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
46-
47-
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
48-
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
49-
val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]
50-
51-
# define transforms for image and segmentation
52-
val_transforms = Compose(
53-
[
54-
LoadNiftid(keys=["img", "seg"]),
55-
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
56-
ScaleIntensityd(keys="img"),
57-
ToTensord(keys=["img", "seg"]),
58-
]
59-
)
60-
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
61-
# sliding window inference need to input 1 image in every iteration
62-
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
63-
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
64-
65-
# try to use all the available GPUs
66-
devices = get_devices_spec(None)
67-
model = UNet(
68-
dimensions=3,
69-
in_channels=1,
70-
out_channels=1,
71-
channels=(16, 32, 64, 128, 256),
72-
strides=(2, 2, 2, 2),
73-
num_res_units=2,
74-
).to(devices[0])
75-
76-
model.load_state_dict(torch.load("best_metric_model.pth"))
77-
78-
# if we have multiple GPUs, set data parallel to execute sliding window inference
79-
if len(devices) > 1:
80-
model = torch.nn.DataParallel(model, device_ids=devices)
81-
82-
model.eval()
83-
with torch.no_grad():
84-
metric_sum = 0.0
85-
metric_count = 0
86-
saver = NiftiSaver(output_dir="./output")
87-
for val_data in val_loader:
88-
val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
89-
# define sliding window size and batch size for windows inference
90-
roi_size = (96, 96, 96)
91-
sw_batch_size = 4
92-
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
93-
value = dice_metric(y_pred=val_outputs, y=val_labels)
94-
metric_count += len(value)
95-
metric_sum += value.item() * len(value)
96-
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
97-
saver.save_batch(val_outputs, val_data["img_meta_dict"])
98-
metric = metric_sum / metric_count
99-
print("evaluation metric:", metric)
36+
print(f"generating synthetic data to {tempdir} (this may take a while)")
37+
for i in range(5):
38+
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
39+
40+
n = nib.Nifti1Image(im, np.eye(4))
41+
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
42+
43+
n = nib.Nifti1Image(seg, np.eye(4))
44+
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
45+
46+
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
47+
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
48+
val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]
49+
50+
# define transforms for image and segmentation
51+
val_transforms = Compose(
52+
[
53+
LoadNiftid(keys=["img", "seg"]),
54+
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
55+
ScaleIntensityd(keys="img"),
56+
ToTensord(keys=["img", "seg"]),
57+
]
58+
)
59+
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
60+
# sliding window inference need to input 1 image in every iteration
61+
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
62+
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
63+
64+
# try to use all the available GPUs
65+
devices = get_devices_spec(None)
66+
model = UNet(
67+
dimensions=3,
68+
in_channels=1,
69+
out_channels=1,
70+
channels=(16, 32, 64, 128, 256),
71+
strides=(2, 2, 2, 2),
72+
num_res_units=2,
73+
).to(devices[0])
74+
75+
model.load_state_dict(torch.load("best_metric_model.pth"))
76+
77+
# if we have multiple GPUs, set data parallel to execute sliding window inference
78+
if len(devices) > 1:
79+
model = torch.nn.DataParallel(model, device_ids=devices)
80+
81+
model.eval()
82+
with torch.no_grad():
83+
metric_sum = 0.0
84+
metric_count = 0
85+
saver = NiftiSaver(output_dir="./output")
86+
for val_data in val_loader:
87+
val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
88+
# define sliding window size and batch size for windows inference
89+
roi_size = (96, 96, 96)
90+
sw_batch_size = 4
91+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
92+
value = dice_metric(y_pred=val_outputs, y=val_labels)
93+
metric_count += len(value)
94+
metric_sum += value.item() * len(value)
95+
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
96+
saver.save_batch(val_outputs, val_data["img_meta_dict"])
97+
metric = metric_sum / metric_count
98+
print("evaluation metric:", metric)
10099

101100

102101
if __name__ == "__main__":
103-
main()
102+
with tempfile.TemporaryDirectory() as tempdir:
103+
main(tempdir)

0 commit comments

Comments
 (0)