|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +import argparse |
| 13 | +import json |
| 14 | +import logging |
| 15 | +import os |
| 16 | +import sys |
| 17 | +import time |
| 18 | +from datetime import timedelta |
| 19 | +from pathlib import Path |
| 20 | + |
| 21 | +import torch |
| 22 | +import torch.distributed as dist |
| 23 | +import torch.nn.functional as F |
| 24 | +from monai.networks.utils import copy_model_state |
| 25 | +from monai.utils import RankFilter |
| 26 | +from torch.cuda.amp import GradScaler, autocast |
| 27 | +from torch.nn.parallel import DistributedDataParallel as DDP |
| 28 | +from torch.utils.tensorboard import SummaryWriter |
| 29 | +from utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp |
| 30 | + |
| 31 | + |
| 32 | +def main(): |
| 33 | + parser = argparse.ArgumentParser(description="maisi.controlnet.training") |
| 34 | + parser.add_argument( |
| 35 | + "-e", |
| 36 | + "--environment-file", |
| 37 | + default="./configs/environment_maisi_controlnet_train.json", |
| 38 | + help="environment json file that stores environment path", |
| 39 | + ) |
| 40 | + parser.add_argument( |
| 41 | + "-c", |
| 42 | + "--config-file", |
| 43 | + default="./configs/config_maisi_controlnet_train.json", |
| 44 | + help="config json file that stores hyper-parameters", |
| 45 | + ) |
| 46 | + parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") |
| 47 | + args = parser.parse_args() |
| 48 | + |
| 49 | + # Step 0: configuration |
| 50 | + logger = logging.getLogger("maisi.controlnet.training") |
| 51 | + # whether to use distributed data parallel |
| 52 | + use_ddp = args.gpus > 1 |
| 53 | + if use_ddp: |
| 54 | + rank = int(os.environ["LOCAL_RANK"]) |
| 55 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 56 | + device = setup_ddp(rank, world_size) |
| 57 | + logger.addFilter(RankFilter()) |
| 58 | + else: |
| 59 | + rank = 0 |
| 60 | + world_size = 1 |
| 61 | + device = torch.device(f"cuda:{rank}") |
| 62 | + |
| 63 | + torch.cuda.set_device(device) |
| 64 | + logger.info(f"Number of GPUs: {torch.cuda.device_count()}") |
| 65 | + logger.info(f"World_size: {world_size}") |
| 66 | + |
| 67 | + env_dict = json.load(open(args.environment_file, "r")) |
| 68 | + config_dict = json.load(open(args.config_file, "r")) |
| 69 | + |
| 70 | + for k, v in env_dict.items(): |
| 71 | + setattr(args, k, v) |
| 72 | + for k, v in config_dict.items(): |
| 73 | + setattr(args, k, v) |
| 74 | + |
| 75 | + # initialize tensorboard writer |
| 76 | + if rank == 0: |
| 77 | + tensorboard_path = os.path.join(args.tfevent_path, args.exp_name) |
| 78 | + Path(tensorboard_path).mkdir(parents=True, exist_ok=True) |
| 79 | + tensorboard_writer = SummaryWriter(tensorboard_path) |
| 80 | + |
| 81 | + # Step 1: set data loader |
| 82 | + train_loader, _ = prepare_maisi_controlnet_json_dataloader( |
| 83 | + json_data_list=args.json_data_list, |
| 84 | + data_base_dir=args.data_base_dir, |
| 85 | + rank=rank, |
| 86 | + world_size=world_size, |
| 87 | + batch_size=args.controlnet_train["batch_size"], |
| 88 | + cache_rate=args.controlnet_train["cache_rate"], |
| 89 | + fold=args.controlnet_train["fold"], |
| 90 | + ) |
| 91 | + |
| 92 | + # Step 2: define diffusion model and controlnet |
| 93 | + # define diffusion Model |
| 94 | + unet = define_instance(args, "diffusion_unet_def").to(device) |
| 95 | + # load trained diffusion model |
| 96 | + if not os.path.exists(args.trained_diffusion_path): |
| 97 | + raise ValueError("Please download the trained diffusion unet checkpoint.") |
| 98 | + diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device) |
| 99 | + unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) |
| 100 | + # load scale factor |
| 101 | + scale_factor = diffusion_model_ckpt["scale_factor"] |
| 102 | + logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") |
| 103 | + logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") |
| 104 | + # define ControlNet |
| 105 | + controlnet = define_instance(args, "controlnet_def").to(device) |
| 106 | + # copy weights from the DM to the controlnet |
| 107 | + copy_model_state(controlnet, unet.state_dict()) |
| 108 | + # load trained controlnet model if it is provided |
| 109 | + if args.trained_controlnet_path is not None: |
| 110 | + controlnet.load_state_dict( |
| 111 | + torch.load(args.trained_controlnet_path, map_location=device)["controlnet_state_dict"] |
| 112 | + ) |
| 113 | + logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") |
| 114 | + else: |
| 115 | + logger.info("train controlnet model from scratch.") |
| 116 | + # we freeze the parameters of the diffusion model. |
| 117 | + for p in unet.parameters(): |
| 118 | + p.requires_grad = False |
| 119 | + |
| 120 | + noise_scheduler = define_instance(args, "noise_scheduler") |
| 121 | + |
| 122 | + if use_ddp: |
| 123 | + controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True) |
| 124 | + |
| 125 | + # Step 3: training config |
| 126 | + weighted_loss = args.controlnet_train["weighted_loss"] |
| 127 | + weighted_loss_label = args.controlnet_train["weighted_loss_label"] |
| 128 | + optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"]) |
| 129 | + total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"] |
| 130 | + logger.info(f"total number of training steps: {total_steps}.") |
| 131 | + |
| 132 | + lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0) |
| 133 | + |
| 134 | + # Step 4: training |
| 135 | + n_epochs = args.controlnet_train["n_epochs"] |
| 136 | + scaler = GradScaler() |
| 137 | + total_step = 0 |
| 138 | + best_loss = 1e4 |
| 139 | + |
| 140 | + if weighted_loss > 0: |
| 141 | + logger.info(f"apply weighted loss = {weighted_loss} on labels: {weighted_loss_label}") |
| 142 | + |
| 143 | + controlnet.train() |
| 144 | + unet.eval() |
| 145 | + prev_time = time.time() |
| 146 | + for epoch in range(n_epochs): |
| 147 | + epoch_loss_ = 0 |
| 148 | + for step, batch in enumerate(train_loader): |
| 149 | + # get image embedding and label mask and scale image embedding by the provided scale_factor |
| 150 | + inputs = batch["image"].to(device) * scale_factor |
| 151 | + labels = batch["label"].to(device) |
| 152 | + # get corresponding conditions |
| 153 | + top_region_index_tensor = batch["top_region_index"].to(device) |
| 154 | + bottom_region_index_tensor = batch["bottom_region_index"].to(device) |
| 155 | + spacing_tensor = batch["spacing"].to(device) |
| 156 | + |
| 157 | + optimizer.zero_grad(set_to_none=True) |
| 158 | + |
| 159 | + with autocast(enabled=True): |
| 160 | + # generate random noise |
| 161 | + noise_shape = list(inputs.shape) |
| 162 | + noise = torch.randn(noise_shape, dtype=inputs.dtype).to(device) |
| 163 | + |
| 164 | + # use binary encoding to encode segmentation mask |
| 165 | + controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float() |
| 166 | + |
| 167 | + # create timesteps |
| 168 | + timesteps = torch.randint( |
| 169 | + 0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=device |
| 170 | + ).long() |
| 171 | + |
| 172 | + # create noisy latent |
| 173 | + noisy_latent = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) |
| 174 | + |
| 175 | + # get controlnet output |
| 176 | + down_block_res_samples, mid_block_res_sample = controlnet( |
| 177 | + x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond |
| 178 | + ) |
| 179 | + # get noise prediction from diffusion unet |
| 180 | + noise_pred = unet( |
| 181 | + x=noisy_latent, |
| 182 | + timesteps=timesteps, |
| 183 | + top_region_index_tensor=top_region_index_tensor, |
| 184 | + bottom_region_index_tensor=bottom_region_index_tensor, |
| 185 | + spacing_tensor=spacing_tensor, |
| 186 | + down_block_additional_residuals=down_block_res_samples, |
| 187 | + mid_block_additional_residual=mid_block_res_sample, |
| 188 | + ) |
| 189 | + |
| 190 | + if weighted_loss > 1.0: |
| 191 | + weights = torch.ones_like(inputs).to(inputs.device) |
| 192 | + roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device) |
| 193 | + interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest") |
| 194 | + # assign larger weights for ROI (tumor) |
| 195 | + for label in weighted_loss_label: |
| 196 | + roi[interpolate_label == label] = 1 |
| 197 | + weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = weighted_loss |
| 198 | + loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean() |
| 199 | + else: |
| 200 | + loss = F.l1_loss(noise_pred.float(), noise.float()) |
| 201 | + |
| 202 | + scaler.scale(loss).backward() |
| 203 | + scaler.step(optimizer) |
| 204 | + scaler.update() |
| 205 | + lr_scheduler.step() |
| 206 | + total_step += 1 |
| 207 | + |
| 208 | + if rank == 0: |
| 209 | + # write train loss for each batch into tensorboard |
| 210 | + tensorboard_writer.add_scalar( |
| 211 | + "train/train_controlnet_loss_iter", loss.detach().cpu().item(), total_step |
| 212 | + ) |
| 213 | + batches_done = step + 1 |
| 214 | + batches_left = len(train_loader) - batches_done |
| 215 | + time_left = timedelta(seconds=batches_left * (time.time() - prev_time)) |
| 216 | + prev_time = time.time() |
| 217 | + logger.info( |
| 218 | + "\r[Epoch %d/%d] [Batch %d/%d] [LR: %.8f] [loss: %.4f] ETA: %s " |
| 219 | + % ( |
| 220 | + epoch + 1, |
| 221 | + n_epochs, |
| 222 | + step + 1, |
| 223 | + len(train_loader), |
| 224 | + lr_scheduler.get_last_lr()[0], |
| 225 | + loss.detach().cpu().item(), |
| 226 | + time_left, |
| 227 | + ) |
| 228 | + ) |
| 229 | + epoch_loss_ += loss.detach() |
| 230 | + |
| 231 | + epoch_loss = epoch_loss_ / (step + 1) |
| 232 | + |
| 233 | + if use_ddp: |
| 234 | + dist.barrier() |
| 235 | + dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG) |
| 236 | + |
| 237 | + if rank == 0: |
| 238 | + tensorboard_writer.add_scalar("train/train_controlnet_loss_epoch", epoch_loss.cpu().item(), total_step) |
| 239 | + # save controlnet only on master GPU (rank 0) |
| 240 | + controlnet_state_dict = controlnet.module.state_dict() if world_size > 1 else controlnet.state_dict() |
| 241 | + torch.save( |
| 242 | + { |
| 243 | + "epoch": epoch + 1, |
| 244 | + "loss": epoch_loss, |
| 245 | + "controlnet_state_dict": controlnet_state_dict, |
| 246 | + }, |
| 247 | + f"{args.model_dir}/{args.exp_name}_current.pt", |
| 248 | + ) |
| 249 | + |
| 250 | + if epoch_loss < best_loss: |
| 251 | + best_loss = epoch_loss |
| 252 | + logger.info(f"best loss -> {best_loss}.") |
| 253 | + torch.save( |
| 254 | + { |
| 255 | + "epoch": epoch + 1, |
| 256 | + "loss": best_loss, |
| 257 | + "controlnet_state_dict": controlnet_state_dict, |
| 258 | + }, |
| 259 | + f"{args.model_dir}/{args.exp_name}_best.pt", |
| 260 | + ) |
| 261 | + |
| 262 | + torch.cuda.empty_cache() |
| 263 | + if use_ddp: |
| 264 | + dist.destroy_process_group() |
| 265 | + |
| 266 | + |
| 267 | +if __name__ == "__main__": |
| 268 | + logging.basicConfig( |
| 269 | + stream=sys.stdout, |
| 270 | + level=logging.INFO, |
| 271 | + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", |
| 272 | + datefmt="%Y-%m-%d %H:%M:%S", |
| 273 | + ) |
| 274 | + main() |
0 commit comments