Skip to content

Commit 57f8dbf

Browse files
Nic-Mamonai-bot
andauthored
895 Update temp directory function in all tests and examples (#899)
* [DLMED] update tempdir * [DLMED] update tempdir in examples * [DLMED] fix typo * [MONAI] python code formatting Co-authored-by: monai-bot <[email protected]>
1 parent 52e6278 commit 57f8dbf

40 files changed

+1782
-1904
lines changed

examples/segmentation_3d/unet_evaluation_array.py

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import logging
1313
import os
14-
import shutil
1514
import sys
1615
import tempfile
1716
from glob import glob
@@ -33,58 +32,57 @@ def main():
3332
config.print_config()
3433
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
3534

36-
tempdir = tempfile.mkdtemp()
37-
print(f"generating synthetic data to {tempdir} (this may take a while)")
38-
for i in range(5):
39-
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)
35+
with tempfile.TemporaryDirectory() as tempdir:
36+
print(f"generating synthetic data to {tempdir} (this may take a while)")
37+
for i in range(5):
38+
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)
4039

41-
n = nib.Nifti1Image(im, np.eye(4))
42-
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
40+
n = nib.Nifti1Image(im, np.eye(4))
41+
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
4342

44-
n = nib.Nifti1Image(seg, np.eye(4))
45-
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
43+
n = nib.Nifti1Image(seg, np.eye(4))
44+
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
4645

47-
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
48-
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
46+
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
47+
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
4948

50-
# define transforms for image and segmentation
51-
imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
52-
segtrans = Compose([AddChannel(), ToTensor()])
53-
val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
54-
# sliding window inference for one image at every iteration
55-
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
56-
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
49+
# define transforms for image and segmentation
50+
imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
51+
segtrans = Compose([AddChannel(), ToTensor()])
52+
val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
53+
# sliding window inference for one image at every iteration
54+
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
55+
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
5756

58-
device = torch.device("cuda:0")
59-
model = UNet(
60-
dimensions=3,
61-
in_channels=1,
62-
out_channels=1,
63-
channels=(16, 32, 64, 128, 256),
64-
strides=(2, 2, 2, 2),
65-
num_res_units=2,
66-
).to(device)
57+
device = torch.device("cuda:0")
58+
model = UNet(
59+
dimensions=3,
60+
in_channels=1,
61+
out_channels=1,
62+
channels=(16, 32, 64, 128, 256),
63+
strides=(2, 2, 2, 2),
64+
num_res_units=2,
65+
).to(device)
6766

68-
model.load_state_dict(torch.load("best_metric_model.pth"))
69-
model.eval()
70-
with torch.no_grad():
71-
metric_sum = 0.0
72-
metric_count = 0
73-
saver = NiftiSaver(output_dir="./output")
74-
for val_data in val_loader:
75-
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
76-
# define sliding window size and batch size for windows inference
77-
roi_size = (96, 96, 96)
78-
sw_batch_size = 4
79-
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
80-
value = dice_metric(y_pred=val_outputs, y=val_labels)
81-
metric_count += len(value)
82-
metric_sum += value.item() * len(value)
83-
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
84-
saver.save_batch(val_outputs, val_data[2])
85-
metric = metric_sum / metric_count
86-
print("evaluation metric:", metric)
87-
shutil.rmtree(tempdir)
67+
model.load_state_dict(torch.load("best_metric_model.pth"))
68+
model.eval()
69+
with torch.no_grad():
70+
metric_sum = 0.0
71+
metric_count = 0
72+
saver = NiftiSaver(output_dir="./output")
73+
for val_data in val_loader:
74+
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
75+
# define sliding window size and batch size for windows inference
76+
roi_size = (96, 96, 96)
77+
sw_batch_size = 4
78+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
79+
value = dice_metric(y_pred=val_outputs, y=val_labels)
80+
metric_count += len(value)
81+
metric_sum += value.item() * len(value)
82+
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
83+
saver.save_batch(val_outputs, val_data[2])
84+
metric = metric_sum / metric_count
85+
print("evaluation metric:", metric)
8886

8987

9088
if __name__ == "__main__":

examples/segmentation_3d/unet_evaluation_dict.py

Lines changed: 64 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import logging
1313
import os
14-
import shutil
1514
import sys
1615
import tempfile
1716
from glob import glob
@@ -34,71 +33,70 @@ def main():
3433
monai.config.print_config()
3534
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
3635

