Skip to content

Add maisi controlnet training #1750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
29ef5e2
add train config
guopengf Jul 2, 2024
b189a38
add training script
guopengf Jul 2, 2024
9bf5c7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
c54778a
Updates to GAN script examples (#1727)
ericspod Jun 28, 2024
13c9143
[pre-commit.ci] pre-commit suggestions (#1746)
pre-commit-ci[bot] Jul 2, 2024
3295944
update
guopengf Jul 2, 2024
3dc31fc
update
guopengf Jul 3, 2024
3787a7c
update config and train loop
guopengf Jul 3, 2024
91b2b30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
7cef2e5
Merge branch 'main' into add-maisi-controlnet-training
guopengf Jul 3, 2024
b5e692c
add docstring and using logging
guopengf Jul 4, 2024
58760c6
update train config
guopengf Jul 4, 2024
3f94d19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
2a4a31e
Merge branch 'main' into add-maisi-controlnet-training
KumoLiu Jul 4, 2024
b902c5b
update
guopengf Jul 4, 2024
56720d7
Update generative/maisi/scripts/train_controlnet.py
guopengf Jul 8, 2024
93d5469
Merge branch 'main' into add-maisi-controlnet-training
guopengf Jul 8, 2024
a13d351
update and move util functions to util.py
guopengf Jul 8, 2024
40526e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2024
fde48c4
Update generative/maisi/scripts/train_controlnet.py
guopengf Jul 8, 2024
a04c37b
Update generative/maisi/scripts/train_controlnet.py
guopengf Jul 8, 2024
3234841
Merge branch 'main' into add-maisi-controlnet-training
guopengf Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions generative/maisi/configs/config_maisi_controlnet_train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
{
"random_seed": null,
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 4,
"diffusion_unet_def": {
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"include_top_region_index_input": true,
"include_bottom_region_index_input": true,
"include_spacing_input": true
},
"controlnet_def": {
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"num_channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"conditioning_embedding_in_channels": 8,
"conditioning_embedding_num_channels": [8, 32, 64]
},
"noise_scheduler": {
"_target_": "generative.networks.schedulers.DDPMScheduler",
"num_train_timesteps": 1000,
"beta_start": 0.0015,
"beta_end": 0.0195,
"schedule": "scaled_linear_beta",
"clip_sample": false
},
"controlnet_train": {
"batch_size": 1,
"cache_rate": 0.0,
"fold": 0,
"lr": 1e-5,
"n_epochs": 100,
"weighted_loss_label": [129],
"weighted_loss": 100
}
}
11 changes: 11 additions & 0 deletions generative/maisi/configs/environment_maisi_controlnet_train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"model_dir": "./models/",
"output_dir": "./output",
"tfevent_path": "./outputs/tfevent",
"trained_autoencoder_path": "./models/autoencoder_epoch273.pt",
"trained_diffusion_path": "./models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt",
"trained_controlnet_path": "./models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt",
"exp_name": "controlnet_kits_finetune",
"data_base_dir": ["./datasets/C4KC-KiTS_subset"],
"json_data_list": ["./datasets/C4KC-KiTS_subset.json"]
}
274 changes: 274 additions & 0 deletions generative/maisi/scripts/train_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
import os
import sys
import time
from datetime import timedelta
from pathlib import Path

import torch
import torch.distributed as dist
import torch.nn.functional as F
from monai.networks.utils import copy_model_state
from monai.utils import RankFilter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp


def main():
parser = argparse.ArgumentParser(description="maisi.controlnet.training")
parser.add_argument(
"-e",
"--environment-file",
default="./configs/environment_maisi_controlnet_train.json",
help="environment json file that stores environment path",
)
parser.add_argument(
"-c",
"--config-file",
default="./configs/config_maisi_controlnet_train.json",
help="config json file that stores hyper-parameters",
)
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
args = parser.parse_args()

# Step 0: configuration
logger = logging.getLogger("maisi.controlnet.training")
# whether to use distributed data parallel
use_ddp = args.gpus > 1
if use_ddp:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = setup_ddp(rank, world_size)
logger.addFilter(RankFilter())
else:
rank = 0
world_size = 1
device = torch.device(f"cuda:{rank}")

torch.cuda.set_device(device)
logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
logger.info(f"World_size: {world_size}")

env_dict = json.load(open(args.environment_file, "r"))
config_dict = json.load(open(args.config_file, "r"))

for k, v in env_dict.items():
setattr(args, k, v)
for k, v in config_dict.items():
setattr(args, k, v)

# initialize tensorboard writer
if rank == 0:
tensorboard_path = os.path.join(args.tfevent_path, args.exp_name)
Path(tensorboard_path).mkdir(parents=True, exist_ok=True)
tensorboard_writer = SummaryWriter(tensorboard_path)

# Step 1: set data loader
train_loader, _ = prepare_maisi_controlnet_json_dataloader(
json_data_list=args.json_data_list,
data_base_dir=args.data_base_dir,
rank=rank,
world_size=world_size,
batch_size=args.controlnet_train["batch_size"],
cache_rate=args.controlnet_train["cache_rate"],
fold=args.controlnet_train["fold"],
)

# Step 2: define diffusion model and controlnet
# define diffusion Model
unet = define_instance(args, "diffusion_unet_def").to(device)
# load trained diffusion model
if not os.path.exists(args.trained_diffusion_path):
raise ValueError("Please download the trained diffusion unet checkpoint.")
diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device)
unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"])
# load scale factor
scale_factor = diffusion_model_ckpt["scale_factor"]
logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.")
logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.")
# define ControlNet
controlnet = define_instance(args, "controlnet_def").to(device)
# copy weights from the DM to the controlnet
copy_model_state(controlnet, unet.state_dict())
# load trained controlnet model if it is provided
if args.trained_controlnet_path is not None:
controlnet.load_state_dict(
torch.load(args.trained_controlnet_path, map_location=device)["controlnet_state_dict"]
)
logger.info(f"load trained controlnet model from {args.trained_controlnet_path}")
else:
logger.info("train controlnet model from scratch.")
# we freeze the parameters of the diffusion model.
for p in unet.parameters():
p.requires_grad = False

