-
Notifications
You must be signed in to change notification settings - Fork 739
Add maisi controlnet training #1750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
KumoLiu
merged 22 commits into
Project-MONAI:main
from
guopengf:add-maisi-controlnet-training
Jul 9, 2024
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
29ef5e2
add train config
guopengf b189a38
add training script
guopengf 9bf5c7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c54778a
Updates to GAN script examples (#1727)
ericspod 13c9143
[pre-commit.ci] pre-commit suggestions (#1746)
pre-commit-ci[bot] 3295944
update
guopengf 3dc31fc
update
guopengf 3787a7c
update config and train loop
guopengf 91b2b30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7cef2e5
Merge branch 'main' into add-maisi-controlnet-training
guopengf b5e692c
add docstring and using logging
guopengf 58760c6
update train config
guopengf 3f94d19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2a4a31e
Merge branch 'main' into add-maisi-controlnet-training
KumoLiu b902c5b
update
guopengf 56720d7
Update generative/maisi/scripts/train_controlnet.py
guopengf 93d5469
Merge branch 'main' into add-maisi-controlnet-training
guopengf a13d351
update and move util functions to util.py
guopengf 40526e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fde48c4
Update generative/maisi/scripts/train_controlnet.py
guopengf a04c37b
Update generative/maisi/scripts/train_controlnet.py
guopengf 3234841
Merge branch 'main' into add-maisi-controlnet-training
guopengf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
79 changes: 79 additions & 0 deletions
79
generative/maisi/configs/config_maisi_controlnet_train.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
{ | ||
"random_seed": null, | ||
"spatial_dims": 3, | ||
"image_channels": 1, | ||
"latent_channels": 4, | ||
"diffusion_unet_def": { | ||
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi", | ||
"spatial_dims": "@spatial_dims", | ||
"in_channels": "@latent_channels", | ||
"out_channels": "@latent_channels", | ||
"num_channels": [ | ||
64, | ||
128, | ||
256, | ||
512 | ||
], | ||
"attention_levels": [ | ||
false, | ||
false, | ||
true, | ||
true | ||
], | ||
"num_head_channels": [ | ||
0, | ||
0, | ||
32, | ||
32 | ||
], | ||
"num_res_blocks": 2, | ||
"use_flash_attention": true, | ||
"include_top_region_index_input": true, | ||
"include_bottom_region_index_input": true, | ||
"include_spacing_input": true | ||
}, | ||
"controlnet_def": { | ||
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", | ||
"spatial_dims": "@spatial_dims", | ||
"in_channels": "@latent_channels", | ||
"num_channels": [ | ||
64, | ||
128, | ||
256, | ||
512 | ||
], | ||
"attention_levels": [ | ||
false, | ||
false, | ||
true, | ||
true | ||
], | ||
"num_head_channels": [ | ||
0, | ||
0, | ||
32, | ||
32 | ||
], | ||
"num_res_blocks": 2, | ||
"use_flash_attention": true, | ||
"conditioning_embedding_in_channels": 8, | ||
"conditioning_embedding_num_channels": [8, 32, 64] | ||
}, | ||
"noise_scheduler": { | ||
"_target_": "generative.networks.schedulers.DDPMScheduler", | ||
"num_train_timesteps": 1000, | ||
"beta_start": 0.0015, | ||
"beta_end": 0.0195, | ||
"schedule": "scaled_linear_beta", | ||
"clip_sample": false | ||
}, | ||
"controlnet_train": { | ||
"batch_size": 1, | ||
"cache_rate": 0.0, | ||
"fold": 0, | ||
"lr": 1e-5, | ||
"n_epochs": 100, | ||
"weighted_loss_label": [129], | ||
"weighted_loss": 100 | ||
} | ||
} |
11 changes: 11 additions & 0 deletions
11
generative/maisi/configs/environment_maisi_controlnet_train.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"model_dir": "./models/", | ||
"output_dir": "./output", | ||
"tfevent_path": "./outputs/tfevent", | ||
"trained_autoencoder_path": "./models/autoencoder_epoch273.pt", | ||
"trained_diffusion_path": "./models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt", | ||
"trained_controlnet_path": "./models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt", | ||
"exp_name": "controlnet_kits_finetune", | ||
"data_base_dir": ["./datasets/C4KC-KiTS_subset"], | ||
"json_data_list": ["./datasets/C4KC-KiTS_subset.json"] | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import json | ||
import logging | ||
import os | ||
import sys | ||
import time | ||
from datetime import timedelta | ||
from pathlib import Path | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
from monai.networks.utils import copy_model_state | ||
from monai.utils import RankFilter | ||
from torch.cuda.amp import GradScaler, autocast | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.utils.tensorboard import SummaryWriter | ||
from utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="maisi.controlnet.training") | ||
parser.add_argument( | ||
"-e", | ||
"--environment-file", | ||
default="./configs/environment_maisi_controlnet_train.json", | ||
help="environment json file that stores environment path", | ||
) | ||
parser.add_argument( | ||
"-c", | ||
"--config-file", | ||
default="./configs/config_maisi_controlnet_train.json", | ||
help="config json file that stores hyper-parameters", | ||
) | ||
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") | ||
args = parser.parse_args() | ||
|
||
# Step 0: configuration | ||
logger = logging.getLogger("maisi.controlnet.training") | ||
# whether to use distributed data parallel | ||
use_ddp = args.gpus > 1 | ||
if use_ddp: | ||
rank = int(os.environ["LOCAL_RANK"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
device = setup_ddp(rank, world_size) | ||
logger.addFilter(RankFilter()) | ||
else: | ||
rank = 0 | ||
world_size = 1 | ||
device = torch.device(f"cuda:{rank}") | ||
|
||
torch.cuda.set_device(device) | ||
logger.info(f"Number of GPUs: {torch.cuda.device_count()}") | ||
logger.info(f"World_size: {world_size}") | ||
|
||
env_dict = json.load(open(args.environment_file, "r")) | ||
config_dict = json.load(open(args.config_file, "r")) | ||
|
||
for k, v in env_dict.items(): | ||
setattr(args, k, v) | ||
for k, v in config_dict.items(): | ||
setattr(args, k, v) | ||
|
||
# initialize tensorboard writer | ||
if rank == 0: | ||
tensorboard_path = os.path.join(args.tfevent_path, args.exp_name) | ||
Path(tensorboard_path).mkdir(parents=True, exist_ok=True) | ||
tensorboard_writer = SummaryWriter(tensorboard_path) | ||
|
||
# Step 1: set data loader | ||
train_loader, _ = prepare_maisi_controlnet_json_dataloader( | ||
json_data_list=args.json_data_list, | ||
data_base_dir=args.data_base_dir, | ||
rank=rank, | ||
world_size=world_size, | ||
batch_size=args.controlnet_train["batch_size"], | ||
mingxin-zheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cache_rate=args.controlnet_train["cache_rate"], | ||
fold=args.controlnet_train["fold"], | ||
) | ||
|
||
# Step 2: define diffusion model and controlnet | ||
# define diffusion Model | ||
unet = define_instance(args, "diffusion_unet_def").to(device) | ||
# load trained diffusion model | ||
if not os.path.exists(args.trained_diffusion_path): | ||
raise ValueError("Please download the trained diffusion unet checkpoint.") | ||
diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device) | ||
KumoLiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) | ||
# load scale factor | ||
scale_factor = diffusion_model_ckpt["scale_factor"] | ||
logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") | ||
logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") | ||
# define ControlNet | ||
controlnet = define_instance(args, "controlnet_def").to(device) | ||
# copy weights from the DM to the controlnet | ||
copy_model_state(controlnet, unet.state_dict()) | ||
# load trained controlnet model if it is provided | ||
if args.trained_controlnet_path is not None: | ||
controlnet.load_state_dict( | ||
torch.load(args.trained_controlnet_path, map_location=device)["controlnet_state_dict"] | ||
) | ||
logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") | ||
else: | ||
logger.info("train controlnet model from scratch.") | ||
# we freeze the parameters of the diffusion model. | ||
for p in unet.parameters(): | ||
p.requires_grad = False | ||
|
||
noise_scheduler = define_instance(args, "noise_scheduler") | ||
|
||
if use_ddp: | ||
controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True) | ||
|
||
# Step 3: training config | ||
weighted_loss = args.controlnet_train["weighted_loss"] | ||
weighted_loss_label = args.controlnet_train["weighted_loss_label"] | ||
optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"]) | ||
total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"] | ||
logger.info(f"total number of training steps: {total_steps}.") | ||
|
||
lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0) | ||
|
||
# Step 4: training | ||
n_epochs = args.controlnet_train["n_epochs"] | ||
scaler = GradScaler() | ||
total_step = 0 | ||
best_loss = 1e4 | ||
|
||
if weighted_loss > 0: | ||
logger.info(f"apply weighted loss = {weighted_loss} on labels: {weighted_loss_label}") | ||
|
||
controlnet.train() | ||
unet.eval() | ||
prev_time = time.time() | ||
for epoch in range(n_epochs): | ||
epoch_loss_ = 0 | ||
for step, batch in enumerate(train_loader): | ||
# get image embedding and label mask and scale image embedding by the provided scale_factor | ||
inputs = batch["image"].to(device) * scale_factor | ||
labels = batch["label"].to(device) | ||
# get corresponding conditions | ||
top_region_index_tensor = batch["top_region_index"].to(device) | ||
bottom_region_index_tensor = batch["bottom_region_index"].to(device) | ||
spacing_tensor = batch["spacing"].to(device) | ||
|
||
optimizer.zero_grad(set_to_none=True) | ||
|
||
with autocast(enabled=True): | ||
# generate random noise | ||
noise_shape = list(inputs.shape) | ||
noise = torch.randn(noise_shape, dtype=inputs.dtype).to(device) | ||
|
||
# use binary encoding to encode segmentation mask | ||
controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float() | ||
|
||
# create timesteps | ||
timesteps = torch.randint( | ||
0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=device | ||
).long() | ||
|
||
# create noisy latent | ||
noisy_latent = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) | ||
|
||
# get controlnet output | ||
down_block_res_samples, mid_block_res_sample = controlnet( | ||
x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond | ||
) | ||
# get noise prediction from diffusion unet | ||
noise_pred = unet( | ||
x=noisy_latent, | ||
timesteps=timesteps, | ||
top_region_index_tensor=top_region_index_tensor, | ||
bottom_region_index_tensor=bottom_region_index_tensor, | ||
spacing_tensor=spacing_tensor, | ||
down_block_additional_residuals=down_block_res_samples, | ||
mid_block_additional_residual=mid_block_res_sample, | ||
) | ||
|
||
if weighted_loss > 1.0: | ||
weights = torch.ones_like(inputs).to(inputs.device) | ||
roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device) | ||
interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest") | ||
# assign larger weights for ROI (tumor) | ||
for label in weighted_loss_label: | ||
roi[interpolate_label == label] = 1 | ||
weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = weighted_loss | ||
loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean() | ||
else: | ||
loss = F.l1_loss(noise_pred.float(), noise.float()) | ||
|
||
scaler.scale(loss).backward() | ||
scaler.step(optimizer) | ||
scaler.update() | ||
lr_scheduler.step() | ||
total_step += 1 | ||
|
||
if rank == 0: | ||
# write train loss for each batch into tensorboard | ||
tensorboard_writer.add_scalar( | ||
"train/train_controlnet_loss_iter", loss.detach().cpu().item(), total_step | ||
) | ||
batches_done = step + 1 | ||
batches_left = len(train_loader) - batches_done | ||
time_left = timedelta(seconds=batches_left * (time.time() - prev_time)) | ||
prev_time = time.time() | ||
logger.info( | ||
"\r[Epoch %d/%d] [Batch %d/%d] [LR: %.8f] [loss: %.4f] ETA: %s " | ||
% ( | ||
epoch + 1, | ||
n_epochs, | ||
step + 1, | ||
len(train_loader), | ||
lr_scheduler.get_last_lr()[0], | ||
loss.detach().cpu().item(), | ||
time_left, | ||
) | ||
) | ||
epoch_loss_ += loss.detach() | ||
|
||
epoch_loss = epoch_loss_ / (step + 1) | ||
|
||
if use_ddp: | ||
dist.barrier() | ||
dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG) | ||
|
||
if rank == 0: | ||
tensorboard_writer.add_scalar("train/train_controlnet_loss_epoch", epoch_loss.cpu().item(), total_step) | ||
# save controlnet only on master GPU (rank 0) | ||
controlnet_state_dict = controlnet.module.state_dict() if world_size > 1 else controlnet.state_dict() | ||
torch.save( | ||
{ | ||
"epoch": epoch + 1, | ||
"loss": epoch_loss, | ||
"controlnet_state_dict": controlnet_state_dict, | ||
}, | ||
f"{args.model_dir}/{args.exp_name}_current.pt", | ||
) | ||
|
||
if epoch_loss < best_loss: | ||
best_loss = epoch_loss | ||
logger.info(f"best loss -> {best_loss}.") | ||
torch.save( | ||
{ | ||
"epoch": epoch + 1, | ||
"loss": best_loss, | ||
"controlnet_state_dict": controlnet_state_dict, | ||
}, | ||
f"{args.model_dir}/{args.exp_name}_best.pt", | ||
) | ||
|
||
torch.cuda.empty_cache() | ||
if use_ddp: | ||
dist.destroy_process_group() | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig( | ||
stream=sys.stdout, | ||
level=logging.INFO, | ||
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
) | ||
main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.