|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +import argparse |
| 13 | +import json |
| 14 | +import logging |
| 15 | +from pathlib import Path |
| 16 | +import time |
| 17 | +from datetime import timedelta |
| 18 | +import warnings |
| 19 | + |
| 20 | +warnings.simplefilter('ignore', UserWarning) |
| 21 | + |
| 22 | +import os |
| 23 | +import sys |
| 24 | +import copy |
| 25 | +import torch |
| 26 | +import torch.distributed as dist |
| 27 | +import torch.nn.functional as F |
| 28 | +from monai.config import print_config |
| 29 | +from monai.utils import first, set_determinism |
| 30 | +from monai.data import DataLoader, CacheDataset, partition_dataset |
| 31 | +from torch.cuda.amp import GradScaler, autocast |
| 32 | +from torch.nn.parallel import DistributedDataParallel as DDP |
| 33 | +from torch.utils.tensorboard import SummaryWriter |
| 34 | +from monai.transforms import ( |
| 35 | + Compose, |
| 36 | + EnsureTyped, |
| 37 | + Lambdad, |
| 38 | + LoadImaged, |
| 39 | + Orientationd, |
| 40 | +) |
| 41 | +from monai.bundle import ConfigParser |
| 42 | +from utils import binarize_labels |
| 43 | + |
| 44 | +def setup_ddp(rank, world_size): |
| 45 | + print(f"Running DDP diffusion example on rank {rank}/world_size {world_size}.") |
| 46 | + print(f"Initing to IP {os.environ['MASTER_ADDR']}") |
| 47 | + dist.init_process_group( |
| 48 | + backend="nccl", init_method="env://", timeout=timedelta(seconds=36000), rank=rank, world_size=world_size |
| 49 | + ) |
| 50 | + dist.barrier() |
| 51 | + device = torch.device(f"cuda:{rank}") |
| 52 | + return dist, device |
| 53 | + |
| 54 | +def define_instance(args, instance_def_key): |
| 55 | + parser = ConfigParser(vars(args)) |
| 56 | + parser.parse(True) |
| 57 | + return parser.get_parsed_content(instance_def_key, instantiate=True) |
| 58 | + |
| 59 | +def add_data_dir2path(list_files, data_dir, fold=None): |
| 60 | + new_list_files = copy.deepcopy(list_files) |
| 61 | + if fold is not None: |
| 62 | + new_list_files_train = [] |
| 63 | + new_list_files_val = [] |
| 64 | + for d in new_list_files: |
| 65 | + d["image"] = os.path.join(data_dir, d["image"]) |
| 66 | + |
| 67 | + if "label" in d: |
| 68 | + d["label"] = os.path.join(data_dir, d["label"]) |
| 69 | + |
| 70 | + if fold is not None: |
| 71 | + if d["fold"] == fold: |
| 72 | + new_list_files_val.append(copy.deepcopy(d)) |
| 73 | + else: |
| 74 | + new_list_files_train.append(copy.deepcopy(d)) |
| 75 | + |
| 76 | + if fold is not None: |
| 77 | + return new_list_files_train, new_list_files_val |
| 78 | + else: |
| 79 | + return new_list_files, [] |
| 80 | + |
| 81 | + |
| 82 | +def prepare_maisi_controlnet_json_dataloader( |
| 83 | + args, |
| 84 | + json_data_list, |
| 85 | + data_base_dir, |
| 86 | + batch_size=1, |
| 87 | + fold=0, |
| 88 | + cache_rate=0.0, |
| 89 | + rank=0, |
| 90 | + world_size=1, |
| 91 | +): |
| 92 | + ddp_bool = world_size > 1 |
| 93 | + if isinstance(json_data_list, list): |
| 94 | + assert isinstance(args.data_base_dir, list) |
| 95 | + list_train = [] |
| 96 | + list_valid = [] |
| 97 | + for data_list, data_base_dir in zip(json_data_list, data_base_dir): |
| 98 | + with open(data_list, "r") as f: |
| 99 | + json_data = json.load(f) |
| 100 | + train, val = add_data_dir2path(json_data, data_base_dir, fold) |
| 101 | + list_train += train |
| 102 | + list_valid += val |
| 103 | + else: |
| 104 | + with open(json_data_list, "r") as f: |
| 105 | + json_data = json.load(f) |
| 106 | + list_train, list_valid = add_data_dir2path(json_data['training'], data_base_dir, fold) |
| 107 | + |
| 108 | + common_transform = [ |
| 109 | + LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True), |
| 110 | + Orientationd(keys=["label"], axcodes="RAS"), |
| 111 | + EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True), |
| 112 | + Lambdad(keys='top_region_index', func=lambda x: torch.FloatTensor(x)), |
| 113 | + Lambdad(keys='bottom_region_index', func=lambda x: torch.FloatTensor(x)), |
| 114 | + Lambdad(keys='spacing', func=lambda x: torch.FloatTensor(x)), |
| 115 | + Lambdad(keys='top_region_index', func=lambda x: x * 1e2), |
| 116 | + Lambdad(keys='bottom_region_index', func=lambda x: x * 1e2), |
| 117 | + Lambdad(keys='spacing', func=lambda x: x * 1e2), |
| 118 | + ] |
| 119 | + train_transforms, val_transforms = Compose(common_transform), Compose(common_transform) |
| 120 | + |
| 121 | + train_loader = None |
| 122 | + |
| 123 | + if ddp_bool: |
| 124 | + list_train = partition_dataset( |
| 125 | + data=list_train, |
| 126 | + shuffle=True, |
| 127 | + num_partitions=world_size, |
| 128 | + even_divisible=True, |
| 129 | + )[rank] |
| 130 | + train_ds = CacheDataset( |
| 131 | + data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8 |
| 132 | + ) |
| 133 | + train_loader = DataLoader( |
| 134 | + train_ds, batch_size=batch_size, shuffle=True, |
| 135 | + num_workers=8, pin_memory=True |
| 136 | + ) |
| 137 | + if ddp_bool: |
| 138 | + list_valid = partition_dataset( |
| 139 | + data=list_valid, |
| 140 | + shuffle=True, |
| 141 | + num_partitions=world_size, |
| 142 | + even_divisible=False, |
| 143 | + )[rank] |
| 144 | + val_ds = CacheDataset( |
| 145 | + data=list_valid, transform=val_transforms, cache_rate=cache_rate, num_workers=8, |
| 146 | + ) |
| 147 | + val_loader = DataLoader( |
| 148 | + val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False |
| 149 | + ) |
| 150 | + return train_loader, val_loader |
| 151 | + |
| 152 | +def main(): |
| 153 | + parser = argparse.ArgumentParser(description="PyTorch VAE-GAN training") |
| 154 | + parser.add_argument( |
| 155 | + "-e", |
| 156 | + "--environment-file", |
| 157 | + default="./config/environment_maisi_controlnet_train.json", |
| 158 | + help="environment json file that stores environment path", |
| 159 | + ) |
| 160 | + parser.add_argument( |
| 161 | + "-c", |
| 162 | + "--config-file", |
| 163 | + default="./config/config_maisi_controlnet_train.json", |
| 164 | + help="config json file that stores hyper-parameters", |
| 165 | + ) |
| 166 | + parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") |
| 167 | + parser.add_argument("-w", "--weighted_loss_label", nargs='+', default=[], action="store_true", help="list of lables that use weighted loss") |
| 168 | + parser.add_argument("-l", "--weighted_loss", default=100, type=int, help="loss weight loss for ROI labels") |
| 169 | + args = parser.parse_args() |
| 170 | + |
| 171 | + # Step 0: configuration |
| 172 | + ddp_bool = args.gpus > 1 # whether to use distributed data parallel |
| 173 | + if ddp_bool: |
| 174 | + rank = int(os.environ["LOCAL_RANK"]) |
| 175 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 176 | + dist, device = setup_ddp(rank, world_size) |
| 177 | + else: |
| 178 | + rank = 0 |
| 179 | + world_size = 1 |
| 180 | + device = 0 |
| 181 | + |
| 182 | + torch.cuda.set_device(device) |
| 183 | + print(f"Using {device}") |
| 184 | + |
| 185 | + env_dict = json.load(open(args.environment_file, "r")) |
| 186 | + config_dict = json.load(open(args.config_file, "r")) |
| 187 | + |
| 188 | + for k, v in env_dict.items(): |
| 189 | + setattr(args, k, v) |
| 190 | + for k, v in config_dict.items(): |
| 191 | + setattr(args, k, v) |
| 192 | + |
| 193 | + # initialize tensorboard writer |
| 194 | + if rank == 0: |
| 195 | + Path(args.tfevent_path).mkdir(parents=True, exist_ok=True) |
| 196 | + tensorboard_path = os.path.join(args.output_dir, "controlnet_tfevent") |
| 197 | + tensorboard_writer = SummaryWriter(tensorboard_path) |
| 198 | + |
| 199 | + # Step 1: set data loader |
| 200 | + train_loader, val_loader = prepare_maisi_controlnet_json_dataloader( |
| 201 | + args, |
| 202 | + json_data_list = args.json_data_list, |
| 203 | + data_base_dir = args.data_base_dir, |
| 204 | + rank=rank, |
| 205 | + world_size=world_size, |
| 206 | + batch_size=args.controlnet_train["batch_size"], |
| 207 | + cache_rate=args.controlnet_train["cache_rate"], |
| 208 | + fold=args.controlnet_train["fold"] |
| 209 | + ) |
| 210 | + |
| 211 | + # Step 2: define diffusion model and controlnet |
| 212 | + |
| 213 | + # define diffusion Model |
| 214 | + unet = define_instance(args, "difusion_unet_def").to(device) |
| 215 | + # load trained diffusion model |
| 216 | + map_location = {"cuda:%d" % 0: "cuda:%d" % rank} |
| 217 | + unet.load_state_dict(torch.load(args.trained_diffusion_path, map_location=map_location)) |
| 218 | + if rank == 0: |
| 219 | + print(f"Load trained diffusion model from", args.trained_diffusion_path) |
| 220 | + # define ControlNet |
| 221 | + controlnet = define_instance(args, "controlnet_def").to(device) |
| 222 | + # copy weights from the DM to the controlnet |
| 223 | + controlnet.load_state_dict(unet.state_dict(), strict=False) |
| 224 | + if args.trained_controlnet_path is not None: |
| 225 | + controlnet.load_state_dict(torch.load(args.trained_controlnet_path, map_location=map_location)) |
| 226 | + if rank == 0: |
| 227 | + print(f"load trained controlnet model from", args.trained_controlnet_path) |
| 228 | + else: |
| 229 | + if rank == 0: |
| 230 | + print("train controlnet model from scratch.") |
| 231 | + # we freeze the parameters of the diffusion model. |
| 232 | + for p in unet.parameters(): |
| 233 | + p.requires_grad = False |
| 234 | + |
| 235 | + noise_scheduler = define_instance(args, "noise_scheduler") |
| 236 | + |
| 237 | + if ddp_bool: |
| 238 | + controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True) |
| 239 | + |
| 240 | + # Step 3: training config |
| 241 | + optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"]) |
| 242 | + total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"] |
| 243 | + if rank ==0: |
| 244 | + print(f"total number of training steps: {total_steps}.") |
| 245 | + |
| 246 | + lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0) |
| 247 | + |
| 248 | + # Step 4: training |
| 249 | + n_epochs = args.controlnet_train["n_epochs"] |
| 250 | + scaler = GradScaler() |
| 251 | + total_step = 0 |
| 252 | + |
| 253 | + if args.weighted_loss > 0 and rank == 0: |
| 254 | + print(f"apply weighted loss = {args.weighted_loss} on labels: {args.weighted_loss_label}") |
| 255 | + |
| 256 | + prev_time = time.time() |
| 257 | + for epoch in range(n_epochs): |
| 258 | + unet.train() |
| 259 | + epoch_loss_ = 0 |
| 260 | + |
| 261 | + for step, batch in enumerate(train_loader): |
| 262 | + # get image embedding and label mask |
| 263 | + inputs = batch["image"].to(device) |
| 264 | + labels = batch["label"].to(device) |
| 265 | + # get coresponding condtions |
| 266 | + top_region_index_tensor = batch['top_region_index'].to(device) |
| 267 | + bottom_region_index_tensor = batch['bottom_region_index'].to(device) |
| 268 | + spacing_tensor = batch['spacing'].to(device) |
| 269 | + |
| 270 | + optimizer.zero_grad(set_to_none=True) |
| 271 | + |
| 272 | + with autocast(enabled=True): |
| 273 | + # generate random noise |
| 274 | + noise_shape = list(inputs.shape) |
| 275 | + noise = torch.randn(noise_shape, dtype=inputs.dtype).to(inputs.device) |
| 276 | + |
| 277 | + # use binary encoding to encode segmentation mask |
| 278 | + controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float() |
| 279 | + |
| 280 | + # create timesteps |
| 281 | + timesteps = torch.randint( |
| 282 | + 0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=inputs.device |
| 283 | + ).long() |
| 284 | + |
| 285 | + # create noisy latent |
| 286 | + noisy_latent = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) |
| 287 | + |
| 288 | + # get controlnet output |
| 289 | + down_block_res_samples, mid_block_res_sample = controlnet( |
| 290 | + x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond |
| 291 | + ) |
| 292 | + # get noise prediction from diffusion unet |
| 293 | + noise_pred = unet(x=noisy_latent, |
| 294 | + timesteps=timesteps, |
| 295 | + top_region_index_tensor=top_region_index_tensor, |
| 296 | + bottom_region_index_tensor=bottom_region_index_tensor, |
| 297 | + spacing_tensor=spacing_tensor, |
| 298 | + down_block_additional_residuals=down_block_res_samples, |
| 299 | + mid_block_additional_residual=mid_block_res_sample) |
| 300 | + |
| 301 | + if args.weighted_loss > 1.0: |
| 302 | + weights = torch.ones_like(inputs).to(inputs.device) |
| 303 | + roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device) |
| 304 | + interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest") |
| 305 | + # assign larger weights for ROI (tumor) |
| 306 | + for label in args.weighted_loss_label: |
| 307 | + roi[interpolate_label==label] = 1 |
| 308 | + weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = args.weighted_loss |
| 309 | + loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean() |
| 310 | + else: |
| 311 | + loss = F.l1_loss(noise_pred.float(), noise.float()) |
| 312 | + |
| 313 | + scaler.scale(loss).backward() |
| 314 | + scaler.step(optimizer) |
| 315 | + scaler.update() |
| 316 | + lr_scheduler.step() |
| 317 | + |
| 318 | + if rank == 0: |
| 319 | + # write train loss for each batch into tensorboard |
| 320 | + total_step += 1 |
| 321 | + tensorboard_writer.add_scalar("train/train_diffusion_loss_iter", loss.detach().cpu().item(), total_step) |
| 322 | + batches_done = step |
| 323 | + batches_left = len(train_loader) - batches_done |
| 324 | + time_left = timedelta(seconds=batches_left * (time.time() - prev_time)) |
| 325 | + prev_time = time.time() |
| 326 | + sys.stdout.write( |
| 327 | + "\r[Epoch %d/%d] [Batch %d/%d] [LR: %f] [loss: %04f] ETA: %s " |
| 328 | + % ( |
| 329 | + epoch, |
| 330 | + n_epochs, |
| 331 | + step, |
| 332 | + len(train_loader), |
| 333 | + lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else optimizer.param_groups[0][ |
| 334 | + 'lr'], |
| 335 | + loss.detach().cpu().item(), |
| 336 | + time_left, |
| 337 | + ) |
| 338 | + ) |
| 339 | + epoch_loss_ += loss.detach() |
| 340 | + |
| 341 | + epoch_loss = epoch_loss_ / (step + 1) |
| 342 | + |
| 343 | + if ddp_bool: |
| 344 | + dist.barrier() |
| 345 | + dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG) |
| 346 | + |
| 347 | + if rank == 0: |
| 348 | + tensorboard_writer.add_scalar("train/train_diffusion_loss_epoch", epoch_loss.cpu().item(), total_step) |
| 349 | + |
| 350 | + torch.cuda.empty_cache() |
| 351 | + |
| 352 | + |
| 353 | + |
| 354 | +if __name__ == "__main__": |
| 355 | + logging.basicConfig( |
| 356 | + stream=sys.stdout, |
| 357 | + level=logging.INFO, |
| 358 | + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", |
| 359 | + datefmt="%Y-%m-%d %H:%M:%S", |
| 360 | + ) |
| 361 | + main() |
0 commit comments