@@ -165,17 +165,16 @@ def _generate_data_list(self, dataset_dir):
165
165
166
166
167
167
def main_worker (args ):
168
- local_rank = int (os .environ ["LOCAL_RANK" ])
169
168
# disable logging for processes except 0 on every node
170
- if local_rank != 0 :
169
+ if args . local_rank != 0 :
171
170
f = open (os .devnull , "w" )
172
171
sys .stdout = sys .stderr = f
173
172
if not os .path .exists (args .dir ):
174
173
raise FileNotFoundError (f"missing directory { args .dir } " )
175
174
176
175
# initialize the distributed training process, every GPU runs in a process
177
176
dist .init_process_group (backend = "nccl" , init_method = "env://" )
178
- device = torch .device (f"cuda:{ local_rank } " )
177
+ device = torch .device (f"cuda:{ args . local_rank } " )
179
178
torch .cuda .set_device (device )
180
179
# use amp to accelerate training
181
180
scaler = torch .cuda .amp .GradScaler ()
@@ -369,6 +368,8 @@ def evaluate(model, val_loader, dice_metric, dice_metric_batch, post_trans):
369
368
def main ():
370
369
parser = argparse .ArgumentParser ()
371
370
parser .add_argument ("-d" , "--dir" , default = "./testdata" , type = str , help = "directory of Brain Tumor dataset" )
371
+ # must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDP
372
+ parser .add_argument ("--local_rank" , type = int , help = "node rank for distributed training" )
372
373
parser .add_argument ("--epochs" , default = 300 , type = int , metavar = "N" , help = "number of total epochs to run" )
373
374
parser .add_argument ("--lr" , default = 1e-4 , type = float , help = "learning rate" )
374
375
parser .add_argument ("-b" , "--batch_size" , default = 1 , type = int , help = "mini-batch size of every GPU" )
@@ -391,9 +392,12 @@ def main():
391
392
main_worker (args = args )
392
393
393
394
394
- # usage example(refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/run .py):
395
+ # usage example(refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch .py):
395
396
396
- # torchrun --standalone --nnodes=1 --nproc_per_node=NUM_GPUS_PER_NODE brats_training_ddp.py -d DIR_OF_TESTDATA
397
+ # python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
398
+ # --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
399
+ # --master_addr="192.168.1.1" --master_port=1234
400
+ # brats_training_ddp.py -d DIR_OF_TESTDATA
397
401
398
402
if __name__ == "__main__" :
399
403
main ()
0 commit comments