45
45
46
46
import nibabel as nib
47
47
import numpy as np
48
+ import torch
48
49
import torch .distributed as dist
49
50
50
51
from monai .data import create_test_image_3d , partition_dataset
57
58
KeepLargestConnectedComponentd ,
58
59
LoadImaged ,
59
60
ScaleIntensityd ,
60
- EnsureTyped ,
61
+ ToDeviced ,
61
62
)
62
63
from monai .utils import string_list_all_gather
63
64
@@ -77,8 +78,8 @@ def compute(args):
77
78
n = nib .Nifti1Image (label , np .eye (4 ))
78
79
nib .save (n , os .path .join (args .dir , f"label{ i :d} .nii.gz" ))
79
80
80
- # initialize the distributed evaluation process, change to NCCL backend if computing on GPU
81
- dist .init_process_group (backend = "gloo " , init_method = "env://" )
81
+ # initialize the distributed evaluation process, change to gloo backend if computing on CPU
82
+ dist .init_process_group (backend = "nccl " , init_method = "env://" )
82
83
83
84
preds = sorted (glob (os .path .join (args .dir , "pred*.nii.gz" )))
84
85
labels = sorted (glob (os .path .join (args .dir , "label*.nii.gz" )))
@@ -92,13 +93,15 @@ def compute(args):
92
93
even_divisible = False ,
93
94
)[dist .get_rank ()]
94
95
96
+ device = torch .device (f"cuda:{ args .local_rank } " )
97
+ torch .cuda .set_device (device )
95
98
# define transforms for predictions and labels
96
99
transforms = Compose (
97
100
[
98
101
LoadImaged (keys = ["pred" , "label" ]),
102
+ ToDeviced (keys = ["pred" , "label" ], device = device ),
99
103
EnsureChannelFirstd (keys = ["pred" , "label" ]),
100
104
ScaleIntensityd (keys = "pred" ),
101
- EnsureTyped (keys = ["pred" , "label" ]),
102
105
AsDiscreted (keys = "pred" , threshold = 0.5 ),
103
106
KeepLargestConnectedComponentd (keys = "pred" , applied_labels = [1 ]),
104
107
]
@@ -129,6 +132,13 @@ def compute(args):
129
132
dist .destroy_process_group ()
130
133
131
134
135
+ # usage example(refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py):
136
+
137
+ # python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
138
+ # --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
139
+ # --master_addr="192.168.1.1" --master_port=1234
140
+ # compute_metric.py -d DIR_OF_OUTPUT
141
+
132
142
def main ():
133
143
parser = argparse .ArgumentParser ()
134
144
parser .add_argument ("-d" , "--dir" , default = "./output" , type = str , help = "root directory of labels and predictions." )
0 commit comments