Skip to content

Commit ca76573

Browse files
committed
add training script
Signed-off-by: Pengfei Guo <[email protected]>
1 parent 650651c commit ca76573

File tree

1 file changed

+361
-0
lines changed

1 file changed

+361
-0
lines changed
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import argparse
13+
import json
14+
import logging
15+
from pathlib import Path
16+
import time
17+
from datetime import timedelta
18+
import warnings
19+
20+
warnings.simplefilter('ignore', UserWarning)
21+
22+
import os
23+
import sys
24+
import copy
25+
import torch
26+
import torch.distributed as dist
27+
import torch.nn.functional as F
28+
from monai.config import print_config
29+
from monai.utils import first, set_determinism
30+
from monai.data import DataLoader, CacheDataset, partition_dataset
31+
from torch.cuda.amp import GradScaler, autocast
32+
from torch.nn.parallel import DistributedDataParallel as DDP
33+
from torch.utils.tensorboard import SummaryWriter
34+
from monai.transforms import (
35+
Compose,
36+
EnsureTyped,
37+
Lambdad,
38+
LoadImaged,
39+
Orientationd,
40+
)
41+
from monai.bundle import ConfigParser
42+
from utils import binarize_labels
43+
44+
def setup_ddp(rank, world_size):
45+
print(f"Running DDP diffusion example on rank {rank}/world_size {world_size}.")
46+
print(f"Initing to IP {os.environ['MASTER_ADDR']}")
47+
dist.init_process_group(
48+
backend="nccl", init_method="env://", timeout=timedelta(seconds=36000), rank=rank, world_size=world_size
49+
)
50+
dist.barrier()
51+
device = torch.device(f"cuda:{rank}")
52+
return dist, device
53+
54+
def define_instance(args, instance_def_key):
55+
parser = ConfigParser(vars(args))
56+
parser.parse(True)
57+
return parser.get_parsed_content(instance_def_key, instantiate=True)
58+
59+
def add_data_dir2path(list_files, data_dir, fold=None):
60+
new_list_files = copy.deepcopy(list_files)
61+
if fold is not None:
62+
new_list_files_train = []
63+
new_list_files_val = []
64+
for d in new_list_files:
65+
d["image"] = os.path.join(data_dir, d["image"])
66+
67+
if "label" in d:
68+
d["label"] = os.path.join(data_dir, d["label"])
69+
70+
if fold is not None:
71+
if d["fold"] == fold:
72+
new_list_files_val.append(copy.deepcopy(d))
73+
else:
74+
new_list_files_train.append(copy.deepcopy(d))
75+
76+
if fold is not None:
77+
return new_list_files_train, new_list_files_val
78+
else:
79+
return new_list_files, []
80+
81+
82+
def prepare_maisi_controlnet_json_dataloader(
83+
args,
84+
json_data_list,
85+
data_base_dir,
86+
batch_size=1,
87+
fold=0,
88+
cache_rate=0.0,
89+
rank=0,
90+
world_size=1,
91+
):
92+
ddp_bool = world_size > 1
93+
if isinstance(json_data_list, list):
94+
assert isinstance(args.data_base_dir, list)
95+
list_train = []
96+
list_valid = []
97+
for data_list, data_base_dir in zip(json_data_list, data_base_dir):
98+
with open(data_list, "r") as f:
99+
json_data = json.load(f)
100+
train, val = add_data_dir2path(json_data, data_base_dir, fold)
101+
list_train += train
102+
list_valid += val
103+
else:
104+
with open(json_data_list, "r") as f:
105+
json_data = json.load(f)
106+
list_train, list_valid = add_data_dir2path(json_data['training'], data_base_dir, fold)
107+
108+
common_transform = [
109+
LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True),
110+
Orientationd(keys=["label"], axcodes="RAS"),
111+
EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True),
112+
Lambdad(keys='top_region_index', func=lambda x: torch.FloatTensor(x)),
113+
Lambdad(keys='bottom_region_index', func=lambda x: torch.FloatTensor(x)),
114+
Lambdad(keys='spacing', func=lambda x: torch.FloatTensor(x)),
115+
Lambdad(keys='top_region_index', func=lambda x: x * 1e2),
116+
Lambdad(keys='bottom_region_index', func=lambda x: x * 1e2),
117+
Lambdad(keys='spacing', func=lambda x: x * 1e2),
118+
]
119+
train_transforms, val_transforms = Compose(common_transform), Compose(common_transform)
120+
121+
train_loader = None
122+
123+
if ddp_bool:
124+
list_train = partition_dataset(
125+
data=list_train,
126+
shuffle=True,
127+
num_partitions=world_size,
128+
even_divisible=True,
129+
)[rank]
130+
train_ds = CacheDataset(
131+
data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8
132+
)
133+
train_loader = DataLoader(
134+
train_ds, batch_size=batch_size, shuffle=True,
135+
num_workers=8, pin_memory=True
136+
)
137+
if ddp_bool:
138+
list_valid = partition_dataset(
139+
data=list_valid,
140+
shuffle=True,
141+
num_partitions=world_size,
142+
even_divisible=False,
143+
)[rank]
144+
val_ds = CacheDataset(
145+
data=list_valid, transform=val_transforms, cache_rate=cache_rate, num_workers=8,
146+
)
147+
val_loader = DataLoader(
148+
val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False
149+
)
150+
return train_loader, val_loader
151+
152+
def main():
153+
parser = argparse.ArgumentParser(description="PyTorch VAE-GAN training")
154+
parser.add_argument(
155+
"-e",
156+
"--environment-file",
157+
default="./config/environment_maisi_controlnet_train.json",
158+
help="environment json file that stores environment path",
159+
)
160+
parser.add_argument(
161+
"-c",
162+
"--config-file",
163+
default="./config/config_maisi_controlnet_train.json",
164+
help="config json file that stores hyper-parameters",
165+
)
166+
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
167+
parser.add_argument("-w", "--weighted_loss_label", nargs='+', default=[], action="store_true", help="list of lables that use weighted loss")
168+
parser.add_argument("-l", "--weighted_loss", default=100, type=int, help="loss weight loss for ROI labels")
169+
args = parser.parse_args()
170+
171+
# Step 0: configuration
172+
ddp_bool = args.gpus > 1 # whether to use distributed data parallel
173+
if ddp_bool:
174+
rank = int(os.environ["LOCAL_RANK"])
175+
world_size = int(os.environ["WORLD_SIZE"])
176+
dist, device = setup_ddp(rank, world_size)
177+
else:
178+
rank = 0
179+
world_size = 1
180+
device = 0
181+
182+
torch.cuda.set_device(device)
183+
print(f"Using {device}")
184+
185+
env_dict = json.load(open(args.environment_file, "r"))
186+
config_dict = json.load(open(args.config_file, "r"))
187+
188+
for k, v in env_dict.items():
189+
setattr(args, k, v)
190+
for k, v in config_dict.items():
191+
setattr(args, k, v)
192+
193+
# initialize tensorboard writer
194+
if rank == 0:
195+
Path(args.tfevent_path).mkdir(parents=True, exist_ok=True)
196+
tensorboard_path = os.path.join(args.output_dir, "controlnet_tfevent")
197+
tensorboard_writer = SummaryWriter(tensorboard_path)
198+
199+
# Step 1: set data loader
200+
train_loader, val_loader = prepare_maisi_controlnet_json_dataloader(
201+
args,
202+
json_data_list = args.json_data_list,
203+
data_base_dir = args.data_base_dir,
204+
rank=rank,
205+
world_size=world_size,
206+
batch_size=args.controlnet_train["batch_size"],
207+
cache_rate=args.controlnet_train["cache_rate"],
208+
fold=args.controlnet_train["fold"]
209+
)
210+
211+
# Step 2: define diffusion model and controlnet
212+
213+
# define diffusion Model
214+
unet = define_instance(args, "difusion_unet_def").to(device)
215+
# load trained diffusion model
216+
map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
217+
unet.load_state_dict(torch.load(args.trained_diffusion_path, map_location=map_location))
218+
if rank == 0:
219+
print(f"Load trained diffusion model from", args.trained_diffusion_path)
220+
# define ControlNet
221+
controlnet = define_instance(args, "controlnet_def").to(device)
222+
# copy weights from the DM to the controlnet
223+
controlnet.load_state_dict(unet.state_dict(), strict=False)
224+
if args.trained_controlnet_path is not None:
225+
controlnet.load_state_dict(torch.load(args.trained_controlnet_path, map_location=map_location))
226+
if rank == 0:
227+
print(f"load trained controlnet model from", args.trained_controlnet_path)
228+
else:
229+
if rank == 0:
230+
print("train controlnet model from scratch.")
231+
# we freeze the parameters of the diffusion model.
232+
for p in unet.parameters():
233+
p.requires_grad = False
234+
235+
noise_scheduler = define_instance(args, "noise_scheduler")
236+
237+
if ddp_bool:
238+
controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True)
239+
240+
# Step 3: training config
241+
optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"])
242+
total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"]
243+
if rank ==0:
244+
print(f"total number of training steps: {total_steps}.")
245+
246+
lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0)
247+
248+
# Step 4: training
249+
n_epochs = args.controlnet_train["n_epochs"]
250+
scaler = GradScaler()
251+
total_step = 0
252+
253+
if args.weighted_loss > 0 and rank == 0:
254+
print(f"apply weighted loss = {args.weighted_loss} on labels: {args.weighted_loss_label}")
255+
256+
prev_time = time.time()
257+
for epoch in range(n_epochs):
258+
unet.train()
259+
epoch_loss_ = 0
260+
261+
for step, batch in enumerate(train_loader):
262+
# get image embedding and label mask
263+
inputs = batch["image"].to(device)
264+
labels = batch["label"].to(device)
265+
# get coresponding condtions
266+
top_region_index_tensor = batch['top_region_index'].to(device)
267+
bottom_region_index_tensor = batch['bottom_region_index'].to(device)
268+
spacing_tensor = batch['spacing'].to(device)
269+
270+
optimizer.zero_grad(set_to_none=True)
271+
272+
with autocast(enabled=True):
273+
# generate random noise
274+
noise_shape = list(inputs.shape)
275+
noise = torch.randn(noise_shape, dtype=inputs.dtype).to(inputs.device)
276+
277+
# use binary encoding to encode segmentation mask
278+
controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float()
279+
280+
# create timesteps
281+
timesteps = torch.randint(
282+
0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=inputs.device
283+
).long()
284+
285+
# create noisy latent
286+
noisy_latent = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
287+
288+
# get controlnet output
289+
down_block_res_samples, mid_block_res_sample = controlnet(
290+
x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond
291+
)
292+
# get noise prediction from diffusion unet
293+
noise_pred = unet(x=noisy_latent,
294+
timesteps=timesteps,
295+
top_region_index_tensor=top_region_index_tensor,
296+
bottom_region_index_tensor=bottom_region_index_tensor,
297+
spacing_tensor=spacing_tensor,
298+
down_block_additional_residuals=down_block_res_samples,
299+
mid_block_additional_residual=mid_block_res_sample)
300+
301+
if args.weighted_loss > 1.0:
302+
weights = torch.ones_like(inputs).to(inputs.device)
303+
roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device)
304+
interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest")
305+
# assign larger weights for ROI (tumor)
306+
for label in args.weighted_loss_label:
307+
roi[interpolate_label==label] = 1
308+
weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = args.weighted_loss
309+
loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean()
310+
else:
311+
loss = F.l1_loss(noise_pred.float(), noise.float())
312+
313+
scaler.scale(loss).backward()
314+
scaler.step(optimizer)
315+
scaler.update()
316+
lr_scheduler.step()
317+
318+
if rank == 0:
319+
# write train loss for each batch into tensorboard
320+
total_step += 1
321+
tensorboard_writer.add_scalar("train/train_diffusion_loss_iter", loss.detach().cpu().item(), total_step)
322+
batches_done = step
323+
batches_left = len(train_loader) - batches_done
324+
time_left = timedelta(seconds=batches_left * (time.time() - prev_time))
325+
prev_time = time.time()
326+
sys.stdout.write(
327+
"\r[Epoch %d/%d] [Batch %d/%d] [LR: %f] [loss: %04f] ETA: %s "
328+
% (
329+
epoch,
330+
n_epochs,
331+
step,
332+
len(train_loader),
333+
lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else optimizer.param_groups[0][
334+
'lr'],
335+
loss.detach().cpu().item(),
336+
time_left,
337+
)
338+
)
339+
epoch_loss_ += loss.detach()
340+
341+
epoch_loss = epoch_loss_ / (step + 1)
342+
343+
if ddp_bool:
344+
dist.barrier()
345+
dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG)
346+
347+
if rank == 0:
348+
tensorboard_writer.add_scalar("train/train_diffusion_loss_epoch", epoch_loss.cpu().item(), total_step)
349+
350+
torch.cuda.empty_cache()
351+
352+
353+
354+
if __name__ == "__main__":
355+
logging.basicConfig(
356+
stream=sys.stdout,
357+
level=logging.INFO,
358+
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
359+
datefmt="%Y-%m-%d %H:%M:%S",
360+
)
361+
main()

0 commit comments

Comments
 (0)