11
11
12
12
import logging
13
13
import os
14
- import shutil
15
14
import sys
16
15
import tempfile
17
16
from glob import glob
@@ -34,71 +33,70 @@ def main():
34
33
monai .config .print_config ()
35
34
logging .basicConfig (stream = sys .stdout , level = logging .INFO )
36
35
37
- tempdir = tempfile .mkdtemp ()
38
- print (f"generating synthetic data to { tempdir } (this may take a while)" )
39
- for i in range (5 ):
40
- im , seg = create_test_image_3d (128 , 128 , 128 , num_seg_classes = 1 , channel_dim = - 1 )
41
-
42
- n = nib .Nifti1Image (im , np .eye (4 ))
43
- nib .save (n , os .path .join (tempdir , f"im{ i :d} .nii.gz" ))
44
-
45
- n = nib .Nifti1Image (seg , np .eye (4 ))
46
- nib .save (n , os .path .join (tempdir , f"seg{ i :d} .nii.gz" ))
47
-
48
- images = sorted (glob (os .path .join (tempdir , "im*.nii.gz" )))
49
- segs = sorted (glob (os .path .join (tempdir , "seg*.nii.gz" )))
50
- val_files = [{"img" : img , "seg" : seg } for img , seg in zip (images , segs )]
51
-
52
- # define transforms for image and segmentation
53
- val_transforms = Compose (
54
- [
55
- LoadNiftid (keys = ["img" , "seg" ]),
56
- AsChannelFirstd (keys = ["img" , "seg" ], channel_dim = - 1 ),
57
- ScaleIntensityd (keys = "img" ),
58
- ToTensord (keys = ["img" , "seg" ]),
59
- ]
60
- )
61
- val_ds = monai .data .Dataset (data = val_files , transform = val_transforms )
62
- # sliding window inference need to input 1 image in every iteration
63
- val_loader = DataLoader (val_ds , batch_size = 1 , num_workers = 4 , collate_fn = list_data_collate )
64
- dice_metric = DiceMetric (include_background = True , to_onehot_y = False , sigmoid = True , reduction = "mean" )
65
-
66
- # try to use all the available GPUs
67
- devices = get_devices_spec (None )
68
- model = UNet (
69
- dimensions = 3 ,
70
- in_channels = 1 ,
71
- out_channels = 1 ,
72
- channels = (16 , 32 , 64 , 128 , 256 ),
73
- strides = (2 , 2 , 2 , 2 ),
74
- num_res_units = 2 ,
75
- ).to (devices [0 ])
76
-
77
- model .load_state_dict (torch .load ("best_metric_model.pth" ))
78
-
79
- # if we have multiple GPUs, set data parallel to execute sliding window inference
80
- if len (devices ) > 1 :
81
- model = torch .nn .DataParallel (model , device_ids = devices )
82
-
83
- model .eval ()
84
- with torch .no_grad ():
85
- metric_sum = 0.0
86
- metric_count = 0
87
- saver = NiftiSaver (output_dir = "./output" )
88
- for val_data in val_loader :
89
- val_images , val_labels = val_data ["img" ].to (devices [0 ]), val_data ["seg" ].to (devices [0 ])
90
- # define sliding window size and batch size for windows inference
91
- roi_size = (96 , 96 , 96 )
92
- sw_batch_size = 4
93
- val_outputs = sliding_window_inference (val_images , roi_size , sw_batch_size , model )
94
- value = dice_metric (y_pred = val_outputs , y = val_labels )
95
- metric_count += len (value )
96
- metric_sum += value .item () * len (value )
97
- val_outputs = (val_outputs .sigmoid () >= 0.5 ).float ()
98
- saver .save_batch (val_outputs , val_data ["img_meta_dict" ])
99
- metric = metric_sum / metric_count
100
- print ("evaluation metric:" , metric )
101
- shutil .rmtree (tempdir )
36
+ with tempfile .TemporaryDirectory () as tempdir :
37
+ print (f"generating synthetic data to { tempdir } (this may take a while)" )
38
+ for i in range (5 ):
39
+ im , seg = create_test_image_3d (128 , 128 , 128 , num_seg_classes = 1 , channel_dim = - 1 )
40
+
41
+ n = nib .Nifti1Image (im , np .eye (4 ))
42
+ nib .save (n , os .path .join (tempdir , f"im{ i :d} .nii.gz" ))
43
+
44
+ n = nib .Nifti1Image (seg , np .eye (4 ))
45
+ nib .save (n , os .path .join (tempdir , f"seg{ i :d} .nii.gz" ))
46
+
47
+ images = sorted (glob (os .path .join (tempdir , "im*.nii.gz" )))
48
+ segs = sorted (glob (os .path .join (tempdir , "seg*.nii.gz" )))
49
+ val_files = [{"img" : img , "seg" : seg } for img , seg in zip (images , segs )]
50
+
51
+ # define transforms for image and segmentation
52
+ val_transforms = Compose (
53
+ [
54
+ LoadNiftid (keys = ["img" , "seg" ]),
55
+ AsChannelFirstd (keys = ["img" , "seg" ], channel_dim = - 1 ),
56
+ ScaleIntensityd (keys = "img" ),
57
+ ToTensord (keys = ["img" , "seg" ]),
58
+ ]
59
+ )
60
+ val_ds = monai .data .Dataset (data = val_files , transform = val_transforms )
61
+ # sliding window inference need to input 1 image in every iteration
62
+ val_loader = DataLoader (val_ds , batch_size = 1 , num_workers = 4 , collate_fn = list_data_collate )
63
+ dice_metric = DiceMetric (include_background = True , to_onehot_y = False , sigmoid = True , reduction = "mean" )
64
+
65
+ # try to use all the available GPUs
66
+ devices = get_devices_spec (None )
67
+ model = UNet (
68
+ dimensions = 3 ,
69
+ in_channels = 1 ,
70
+ out_channels = 1 ,
71
+ channels = (16 , 32 , 64 , 128 , 256 ),
72
+ strides = (2 , 2 , 2 , 2 ),
73
+ num_res_units = 2 ,
74
+ ).to (devices [0 ])
75
+
76
+ model .load_state_dict (torch .load ("best_metric_model.pth" ))
77
+
78
+ # if we have multiple GPUs, set data parallel to execute sliding window inference
79
+ if len (devices ) > 1 :
80
+ model = torch .nn .DataParallel (model , device_ids = devices )
81
+
82
+ model .eval ()
83
+ with torch .no_grad ():
84
+ metric_sum = 0.0
85
+ metric_count = 0
86
+ saver = NiftiSaver (output_dir = "./output" )
87
+ for val_data in val_loader :
88
+ val_images , val_labels = val_data ["img" ].to (devices [0 ]), val_data ["seg" ].to (devices [0 ])
89
+ # define sliding window size and batch size for windows inference
90
+ roi_size = (96 , 96 , 96 )
91
+ sw_batch_size = 4
92
+ val_outputs = sliding_window_inference (val_images , roi_size , sw_batch_size , model )
93
+ value = dice_metric (y_pred = val_outputs , y = val_labels )
94
+ metric_count += len (value )
95
+ metric_sum += value .item () * len (value )
96
+ val_outputs = (val_outputs .sigmoid () >= 0.5 ).float ()
97
+ saver .save_batch (val_outputs , val_data ["img_meta_dict" ])
98
+ metric = metric_sum / metric_count
99
+ print ("evaluation metric:" , metric )
102
100
103
101
104
102
if __name__ == "__main__" :
0 commit comments