Skip to content

Commit d62043b

Browse files
yiheng-wang-nvwyli
andauthored
1077 Adjust handler to rank 0 for dynunet pipeline (#1078)
Fixes #1077 . ### Description This PR adjusts some handlers into rank 0 only for dynunet pipeline. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Notebook runs automatically `./runner [-p <regex_pattern>]` Co-authored-by: Wenqi Li <[email protected]>
1 parent e62f643 commit d62043b

File tree

1 file changed

+26
-42
lines changed

1 file changed

+26
-42
lines changed

modules/dynunet_pipeline/train.py

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,12 @@
22
import os
33
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
44

5+
import ignite.distributed as idist
56
import torch
67
import torch.distributed as dist
78
from monai.config import print_config
8-
from monai.handlers import (
9-
CheckpointSaver,
10-
LrScheduleHandler,
11-
MeanDice,
12-
StatsHandler,
13-
ValidationHandler,
14-
from_engine,
15-
)
9+
from monai.handlers import (CheckpointSaver, LrScheduleHandler, MeanDice,
10+
StatsHandler, ValidationHandler, from_engine)
1611
from monai.inferers import SimpleInferer, SlidingWindowInferer
1712
from monai.losses import DiceCELoss
1813
from monai.utils import set_determinism
@@ -91,6 +86,8 @@ def validation(args):
9186
"mean dice for label {} is {}".format(i + 1, results[:, i].mean())
9287
)
9388

89+
dist.destroy_process_group()
90+
9491

9592
def train(args):
9693
# load hyper parameters
@@ -151,12 +148,16 @@ def train(args):
151148
optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs) ** 0.9
152149
)
153150
# produce evaluator
154-
val_handlers = [
155-
StatsHandler(output_transform=lambda x: None),
156-
CheckpointSaver(
157-
save_dir=val_output_dir, save_dict={"net": net}, save_key_metric=True
158-
),
159-
]
151+
val_handlers = (
152+
[
153+
StatsHandler(output_transform=lambda x: None),
154+
CheckpointSaver(
155+
save_dir=val_output_dir, save_dict={"net": net}, save_key_metric=True
156+
),
157+
]
158+
if idist.get_rank() == 0
159+
else None
160+
)
160161

161162
evaluator = DynUNetEvaluator(
162163
device=device,
@@ -183,16 +184,18 @@ def train(args):
183184

184185
# produce trainer
185186
loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice)
186-
train_handlers = []
187+
train_handlers = [
188+
ValidationHandler(validator=evaluator, interval=interval, epoch_level=True)
189+
]
187190
if lr_decay_flag:
188191
train_handlers += [LrScheduleHandler(lr_scheduler=scheduler, print_lr=True)]
189-
190-
train_handlers += [
191-
ValidationHandler(validator=evaluator, interval=interval, epoch_level=True),
192-
StatsHandler(
193-
tag_name="train_loss", output_transform=from_engine(["loss"], first=True)
194-
),
195-
]
192+
if idist.get_rank() == 0:
193+
train_handlers += [
194+
StatsHandler(
195+
tag_name="train_loss",
196+
output_transform=from_engine(["loss"], first=True),
197+
)
198+
]
196199

197200
trainer = DynUNetTrainer(
198201
device=device,
@@ -212,27 +215,8 @@ def train(args):
212215
evaluator.logger.setLevel(logging.WARNING)
213216
trainer.logger.setLevel(logging.WARNING)
214217

215-
logger = logging.getLogger()
216-
217-
formatter = logging.Formatter(
218-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
219-
)
220-
221-
# Setup file handler
222-
fhandler = logging.FileHandler(log_filename)
223-
fhandler.setLevel(logging.INFO)
224-
fhandler.setFormatter(formatter)
225-
226-
logger.addHandler(fhandler)
227-
228-
chandler = logging.StreamHandler()
229-
chandler.setLevel(logging.INFO)
230-
chandler.setFormatter(formatter)
231-
logger.addHandler(chandler)
232-
233-
logger.setLevel(logging.INFO)
234-
235218
trainer.run()
219+
dist.destroy_process_group()
236220

237221

238222
if __name__ == "__main__":

0 commit comments

Comments
 (0)