Skip to content

Commit d4d14e6

Browse files
guopengfpre-commit-ci[bot]ericspodKumoLiumingxin-zheng
authored
Add maisi controlnet training (#1750)
Fixes # . ### Description Add training script for MAISI ControlNet. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: Pengfei Guo <[email protected]> Signed-off-by: Eric Kerfoot <[email protected]> Signed-off-by: Pengfei Guo <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Mingxin Zheng <[email protected]>
1 parent bc063ad commit d4d14e6

File tree

4 files changed

+538
-1
lines changed

4 files changed

+538
-1
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
{
2+
"random_seed": null,
3+
"spatial_dims": 3,
4+
"image_channels": 1,
5+
"latent_channels": 4,
6+
"diffusion_unet_def": {
7+
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
8+
"spatial_dims": "@spatial_dims",
9+
"in_channels": "@latent_channels",
10+
"out_channels": "@latent_channels",
11+
"num_channels": [
12+
64,
13+
128,
14+
256,
15+
512
16+
],
17+
"attention_levels": [
18+
false,
19+
false,
20+
true,
21+
true
22+
],
23+
"num_head_channels": [
24+
0,
25+
0,
26+
32,
27+
32
28+
],
29+
"num_res_blocks": 2,
30+
"use_flash_attention": true,
31+
"include_top_region_index_input": true,
32+
"include_bottom_region_index_input": true,
33+
"include_spacing_input": true
34+
},
35+
"controlnet_def": {
36+
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
37+
"spatial_dims": "@spatial_dims",
38+
"in_channels": "@latent_channels",
39+
"num_channels": [
40+
64,
41+
128,
42+
256,
43+
512
44+
],
45+
"attention_levels": [
46+
false,
47+
false,
48+
true,
49+
true
50+
],
51+
"num_head_channels": [
52+
0,
53+
0,
54+
32,
55+
32
56+
],
57+
"num_res_blocks": 2,
58+
"use_flash_attention": true,
59+
"conditioning_embedding_in_channels": 8,
60+
"conditioning_embedding_num_channels": [8, 32, 64]
61+
},
62+
"noise_scheduler": {
63+
"_target_": "generative.networks.schedulers.DDPMScheduler",
64+
"num_train_timesteps": 1000,
65+
"beta_start": 0.0015,
66+
"beta_end": 0.0195,
67+
"schedule": "scaled_linear_beta",
68+
"clip_sample": false
69+
},
70+
"controlnet_train": {
71+
"batch_size": 1,
72+
"cache_rate": 0.0,
73+
"fold": 0,
74+
"lr": 1e-5,
75+
"n_epochs": 100,
76+
"weighted_loss_label": [129],
77+
"weighted_loss": 100
78+
}
79+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"model_dir": "./models/",
3+
"output_dir": "./output",
4+
"tfevent_path": "./outputs/tfevent",
5+
"trained_autoencoder_path": "./models/autoencoder_epoch273.pt",
6+
"trained_diffusion_path": "./models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt",
7+
"trained_controlnet_path": "./models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt",
8+
"exp_name": "controlnet_kits_finetune",
9+
"data_base_dir": ["./datasets/C4KC-KiTS_subset"],
10+
"json_data_list": ["./datasets/C4KC-KiTS_subset.json"]
11+
}
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import argparse
13+
import json
14+
import logging
15+
import os
16+
import sys
17+
import time
18+
from datetime import timedelta
19+
from pathlib import Path
20+
21+
import torch
22+
import torch.distributed as dist
23+
import torch.nn.functional as F
24+
from monai.networks.utils import copy_model_state
25+
from monai.utils import RankFilter
26+
from torch.cuda.amp import GradScaler, autocast
27+
from torch.nn.parallel import DistributedDataParallel as DDP
28+
from torch.utils.tensorboard import SummaryWriter
29+
from utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp
30+
31+
32+
def main():
33+
parser = argparse.ArgumentParser(description="maisi.controlnet.training")
34+
parser.add_argument(
35+
"-e",
36+
"--environment-file",
37+
default="./configs/environment_maisi_controlnet_train.json",
38+
help="environment json file that stores environment path",
39+
)
40+
parser.add_argument(
41+
"-c",
42+
"--config-file",
43+
default="./configs/config_maisi_controlnet_train.json",
44+
help="config json file that stores hyper-parameters",
45+
)
46+
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
47+
args = parser.parse_args()
48+
49+
# Step 0: configuration
50+
logger = logging.getLogger("maisi.controlnet.training")
51+
# whether to use distributed data parallel
52+
use_ddp = args.gpus > 1
53+
if use_ddp:
54+
rank = int(os.environ["LOCAL_RANK"])
55+
world_size = int(os.environ["WORLD_SIZE"])
56+
device = setup_ddp(rank, world_size)
57+
logger.addFilter(RankFilter())
58+
else:
59+
rank = 0
60+
world_size = 1
61+
device = torch.device(f"cuda:{rank}")
62+
63+
torch.cuda.set_device(device)
64+
logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
65+
logger.info(f"World_size: {world_size}")
66+
67+
env_dict = json.load(open(args.environment_file, "r"))
68+
config_dict = json.load(open(args.config_file, "r"))
69+
70+
for k, v in env_dict.items():
71+
setattr(args, k, v)
72+
for k, v in config_dict.items():
73+
setattr(args, k, v)
74+
75+
# initialize tensorboard writer
76+
if rank == 0:
77+
tensorboard_path = os.path.join(args.tfevent_path, args.exp_name)
78+
Path(tensorboard_path).mkdir(parents=True, exist_ok=True)
79+
tensorboard_writer = SummaryWriter(tensorboard_path)
80+
81+
# Step 1: set data loader
82+
train_loader, _ = prepare_maisi_controlnet_json_dataloader(
83+
json_data_list=args.json_data_list,
84+
data_base_dir=args.data_base_dir,
85+
rank=rank,
86+
world_size=world_size,
87+
batch_size=args.controlnet_train["batch_size"],
88+
cache_rate=args.controlnet_train["cache_rate"],
89+
fold=args.controlnet_train["fold"],
90+
)
91+
92+
# Step 2: define diffusion model and controlnet
93+
# define diffusion Model
94+
unet = define_instance(args, "diffusion_unet_def").to(device)
95+
# load trained diffusion model
96+
if not os.path.exists(args.trained_diffusion_path):
97+
raise ValueError("Please download the trained diffusion unet checkpoint.")
98+
diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device)
99+
unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"])
100+
# load scale factor
101+
scale_factor = diffusion_model_ckpt["scale_factor"]
102+
logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.")
103+
logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.")
104+
# define ControlNet
105+
controlnet = define_instance(args, "controlnet_def").to(device)
106+
# copy weights from the DM to the controlnet
107+
copy_model_state(controlnet, unet.state_dict())
108+
# load trained controlnet model if it is provided
109+
if args.trained_controlnet_path is not None:
110+
controlnet.load_state_dict(
111+
torch.load(args.trained_controlnet_path, map_location=device)["controlnet_state_dict"]
112+
)
113+
logger.info(f"load trained controlnet model from {args.trained_controlnet_path}")
114+
else:
115+
logger.info("train controlnet model from scratch.")
116+
# we freeze the parameters of the diffusion model.
117+
for p in unet.parameters():
118+
p.requires_grad = False
119+
120+
noise_scheduler = define_instance(args, "noise_scheduler")
121+
122+
if use_ddp:
123+
controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True)
124+
125+
# Step 3: training config
126+
weighted_loss = args.controlnet_train["weighted_loss"]
127+
weighted_loss_label = args.controlnet_train["weighted_loss_label"]
128+
optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"])
129+
total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"]
130+
logger.info(f"total number of training steps: {total_steps}.")
131+
132+
lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0)
133+
134+
# Step 4: training
135+
n_epochs = args.controlnet_train["n_epochs"]
136+
scaler = GradScaler()
137+
total_step = 0
138+
best_loss = 1e4
139+
140+
if weighted_loss > 0:
141+
logger.info(f"apply weighted loss = {weighted_loss} on labels: {weighted_loss_label}")
142+
143+
controlnet.train()
144+
unet.eval()
145+
prev_time = time.time()
146+
for epoch in range(n_epochs):
147+
epoch_loss_ = 0
148+
for step, batch in enumerate(train_loader):
149+
# get image embedding and label mask and scale image embedding by the provided scale_factor
150+
inputs = batch["image"].to(device) * scale_factor
151+
labels = batch["label"].to(device)
152+
# get corresponding conditions
153+
top_region_index_tensor = batch["top_region_index"].to(device)
154+
bottom_region_index_tensor = batch["bottom_region_index"].to(device)
155+
spacing_tensor = batch["spacing"].to(device)
156+
157+
optimizer.zero_grad(set_to_none=True)
158+
159+
with autocast(enabled=True):
160+
# generate random noise
161+
noise_shape = list(inputs.shape)
162+
noise = torch.randn(noise_shape, dtype=inputs.dtype).to(device)
163+
164+
# use binary encoding to encode segmentation mask
165+
controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float()
166+
167+
# create timesteps
168+
timesteps = torch.randint(
169+
0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=device
170+
).long()
171+
172+
# create noisy latent
173+
noisy_latent = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
174+
175+
# get controlnet output
176+
down_block_res_samples, mid_block_res_sample = controlnet(
177+
x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond
178+
)
179+
# get noise prediction from diffusion unet
180+
noise_pred = unet(
181+
x=noisy_latent,
182+
timesteps=timesteps,
183+
top_region_index_tensor=top_region_index_tensor,
184+
bottom_region_index_tensor=bottom_region_index_tensor,
185+
spacing_tensor=spacing_tensor,
186+
down_block_additional_residuals=down_block_res_samples,
187+
mid_block_additional_residual=mid_block_res_sample,
188+
)
189+
190+
if weighted_loss > 1.0:
191+
weights = torch.ones_like(inputs).to(inputs.device)
192+
roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device)
193+
interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest")
194+
# assign larger weights for ROI (tumor)
195+
for label in weighted_loss_label:
196+
roi[interpolate_label == label] = 1
197+
weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = weighted_loss
198+
loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean()
199+
else:
200+
loss = F.l1_loss(noise_pred.float(), noise.float())
201+
202+
scaler.scale(loss).backward()
203+
scaler.step(optimizer)
204+
scaler.update()
205+
lr_scheduler.step()
206+
total_step += 1
207+
208+
if rank == 0:
209+
# write train loss for each batch into tensorboard
210+
tensorboard_writer.add_scalar(
211+
"train/train_controlnet_loss_iter", loss.detach().cpu().item(), total_step
212+
)
213+
batches_done = step + 1
214+
batches_left = len(train_loader) - batches_done
215+
time_left = timedelta(seconds=batches_left * (time.time() - prev_time))
216+
prev_time = time.time()
217+
logger.info(
218+
"\r[Epoch %d/%d] [Batch %d/%d] [LR: %.8f] [loss: %.4f] ETA: %s "
219+
% (
220+
epoch + 1,
221+
n_epochs,
222+
step + 1,
223+
len(train_loader),
224+
lr_scheduler.get_last_lr()[0],
225+
loss.detach().cpu().item(),
226+
time_left,
227+
)
228+
)
229+
epoch_loss_ += loss.detach()
230+
231+
epoch_loss = epoch_loss_ / (step + 1)
232+
233+
if use_ddp:
234+
dist.barrier()
235+
dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG)
236+
237+
if rank == 0:
238+
tensorboard_writer.add_scalar("train/train_controlnet_loss_epoch", epoch_loss.cpu().item(), total_step)
239+
# save controlnet only on master GPU (rank 0)
240+
controlnet_state_dict = controlnet.module.state_dict() if world_size > 1 else controlnet.state_dict()
241+
torch.save(
242+
{
243+
"epoch": epoch + 1,
244+
"loss": epoch_loss,
245+
"controlnet_state_dict": controlnet_state_dict,
246+
},
247+
f"{args.model_dir}/{args.exp_name}_current.pt",
248+
)
249+
250+
if epoch_loss < best_loss:
251+
best_loss = epoch_loss
252+
logger.info(f"best loss -> {best_loss}.")
253+
torch.save(
254+
{
255+
"epoch": epoch + 1,
256+
"loss": best_loss,
257+
"controlnet_state_dict": controlnet_state_dict,
258+
},
259+
f"{args.model_dir}/{args.exp_name}_best.pt",
260+
)
261+
262+
torch.cuda.empty_cache()
263+
if use_ddp:
264+
dist.destroy_process_group()
265+
266+
267+
if __name__ == "__main__":
268+
logging.basicConfig(
269+
stream=sys.stdout,
270+
level=logging.INFO,
271+
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
272+
datefmt="%Y-%m-%d %H:%M:%S",
273+
)
274+
main()

0 commit comments

Comments
 (0)