2
2
import os
3
3
from argparse import ArgumentDefaultsHelpFormatter , ArgumentParser
4
4
5
+ import ignite .distributed as idist
5
6
import torch
6
7
import torch .distributed as dist
7
8
from monai .config import print_config
8
- from monai .handlers import (
9
- CheckpointSaver ,
10
- LrScheduleHandler ,
11
- MeanDice ,
12
- StatsHandler ,
13
- ValidationHandler ,
14
- from_engine ,
15
- )
9
+ from monai .handlers import (CheckpointSaver , LrScheduleHandler , MeanDice ,
10
+ StatsHandler , ValidationHandler , from_engine )
16
11
from monai .inferers import SimpleInferer , SlidingWindowInferer
17
12
from monai .losses import DiceCELoss
18
13
from monai .utils import set_determinism
@@ -91,6 +86,8 @@ def validation(args):
91
86
"mean dice for label {} is {}" .format (i + 1 , results [:, i ].mean ())
92
87
)
93
88
89
+ dist .destroy_process_group ()
90
+
94
91
95
92
def train (args ):
96
93
# load hyper parameters
@@ -151,12 +148,16 @@ def train(args):
151
148
optimizer , lr_lambda = lambda epoch : (1 - epoch / max_epochs ) ** 0.9
152
149
)
153
150
# produce evaluator
154
- val_handlers = [
155
- StatsHandler (output_transform = lambda x : None ),
156
- CheckpointSaver (
157
- save_dir = val_output_dir , save_dict = {"net" : net }, save_key_metric = True
158
- ),
159
- ]
151
+ val_handlers = (
152
+ [
153
+ StatsHandler (output_transform = lambda x : None ),
154
+ CheckpointSaver (
155
+ save_dir = val_output_dir , save_dict = {"net" : net }, save_key_metric = True
156
+ ),
157
+ ]
158
+ if idist .get_rank () == 0
159
+ else None
160
+ )
160
161
161
162
evaluator = DynUNetEvaluator (
162
163
device = device ,
@@ -183,16 +184,18 @@ def train(args):
183
184
184
185
# produce trainer
185
186
loss = DiceCELoss (to_onehot_y = True , softmax = True , batch = batch_dice )
186
- train_handlers = []
187
+ train_handlers = [
188
+ ValidationHandler (validator = evaluator , interval = interval , epoch_level = True )
189
+ ]
187
190
if lr_decay_flag :
188
191
train_handlers += [LrScheduleHandler (lr_scheduler = scheduler , print_lr = True )]
189
-
190
- train_handlers += [
191
- ValidationHandler ( validator = evaluator , interval = interval , epoch_level = True ),
192
- StatsHandler (
193
- tag_name = "train_loss" , output_transform = from_engine (["loss" ], first = True )
194
- ),
195
- ]
192
+ if idist . get_rank () == 0 :
193
+ train_handlers += [
194
+ StatsHandler (
195
+ tag_name = "train_loss" ,
196
+ output_transform = from_engine (["loss" ], first = True ),
197
+ )
198
+ ]
196
199
197
200
trainer = DynUNetTrainer (
198
201
device = device ,
@@ -212,27 +215,8 @@ def train(args):
212
215
evaluator .logger .setLevel (logging .WARNING )
213
216
trainer .logger .setLevel (logging .WARNING )
214
217
215
- logger = logging .getLogger ()
216
-
217
- formatter = logging .Formatter (
218
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
219
- )
220
-
221
- # Setup file handler
222
- fhandler = logging .FileHandler (log_filename )
223
- fhandler .setLevel (logging .INFO )
224
- fhandler .setFormatter (formatter )
225
-
226
- logger .addHandler (fhandler )
227
-
228
- chandler = logging .StreamHandler ()
229
- chandler .setLevel (logging .INFO )
230
- chandler .setFormatter (formatter )
231
- logger .addHandler (chandler )
232
-
233
- logger .setLevel (logging .INFO )
234
-
235
218
trainer .run ()
219
+ dist .destroy_process_group ()
236
220
237
221
238
222
if __name__ == "__main__" :
0 commit comments