17
17
from datetime import timedelta
18
18
import warnings
19
19
20
- warnings .simplefilter (' ignore' , UserWarning )
20
+ warnings .simplefilter (" ignore" , UserWarning )
21
21
22
22
import os
23
23
import sys
41
41
from monai .bundle import ConfigParser
42
42
from utils import binarize_labels
43
43
44
+
44
45
def setup_ddp (rank , world_size ):
45
46
print (f"Running DDP diffusion example on rank { rank } /world_size { world_size } ." )
46
47
print (f"Initing to IP { os .environ ['MASTER_ADDR' ]} " )
47
48
dist .init_process_group (
48
49
backend = "nccl" , init_method = "env://" , timeout = timedelta (seconds = 36000 ), rank = rank , world_size = world_size
49
- )
50
+ )
50
51
dist .barrier ()
51
52
device = torch .device (f"cuda:{ rank } " )
52
53
return dist , device
53
54
55
+
54
56
def define_instance (args , instance_def_key ):
55
57
parser = ConfigParser (vars (args ))
56
58
parser .parse (True )
57
59
return parser .get_parsed_content (instance_def_key , instantiate = True )
58
60
61
+
59
62
def add_data_dir2path (list_files , data_dir , fold = None ):
60
63
new_list_files = copy .deepcopy (list_files )
61
64
if fold is not None :
62
65
new_list_files_train = []
63
66
new_list_files_val = []
64
67
for d in new_list_files :
65
68
d ["image" ] = os .path .join (data_dir , d ["image" ])
66
-
69
+
67
70
if "label" in d :
68
71
d ["label" ] = os .path .join (data_dir , d ["label" ])
69
-
72
+
70
73
if fold is not None :
71
74
if d ["fold" ] == fold :
72
75
new_list_files_val .append (copy .deepcopy (d ))
73
76
else :
74
77
new_list_files_train .append (copy .deepcopy (d ))
75
-
78
+
76
79
if fold is not None :
77
80
return new_list_files_train , new_list_files_val
78
81
else :
79
82
return new_list_files , []
80
83
81
84
82
85
def prepare_maisi_controlnet_json_dataloader (
83
- args ,
84
- json_data_list ,
85
- data_base_dir ,
86
- batch_size = 1 ,
87
- fold = 0 ,
88
- cache_rate = 0.0 ,
89
- rank = 0 ,
90
- world_size = 1 ,
86
+ args ,
87
+ json_data_list ,
88
+ data_base_dir ,
89
+ batch_size = 1 ,
90
+ fold = 0 ,
91
+ cache_rate = 0.0 ,
92
+ rank = 0 ,
93
+ world_size = 1 ,
91
94
):
92
95
ddp_bool = world_size > 1
93
96
if isinstance (json_data_list , list ):
@@ -103,37 +106,32 @@ def prepare_maisi_controlnet_json_dataloader(
103
106
else :
104
107
with open (json_data_list , "r" ) as f :
105
108
json_data = json .load (f )
106
- list_train , list_valid = add_data_dir2path (json_data [' training' ], data_base_dir , fold )
109
+ list_train , list_valid = add_data_dir2path (json_data [" training" ], data_base_dir , fold )
107
110
108
111
common_transform = [
109
112
LoadImaged (keys = ["image" , "label" ], image_only = True , ensure_channel_first = True ),
110
113
Orientationd (keys = ["label" ], axcodes = "RAS" ),
111
114
EnsureTyped (keys = ["label" ], dtype = torch .uint8 , track_meta = True ),
112
- Lambdad (keys = ' top_region_index' , func = lambda x : torch .FloatTensor (x )),
113
- Lambdad (keys = ' bottom_region_index' , func = lambda x : torch .FloatTensor (x )),
114
- Lambdad (keys = ' spacing' , func = lambda x : torch .FloatTensor (x )),
115
- Lambdad (keys = ' top_region_index' , func = lambda x : x * 1e2 ),
116
- Lambdad (keys = ' bottom_region_index' , func = lambda x : x * 1e2 ),
117
- Lambdad (keys = ' spacing' , func = lambda x : x * 1e2 ),
115
+ Lambdad (keys = " top_region_index" , func = lambda x : torch .FloatTensor (x )),
116
+ Lambdad (keys = " bottom_region_index" , func = lambda x : torch .FloatTensor (x )),
117
+ Lambdad (keys = " spacing" , func = lambda x : torch .FloatTensor (x )),
118
+ Lambdad (keys = " top_region_index" , func = lambda x : x * 1e2 ),
119
+ Lambdad (keys = " bottom_region_index" , func = lambda x : x * 1e2 ),
120
+ Lambdad (keys = " spacing" , func = lambda x : x * 1e2 ),
118
121
]
119
122
train_transforms , val_transforms = Compose (common_transform ), Compose (common_transform )
120
123
121
124
train_loader = None
122
-
125
+
123
126
if ddp_bool :
124
127
list_train = partition_dataset (
125
128
data = list_train ,
126
129
shuffle = True ,
127
130
num_partitions = world_size ,
128
131
even_divisible = True ,
129
132
)[rank ]
130
- train_ds = CacheDataset (
131
- data = list_train , transform = train_transforms , cache_rate = cache_rate , num_workers = 8
132
- )
133
- train_loader = DataLoader (
134
- train_ds , batch_size = batch_size , shuffle = True ,
135
- num_workers = 8 , pin_memory = True
136
- )
133
+ train_ds = CacheDataset (data = list_train , transform = train_transforms , cache_rate = cache_rate , num_workers = 8 )
134
+ train_loader = DataLoader (train_ds , batch_size = batch_size , shuffle = True , num_workers = 8 , pin_memory = True )
137
135
if ddp_bool :
138
136
list_valid = partition_dataset (
139
137
data = list_valid ,
@@ -142,12 +140,14 @@ def prepare_maisi_controlnet_json_dataloader(
142
140
even_divisible = False ,
143
141
)[rank ]
144
142
val_ds = CacheDataset (
145
- data = list_valid , transform = val_transforms , cache_rate = cache_rate , num_workers = 8 ,
146
- )
147
- val_loader = DataLoader (
148
- val_ds , batch_size = batch_size , shuffle = False , num_workers = 2 , pin_memory = False
143
+ data = list_valid ,
144
+ transform = val_transforms ,
145
+ cache_rate = cache_rate ,
146
+ num_workers = 8 ,
149
147
)
150
- return train_loader , val_loader
148
+ val_loader = DataLoader (val_ds , batch_size = batch_size , shuffle = False , num_workers = 2 , pin_memory = False )
149
+ return train_loader , val_loader
150
+
151
151
152
152
def main ():
153
153
parser = argparse .ArgumentParser (description = "PyTorch VAE-GAN training" )
@@ -164,7 +164,14 @@ def main():
164
164
help = "config json file that stores hyper-parameters" ,
165
165
)
166
166
parser .add_argument ("-g" , "--gpus" , default = 1 , type = int , help = "number of gpus per node" )
167
- parser .add_argument ("-w" , "--weighted_loss_label" , nargs = '+' , default = [], action = "store_true" , help = "list of lables that use weighted loss" )
167
+ parser .add_argument (
168
+ "-w" ,
169
+ "--weighted_loss_label" ,
170
+ nargs = "+" ,
171
+ default = [],
172
+ action = "store_true" ,
173
+ help = "list of lables that use weighted loss" ,
174
+ )
168
175
parser .add_argument ("-l" , "--weighted_loss" , default = 100 , type = int , help = "loss weight loss for ROI labels" )
169
176
args = parser .parse_args ()
170
177
@@ -189,7 +196,7 @@ def main():
189
196
setattr (args , k , v )
190
197
for k , v in config_dict .items ():
191
198
setattr (args , k , v )
192
-
199
+
193
200
# initialize tensorboard writer
194
201
if rank == 0 :
195
202
Path (args .tfevent_path ).mkdir (parents = True , exist_ok = True )
@@ -199,13 +206,13 @@ def main():
199
206
# Step 1: set data loader
200
207
train_loader , val_loader = prepare_maisi_controlnet_json_dataloader (
201
208
args ,
202
- json_data_list = args .json_data_list ,
203
- data_base_dir = args .data_base_dir ,
209
+ json_data_list = args .json_data_list ,
210
+ data_base_dir = args .data_base_dir ,
204
211
rank = rank ,
205
212
world_size = world_size ,
206
213
batch_size = args .controlnet_train ["batch_size" ],
207
214
cache_rate = args .controlnet_train ["cache_rate" ],
208
- fold = args .controlnet_train ["fold" ]
215
+ fold = args .controlnet_train ["fold" ],
209
216
)
210
217
211
218
# Step 2: define diffusion model and controlnet
@@ -235,16 +242,16 @@ def main():
235
242
noise_scheduler = define_instance (args , "noise_scheduler" )
236
243
237
244
if ddp_bool :
238
- controlnet = DDP (controlnet , device_ids = [device ], output_device = rank , find_unused_parameters = True )
245
+ controlnet = DDP (controlnet , device_ids = [device ], output_device = rank , find_unused_parameters = True )
239
246
240
247
# Step 3: training config
241
248
optimizer = torch .optim .AdamW (params = controlnet .parameters (), lr = args .controlnet_train ["lr" ])
242
249
total_steps = (args .controlnet_train ["n_epochs" ] * len (train_loader .dataset )) / args .controlnet_train ["batch_size" ]
243
- if rank == 0 :
250
+ if rank == 0 :
244
251
print (f"total number of training steps: { total_steps } ." )
245
252
246
253
lr_scheduler = torch .optim .lr_scheduler .PolynomialLR (optimizer , total_iters = total_steps , power = 2.0 )
247
-
254
+
248
255
# Step 4: training
249
256
n_epochs = args .controlnet_train ["n_epochs" ]
250
257
scaler = GradScaler ()
@@ -259,52 +266,54 @@ def main():
259
266
epoch_loss_ = 0
260
267
261
268
for step , batch in enumerate (train_loader ):
262
- # get image embedding and label mask
269
+ # get image embedding and label mask
263
270
inputs = batch ["image" ].to (device )
264
271
labels = batch ["label" ].to (device )
265
272
# get coresponding condtions
266
- top_region_index_tensor = batch [' top_region_index' ].to (device )
267
- bottom_region_index_tensor = batch [' bottom_region_index' ].to (device )
268
- spacing_tensor = batch [' spacing' ].to (device )
269
-
273
+ top_region_index_tensor = batch [" top_region_index" ].to (device )
274
+ bottom_region_index_tensor = batch [" bottom_region_index" ].to (device )
275
+ spacing_tensor = batch [" spacing" ].to (device )
276
+
270
277
optimizer .zero_grad (set_to_none = True )
271
278
272
279
with autocast (enabled = True ):
273
280
# generate random noise
274
281
noise_shape = list (inputs .shape )
275
282
noise = torch .randn (noise_shape , dtype = inputs .dtype ).to (inputs .device )
276
-
283
+
277
284
# use binary encoding to encode segmentation mask
278
285
controlnet_cond = binarize_labels (labels .as_tensor ().to (torch .uint8 )).float ()
279
-
286
+
280
287
# create timesteps
281
288
timesteps = torch .randint (
282
- 0 , noise_scheduler .num_train_timesteps , (inputs .shape [0 ],), device = inputs .device
283
- ).long ()
284
-
289
+ 0 , noise_scheduler .num_train_timesteps , (inputs .shape [0 ],), device = inputs .device
290
+ ).long ()
291
+
285
292
# create noisy latent
286
293
noisy_latent = noise_scheduler .add_noise (original_samples = inputs , noise = noise , timesteps = timesteps )
287
-
294
+
288
295
# get controlnet output
289
296
down_block_res_samples , mid_block_res_sample = controlnet (
290
297
x = noisy_latent , timesteps = timesteps , controlnet_cond = controlnet_cond
291
298
)
292
299
# get noise prediction from diffusion unet
293
- noise_pred = unet (x = noisy_latent ,
294
- timesteps = timesteps ,
295
- top_region_index_tensor = top_region_index_tensor ,
296
- bottom_region_index_tensor = bottom_region_index_tensor ,
297
- spacing_tensor = spacing_tensor ,
298
- down_block_additional_residuals = down_block_res_samples ,
299
- mid_block_additional_residual = mid_block_res_sample )
300
+ noise_pred = unet (
301
+ x = noisy_latent ,
302
+ timesteps = timesteps ,
303
+ top_region_index_tensor = top_region_index_tensor ,
304
+ bottom_region_index_tensor = bottom_region_index_tensor ,
305
+ spacing_tensor = spacing_tensor ,
306
+ down_block_additional_residuals = down_block_res_samples ,
307
+ mid_block_additional_residual = mid_block_res_sample ,
308
+ )
300
309
301
310
if args .weighted_loss > 1.0 :
302
311
weights = torch .ones_like (inputs ).to (inputs .device )
303
312
roi = torch .zeros ([noise_shape [0 ]] + [1 ] + noise_shape [2 :]).to (inputs .device )
304
313
interpolate_label = F .interpolate (labels , size = inputs .shape [2 :], mode = "nearest" )
305
314
# assign larger weights for ROI (tumor)
306
315
for label in args .weighted_loss_label :
307
- roi [interpolate_label == label ] = 1
316
+ roi [interpolate_label == label ] = 1
308
317
weights [roi .repeat (1 , inputs .shape [1 ], 1 , 1 , 1 ) == 1 ] = args .weighted_loss
309
318
loss = (F .l1_loss (noise_pred .float (), noise .float (), reduction = "none" ) * weights ).mean ()
310
319
else :
@@ -330,8 +339,7 @@ def main():
330
339
n_epochs ,
331
340
step ,
332
341
len (train_loader ),
333
- lr_scheduler .get_last_lr ()[0 ] if lr_scheduler is not None else optimizer .param_groups [0 ][
334
- 'lr' ],
342
+ lr_scheduler .get_last_lr ()[0 ] if lr_scheduler is not None else optimizer .param_groups [0 ]["lr" ],
335
343
loss .detach ().cpu ().item (),
336
344
time_left ,
337
345
)
@@ -348,7 +356,6 @@ def main():
348
356
tensorboard_writer .add_scalar ("train/train_diffusion_loss_epoch" , epoch_loss .cpu ().item (), total_step )
349
357
350
358
torch .cuda .empty_cache ()
351
-
352
359
353
360
354
361
if __name__ == "__main__" :
0 commit comments