Skip to content

Commit 55d4639

Browse files
authored
[DLMED] update compute metric for MetaTensor (#780)
Signed-off-by: Nic Ma <[email protected]>
1 parent e2afddf commit 55d4639

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

modules/compute_metric.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
import nibabel as nib
4747
import numpy as np
48+
import torch
4849
import torch.distributed as dist
4950

5051
from monai.data import create_test_image_3d, partition_dataset
@@ -57,7 +58,7 @@
5758
KeepLargestConnectedComponentd,
5859
LoadImaged,
5960
ScaleIntensityd,
60-
EnsureTyped,
61+
ToDeviced,
6162
)
6263
from monai.utils import string_list_all_gather
6364

@@ -77,8 +78,8 @@ def compute(args):
7778
n = nib.Nifti1Image(label, np.eye(4))
7879
nib.save(n, os.path.join(args.dir, f"label{i:d}.nii.gz"))
7980

80-
# initialize the distributed evaluation process, change to NCCL backend if computing on GPU
81-
dist.init_process_group(backend="gloo", init_method="env://")
81+
# initialize the distributed evaluation process, change to gloo backend if computing on CPU
82+
dist.init_process_group(backend="nccl", init_method="env://")
8283

8384
preds = sorted(glob(os.path.join(args.dir, "pred*.nii.gz")))
8485
labels = sorted(glob(os.path.join(args.dir, "label*.nii.gz")))
@@ -92,13 +93,15 @@ def compute(args):
9293
even_divisible=False,
9394
)[dist.get_rank()]
9495

96+
device = torch.device(f"cuda:{args.local_rank}")
97+
torch.cuda.set_device(device)
9598
# define transforms for predictions and labels
9699
transforms = Compose(
97100
[
98101
LoadImaged(keys=["pred", "label"]),
102+
ToDeviced(keys=["pred", "label"], device=device),
99103
EnsureChannelFirstd(keys=["pred", "label"]),
100104
ScaleIntensityd(keys="pred"),
101-
EnsureTyped(keys=["pred", "label"]),
102105
AsDiscreted(keys="pred", threshold=0.5),
103106
KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
104107
]
@@ -129,6 +132,13 @@ def compute(args):
129132
dist.destroy_process_group()
130133

131134

135+
# usage example(refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py):
136+
137+
# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
138+
# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
139+
# --master_addr="192.168.1.1" --master_port=1234
140+
# compute_metric.py -d DIR_OF_OUTPUT
141+
132142
def main():
133143
parser = argparse.ArgumentParser()
134144
parser.add_argument("-d", "--dir", default="./output", type=str, help="root directory of labels and predictions.")

0 commit comments

Comments
 (0)