Skip to content

Commit 1cd41fc

Browse files
Merge branch 'main' into add-metrics-notebook
2 parents 769e1bc + 4488f47 commit 1cd41fc

17 files changed

+2093
-67
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,8 @@ deployment/ray/mednist_classifier_start.py
152152
3d_segmentation/out
153153
*.nsys-rep
154154
auto3dseg/notebooks/datalist.json
155+
156+
*.jpeg
157+
*.png
158+
*.np*
159+
*.pt

generation/2d_ldm/2d_ldm_tutorial.ipynb

Lines changed: 1012 additions & 0 deletions
Large diffs are not rendered by default.

generation/2d_ldm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# 2D Latent Diffusion Example
2-
This folder contains an example for training and validating a 2D Latent Diffusion Model on Brats axial slices. The example includes support for multi-GPU training with distributed data parallelism.
2+
This folder contains examples for training and validating a 2D Latent Diffusion Model on MedNIST and Brats axial slice data. The notebook [2d_ldm_tutorial.ipynb](./2d_ldm_tutorial.ipynb) demonstrates these concepts with the MedNIST dataset. The larger example given in Python files and explained here uses Brats and includes support for multi-GPU training with distributed data parallelism.
33

44
The workflow of the Latent Diffusion Model is depicted in the figure below. It begins by training an autoencoder in pixel space to encode images into latent features. Following that, it trains a diffusion model in the latent space to denoise the noisy latent features. During inference, it first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Finally, it decodes the denoised latent features into images using the trained autoencoder.
55
<p align="center">