noise_scheduler = define_instance(args, "noise_scheduler")

if use_ddp:
controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True)

# Step 3: training config
weighted_loss = args.controlnet_train["weighted_loss"]
weighted_loss_label = args.controlnet_train["weighted_loss_label"]
optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"])
total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"]
logger.info(f"total number of training steps: {total_steps}.")

lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0)

# Step 4: training
n_epochs = args.controlnet_train["n_epochs"]
scaler = GradScaler()
total_step = 0
best_loss = 1e4

if weighted_loss > 0:
logger.info(f"apply weighted loss = {weighted_loss} on labels: {weighted_loss_label}")

controlnet.train()
unet.eval()
prev_time = time.time()
for epoch in range(n_epochs):
epoch_loss_ = 0
for step, batch in enumerate(train_loader):
# get image embedding and label mask and scale image embedding by the provided scale_factor
inputs = batch["image"].to(device) * scale_factor
labels = batch["label"].to(device)
# get corresponding conditions
top_region_index_tensor = batch["top_region_index"].to(device)
bottom_region_index_tensor = batch["bottom_region_index"].to(device)
spacing_tensor = batch["spacing"].to(device)

optimizer.zero_grad(set_to_none=True)

with autocast(enabled=True):
# generate random noise
noise_shape = list(inputs.shape)
noise = torch.randn(noise_shape, dtype=inputs.dtype).to(device)

# use binary encoding to encode segmentation mask
controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float()

# create timesteps
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=device
).long()

# create noisy latent
noisy_latent = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)

# get controlnet output
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond
)
# get noise prediction from diffusion unet
noise_pred = unet(
x=noisy_latent,
timesteps=timesteps,
top_region_index_tensor=top_region_index_tensor,
bottom_region_index_tensor=bottom_region_index_tensor,
spacing_tensor=spacing_tensor,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)

if weighted_loss > 1.0:
weights = torch.ones_like(inputs).to(inputs.device)
roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device)
interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest")
# assign larger weights for ROI (tumor)
for label in weighted_loss_label:
roi[interpolate_label == label] = 1
weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = weighted_loss
loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean()
else:
loss = F.l1_loss(noise_pred.float(), noise.float())

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
total_step += 1

if rank == 0:
# write train loss for each batch into tensorboard
tensorboard_writer.add_scalar(
"train/train_controlnet_loss_iter", loss.detach().cpu().item(), total_step
)
batches_done = step + 1
batches_left = len(train_loader) - batches_done
time_left = timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
logger.info(
"\r[Epoch %d/%d] [Batch %d/%d] [LR: %.8f] [loss: %.4f] ETA: %s "
% (
epoch + 1,
n_epochs,
step + 1,
len(train_loader),
lr_scheduler.get_last_lr()[0],
loss.detach().cpu().item(),
time_left,
)
)
epoch_loss_ += loss.detach()

epoch_loss = epoch_loss_ / (step + 1)

if use_ddp:
dist.barrier()
dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG)

if rank == 0:
tensorboard_writer.add_scalar("train/train_controlnet_loss_epoch", epoch_loss.cpu().item(), total_step)
# save controlnet only on master GPU (rank 0)
controlnet_state_dict = controlnet.module.state_dict() if world_size > 1 else controlnet.state_dict()
torch.save(
{
"epoch": epoch + 1,
"loss": epoch_loss,
"controlnet_state_dict": controlnet_state_dict,
},
f"{args.model_dir}/{args.exp_name}_current.pt",
)

if epoch_loss < best_loss:
best_loss = epoch_loss
logger.info(f"best loss -> {best_loss}.")
torch.save(
{
"epoch": epoch + 1,
"loss": best_loss,
"controlnet_state_dict": controlnet_state_dict,
},
f"{args.model_dir}/{args.exp_name}_best.pt",
)

torch.cuda.empty_cache()
if use_ddp:
dist.destroy_process_group()


if __name__ == "__main__":
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
main()
Loading
Loading