Skip to content

Commit 4bc3870

Browse files
committed
[DLMED] restore dist for compatible
Signed-off-by: Nic Ma <[email protected]>
1 parent 9794ebd commit 4bc3870

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

acceleration/distributed_training/brats_training_ddp.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,16 @@ def _generate_data_list(self, dataset_dir):
165165

166166

167167
def main_worker(args):
168-
local_rank = int(os.environ["LOCAL_RANK"])
169168
# disable logging for processes except 0 on every node
170-
if local_rank != 0:
169+
if args.local_rank != 0:
171170
f = open(os.devnull, "w")
172171
sys.stdout = sys.stderr = f
173172
if not os.path.exists(args.dir):
174173
raise FileNotFoundError(f"missing directory {args.dir}")
175174

176175
# initialize the distributed training process, every GPU runs in a process
177176
dist.init_process_group(backend="nccl", init_method="env://")
178-
device = torch.device(f"cuda:{local_rank}")
177+
device = torch.device(f"cuda:{args.local_rank}")
179178
torch.cuda.set_device(device)
180179
# use amp to accelerate training
181180
scaler = torch.cuda.amp.GradScaler()
@@ -369,6 +368,8 @@ def evaluate(model, val_loader, dice_metric, dice_metric_batch, post_trans):
369368
def main():
370369
parser = argparse.ArgumentParser()
371370
parser.add_argument("-d", "--dir", default="./testdata", type=str, help="directory of Brain Tumor dataset")
371+
# must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDP
372+
parser.add_argument("--local_rank", type=int, help="node rank for distributed training")
372373
parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run")
373374
parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
374375
parser.add_argument("-b", "--batch_size", default=1, type=int, help="mini-batch size of every GPU")
@@ -391,9 +392,12 @@ def main():
391392
main_worker(args=args)
392393

393394

394-
# usage example(refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py):
395+
# usage example(refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py):
395396

396-
# torchrun --standalone --nnodes=1 --nproc_per_node=NUM_GPUS_PER_NODE brats_training_ddp.py -d DIR_OF_TESTDATA
397+
# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
398+
# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
399+
# --master_addr="192.168.1.1" --master_port=1234
400+
# brats_training_ddp.py -d DIR_OF_TESTDATA
397401

398402
if __name__ == "__main__":
399403
main()

0 commit comments

Comments
 (0)