generation/2d_ldm/config/config_train_16g.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
"latent_channels": 1,
66
"sample_axis": 2,
77
"autoencoder_def": {
8-
"_target_": "generative.networks.nets.AutoencoderKL",
8+
"_target_": "monai.networks.nets.AutoencoderKL",
99
"spatial_dims": "@spatial_dims",
1010
"in_channels": "$@image_channels",
1111
"out_channels": "@image_channels",
1212
"latent_channels": "@latent_channels",
13-
"num_channels": [
13+
"channels": [
1414
64,
1515
128,
1616
256
@@ -33,15 +33,15 @@
3333
"perceptual_weight": 1.0,
3434
"kl_weight": 1e-6,
3535
"recon_loss": "l1",
36-
"n_epochs": 1000,
36+
"max_epochs": 1000,
3737
"val_interval": 1
3838
},
3939
"diffusion_def": {
40-
"_target_": "generative.networks.nets.DiffusionModelUNet",
40+
"_target_": "monai.networks.nets.DiffusionModelUNet",
4141
"spatial_dims": "@spatial_dims",
4242
"in_channels": "@latent_channels",
4343
"out_channels": "@latent_channels",
44-
"num_channels":[32, 64, 128, 256],
44+
"channels":[32, 64, 128, 256],
4545
"attention_levels":[false, true, true, true],
4646
"num_head_channels":[0, 32, 32, 32],
4747
"num_res_blocks": 2
@@ -50,7 +50,7 @@
5050
"batch_size": 50,
5151
"patch_size": [256,256],
5252
"lr": 1e-5,
53-
"n_epochs": 1500,
53+
"max_epochs": 1500,
5454
"val_interval": 2,
5555
"lr_scheduler_milestones": [1000]
5656
},

generation/2d_ldm/config/config_train_32g.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
"latent_channels": 1,
66
"sample_axis": 2,
77
"autoencoder_def": {
8-
"_target_": "generative.networks.nets.AutoencoderKL",
8+
"_target_": "monai.networks.nets.AutoencoderKL",
99
"spatial_dims": "@spatial_dims",
1010
"in_channels": "$@image_channels",
1111
"out_channels": "@image_channels",
1212
"latent_channels": "@latent_channels",
13-
"num_channels": [
13+
"channels": [
1414
64,
1515
128,
1616
256
@@ -33,15 +33,15 @@
3333
"perceptual_weight": 1.0,
3434
"kl_weight": 1e-6,
3535
"recon_loss": "l1",
36-
"n_epochs": 1000,
36+
"max_epochs": 1000,
3737
"val_interval": 1
3838
},
3939
"diffusion_def": {
40-
"_target_": "generative.networks.nets.DiffusionModelUNet",
40+
"_target_": "monai.networks.nets.DiffusionModelUNet",
4141
"spatial_dims": "@spatial_dims",
4242
"in_channels": "@latent_channels",
4343
"out_channels": "@latent_channels",
44-
"num_channels":[32, 64, 128, 256],
44+
"channels":[32, 64, 128, 256],
4545
"attention_levels":[false, true, true, true],
4646
"num_head_channels":[0, 32, 32, 32],
4747
"num_res_blocks": 2
@@ -50,7 +50,7 @@
5050
"batch_size": 80,
5151
"patch_size": [256,256],
5252
"lr": 1e-5,
53-
"n_epochs": 1500,
53+
"max_epochs": 1500,
5454
"val_interval": 2,
5555
"lr_scheduler_milestones": [1000]
5656
},

generation/2d_ldm/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
import numpy as np
2121
import torch
22-
from generative.inferers import LatentDiffusionInferer
23-
from generative.networks.schedulers import DDPMScheduler
22+
from monai.inferers import LatentDiffusionInferer
23+
from monai.networks.schedulers import DDPMScheduler
2424
from monai.config import print_config
2525
from monai.utils import set_determinism
2626
from PIL import Image

generation/2d_ldm/train_autoencoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from pathlib import Path
1818

1919
import torch
20-
from generative.losses import PatchAdversarialLoss, PerceptualLoss
21-
from generative.networks.nets import PatchDiscriminator
20+
from monai.losses import PatchAdversarialLoss, PerceptualLoss
21+
from monai.networks.nets import PatchDiscriminator
2222
from monai.config import print_config
2323
from monai.utils import set_determinism
2424
from torch.nn import L1Loss, MSELoss
@@ -75,7 +75,7 @@ def main():
7575
set_determinism(42)
7676

7777
# Step 1: set data loader
78-
size_divisible = 2 ** (len(args.autoencoder_def["num_channels"]) - 1)
78+
size_divisible = 2 ** (len(args.autoencoder_def["channels"]) - 1)
7979
train_loader, val_loader = prepare_brats2d_dataloader(
8080
args,
8181
args.autoencoder_train["batch_size"],
@@ -95,7 +95,7 @@ def main():
9595
discriminator = PatchDiscriminator(
9696
spatial_dims=args.spatial_dims,
9797
num_layers_d=3,
98-
num_channels=32,
98+
channels=32,
9999
in_channels=1,
100100
out_channels=1,
101101
norm=discriminator_norm,
@@ -172,12 +172,12 @@ def main():
172172

173173
# Step 4: training
174174
autoencoder_warm_up_n_epochs = 5
175-
n_epochs = args.autoencoder_train["n_epochs"]
175+
max_epochs = args.autoencoder_train["max_epochs"]
176176
val_interval = args.autoencoder_train["val_interval"]
177177
best_val_recon_epoch_loss = 100.0
178178
total_step = 0
179179

180-
for epoch in range(n_epochs):
180+
for epoch in range(max_epochs):
181181
# train
182182
autoencoder.train()
183183
discriminator.train()

generation/2d_ldm/train_diffusion.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
import torch
2020
import torch.nn.functional as F
21-
from generative.inferers import LatentDiffusionInferer
22-
from generative.networks.schedulers import DDPMScheduler
21+
from monai.inferers import LatentDiffusionInferer
22+
from monai.networks.schedulers import DDPMScheduler
2323
from monai.config import print_config
2424
from monai.utils import first, set_determinism
25-
from torch.cuda.amp import GradScaler, autocast
25+
from torch.amp import GradScaler, autocast
2626
from torch.nn.parallel import DistributedDataParallel as DDP
2727
from torch.utils.tensorboard import SummaryWriter
2828
from utils import define_instance, prepare_brats2d_dataloader, setup_ddp
@@ -75,7 +75,7 @@ def main():
7575
set_determinism(42)
7676

7777
# Step 1: set data loader
78-
size_divisible = 2 ** (len(args.autoencoder_def["num_channels"]) + len(args.diffusion_def["num_channels"]) - 2)
78+
size_divisible = 2 ** (len(args.autoencoder_def["channels"]) + len(args.diffusion_def["channels"]) - 2)
7979
train_loader, val_loader = prepare_brats2d_dataloader(
8080
args,
8181
args.diffusion_train["batch_size"],
@@ -114,7 +114,7 @@ def main():
114114
# and the results will not differ from those obtained when it is not used._
115115

116116
with torch.no_grad():
117-
with autocast(enabled=True):
117+
with autocast("cuda", enabled=True):
118118
check_data = first(train_loader)
119119
z = autoencoder.encode_stage_2_inputs(check_data["image"].to(device))
120120
if rank == 0:
@@ -179,14 +179,14 @@ def main():
179179
)
180180

181181
# Step 4: training
182-
n_epochs = args.diffusion_train["n_epochs"]
182+
max_epochs = args.diffusion_train["max_epochs"]
183183
val_interval = args.diffusion_train["val_interval"]
184184
autoencoder.eval()
185185
scaler = GradScaler()
186186
total_step = 0
187187
best_val_recon_epoch_loss = 100.0
188188

189-
for epoch in range(start_epoch, n_epochs):
189+
for epoch in range(start_epoch, max_epochs):
190190
unet.train()
191191
lr_scheduler.step()
192192
if ddp_bool:
@@ -196,7 +196,7 @@ def main():
196196
images = batch["image"].to(device)
197197
optimizer_diff.zero_grad(set_to_none=True)
198198

199-
with autocast(enabled=True):
199+
with autocast("cuda", enabled=True):
200200
# Generate random noise
201201
noise_shape = [images.shape[0]] + list(z.shape[1:])
202202
noise = torch.randn(noise_shape, dtype=images.dtype).to(device)
@@ -239,7 +239,7 @@ def main():
239239
unet.eval()
240240
val_recon_epoch_loss = 0
241241
with torch.no_grad():
242-
with autocast(enabled=True):
242+
with autocast("cuda", enabled=True):
243243
# compute val loss
244244
for step, batch in enumerate(val_loader):
245245
images = batch["image"].to(device)

0 commit comments

Comments
 (0)