37-
tempdir = tempfile.mkdtemp()
38-
print(f"generating synthetic data to {tempdir} (this may take a while)")
39-
for i in range(5):
40-
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
41-
42-
n = nib.Nifti1Image(im, np.eye(4))
43-
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
44-
45-
n = nib.Nifti1Image(seg, np.eye(4))
46-
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
47-
48-
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
49-
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
50-
val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]
51-
52-
# define transforms for image and segmentation
53-
val_transforms = Compose(
54-
[
55-
LoadNiftid(keys=["img", "seg"]),
56-
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
57-
ScaleIntensityd(keys="img"),
58-
ToTensord(keys=["img", "seg"]),
59-
]
60-
)
61-
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
62-
# sliding window inference need to input 1 image in every iteration
63-
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
64-
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
65-
66-
# try to use all the available GPUs
67-
devices = get_devices_spec(None)
68-
model = UNet(
69-
dimensions=3,
70-
in_channels=1,
71-
out_channels=1,
72-
channels=(16, 32, 64, 128, 256),
73-
strides=(2, 2, 2, 2),
74-
num_res_units=2,
75-
).to(devices[0])
76-
77-
model.load_state_dict(torch.load("best_metric_model.pth"))
78-
79-
# if we have multiple GPUs, set data parallel to execute sliding window inference
80-
if len(devices) > 1:
81-
model = torch.nn.DataParallel(model, device_ids=devices)
82-
83-
model.eval()
84-
with torch.no_grad():
85-
metric_sum = 0.0
86-
metric_count = 0
87-
saver = NiftiSaver(output_dir="./output")
88-
for val_data in val_loader:
89-
val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
90-
# define sliding window size and batch size for windows inference
91-
roi_size = (96, 96, 96)
92-
sw_batch_size = 4
93-
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
94-
value = dice_metric(y_pred=val_outputs, y=val_labels)
95-
metric_count += len(value)
96-
metric_sum += value.item() * len(value)
97-
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
98-
saver.save_batch(val_outputs, val_data["img_meta_dict"])
99-
metric = metric_sum / metric_count
100-
print("evaluation metric:", metric)
101-
shutil.rmtree(tempdir)
36+
with tempfile.TemporaryDirectory() as tempdir:
37+
print(f"generating synthetic data to {tempdir} (this may take a while)")
38+
for i in range(5):
39+
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
40+
41+
n = nib.Nifti1Image(im, np.eye(4))
42+
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
43+
44+
n = nib.Nifti1Image(seg, np.eye(4))
45+
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
46+
47+
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
48+
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
49+
val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]
50+
51+
# define transforms for image and segmentation
52+
val_transforms = Compose(
53+
[
54+
LoadNiftid(keys=["img", "seg"]),
55+
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
56+
ScaleIntensityd(keys="img"),
57+
ToTensord(keys=["img", "seg"]),
58+
]
59+
)
60+
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
61+
# sliding window inference need to input 1 image in every iteration
62+
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
63+
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
64+
65+
# try to use all the available GPUs
66+
devices = get_devices_spec(None)
67+
model = UNet(
68+
dimensions=3,
69+
in_channels=1,
70+
out_channels=1,
71+
channels=(16, 32, 64, 128, 256),
72+
strides=(2, 2, 2, 2),
73+
num_res_units=2,
74+
).to(devices[0])
75+
76+
model.load_state_dict(torch.load("best_metric_model.pth"))
77+
78+
# if we have multiple GPUs, set data parallel to execute sliding window inference
79+
if len(devices) > 1:
80+
model = torch.nn.DataParallel(model, device_ids=devices)
81+
82+
model.eval()
83+
with torch.no_grad():
84+
metric_sum = 0.0
85+
metric_count = 0
86+
saver = NiftiSaver(output_dir="./output")
87+
for val_data in val_loader:
88+
val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0])
89+
# define sliding window size and batch size for windows inference
90+
roi_size = (96, 96, 96)
91+
sw_batch_size = 4
92+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
93+
value = dice_metric(y_pred=val_outputs, y=val_labels)
94+
metric_count += len(value)
95+
metric_sum += value.item() * len(value)
96+
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
97+
saver.save_batch(val_outputs, val_data["img_meta_dict"])
98+
metric = metric_sum / metric_count
99+
print("evaluation metric:", metric)
102100

103101

104102
if __name__ == "__main__":

0 commit comments

Comments
 (0)