Skip to content

Commit 96302a4

Browse files
pre-commit-ci[bot]guopengf
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ca76573 commit 96302a4

File tree

3 files changed

+72
-65
lines changed

3 files changed

+72
-65
lines changed

generative/maisi/configs/config_maisi_controlnet_train.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,4 @@
8282
"n_epochs": 10000,
8383
"val_interval": 5
8484
}
85-
}
85+
}

generative/maisi/configs/environment_maisi_controlnet_train.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
"trained_autoencoder_path": "../models/autoencoder_epoch273.pt",
55
"trained_diffusion_path": "../models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt",
66
"trained_controlnet_path": "..models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt"
7-
}
7+
}

generative/maisi/scripts/train_controlnet.py

Lines changed: 70 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from datetime import timedelta
1818
import warnings
1919

20-
warnings.simplefilter('ignore', UserWarning)
20+
warnings.simplefilter("ignore", UserWarning)
2121

2222
import os
2323
import sys
@@ -41,53 +41,56 @@
4141
from monai.bundle import ConfigParser
4242
from utils import binarize_labels
4343

44+
4445
def setup_ddp(rank, world_size):
4546
print(f"Running DDP diffusion example on rank {rank}/world_size {world_size}.")
4647
print(f"Initing to IP {os.environ['MASTER_ADDR']}")
4748
dist.init_process_group(
4849
backend="nccl", init_method="env://", timeout=timedelta(seconds=36000), rank=rank, world_size=world_size
49-
)
50+
)
5051
dist.barrier()
5152
device = torch.device(f"cuda:{rank}")
5253
return dist, device
5354

55+
5456
def define_instance(args, instance_def_key):
5557
parser = ConfigParser(vars(args))
5658
parser.parse(True)
5759
return parser.get_parsed_content(instance_def_key, instantiate=True)
5860

61+
5962
def add_data_dir2path(list_files, data_dir, fold=None):
6063
new_list_files = copy.deepcopy(list_files)
6164
if fold is not None:
6265
new_list_files_train = []
6366
new_list_files_val = []
6467
for d in new_list_files:
6568
d["image"] = os.path.join(data_dir, d["image"])
66-
69+
6770
if "label" in d:
6871
d["label"] = os.path.join(data_dir, d["label"])
69-
72+
7073
if fold is not None:
7174
if d["fold"] == fold:
7275
new_list_files_val.append(copy.deepcopy(d))
7376
else:
7477
new_list_files_train.append(copy.deepcopy(d))
75-
78+
7679
if fold is not None:
7780
return new_list_files_train, new_list_files_val
7881
else:
7982
return new_list_files, []
8083

8184

8285
def prepare_maisi_controlnet_json_dataloader(
83-
args,
84-
json_data_list,
85-
data_base_dir,
86-
batch_size=1,
87-
fold=0,
88-
cache_rate=0.0,
89-
rank=0,
90-
world_size=1,
86+
args,
87+
json_data_list,
88+
data_base_dir,
89+
batch_size=1,
90+
fold=0,
91+
cache_rate=0.0,
92+
rank=0,
93+
world_size=1,
9194
):
9295
ddp_bool = world_size > 1
9396
if isinstance(json_data_list, list):
@@ -103,37 +106,32 @@ def prepare_maisi_controlnet_json_dataloader(
103106
else:
104107
with open(json_data_list, "r") as f:
105108
json_data = json.load(f)
106-
list_train, list_valid = add_data_dir2path(json_data['training'], data_base_dir, fold)
109+
list_train, list_valid = add_data_dir2path(json_data["training"], data_base_dir, fold)
107110

108111
common_transform = [
109112
LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True),
110113
Orientationd(keys=["label"], axcodes="RAS"),
111114
EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True),
112-
Lambdad(keys='top_region_index', func=lambda x: torch.FloatTensor(x)),
113-
Lambdad(keys='bottom_region_index', func=lambda x: torch.FloatTensor(x)),
114-
Lambdad(keys='spacing', func=lambda x: torch.FloatTensor(x)),
115-
Lambdad(keys='top_region_index', func=lambda x: x * 1e2),
116-
Lambdad(keys='bottom_region_index', func=lambda x: x * 1e2),
117-
Lambdad(keys='spacing', func=lambda x: x * 1e2),
115+
Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)),
116+
Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)),
117+
Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)),
118+
Lambdad(keys="top_region_index", func=lambda x: x * 1e2),
119+
Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2),
120+
Lambdad(keys="spacing", func=lambda x: x * 1e2),
118121
]
119122
train_transforms, val_transforms = Compose(common_transform), Compose(common_transform)
120123

121124
train_loader = None
122-
125+
123126
if ddp_bool:
124127
list_train = partition_dataset(
125128
data=list_train,
126129
shuffle=True,
127130
num_partitions=world_size,
128131
even_divisible=True,
129132
)[rank]
130-
train_ds = CacheDataset(
131-
data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8
132-
)
133-
train_loader = DataLoader(
134-
train_ds, batch_size=batch_size, shuffle=True,
135-
num_workers=8, pin_memory=True
136-
)
133+
train_ds = CacheDataset(data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8)
134+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
137135
if ddp_bool:
138136
list_valid = partition_dataset(
139137
data=list_valid,
@@ -142,12 +140,14 @@ def prepare_maisi_controlnet_json_dataloader(
142140
even_divisible=False,
143141
)[rank]
144142
val_ds = CacheDataset(
145-
data=list_valid, transform=val_transforms, cache_rate=cache_rate, num_workers=8,
146-
)
147-
val_loader = DataLoader(
148-
val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False
143+
data=list_valid,
144+
transform=val_transforms,
145+
cache_rate=cache_rate,
146+
num_workers=8,
149147
)
150-
return train_loader, val_loader
148+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False)
149+
return train_loader, val_loader
150+
151151

152152
def main():
153153
parser = argparse.ArgumentParser(description="PyTorch VAE-GAN training")
@@ -164,7 +164,14 @@ def main():
164164
help="config json file that stores hyper-parameters",
165165
)
166166
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
167-
parser.add_argument("-w", "--weighted_loss_label", nargs='+', default=[], action="store_true", help="list of lables that use weighted loss")
167+
parser.add_argument(
168+
"-w",
169+
"--weighted_loss_label",
170+
nargs="+",
171+
default=[],
172+
action="store_true",
173+
help="list of lables that use weighted loss",
174+
)
168175
parser.add_argument("-l", "--weighted_loss", default=100, type=int, help="loss weight loss for ROI labels")
169176
args = parser.parse_args()
170177

@@ -189,7 +196,7 @@ def main():
189196
setattr(args, k, v)
190197
for k, v in config_dict.items():
191198
setattr(args, k, v)
192-
199+
193200
# initialize tensorboard writer
194201
if rank == 0:
195202
Path(args.tfevent_path).mkdir(parents=True, exist_ok=True)
@@ -199,13 +206,13 @@ def main():
199206
# Step 1: set data loader
200207
train_loader, val_loader = prepare_maisi_controlnet_json_dataloader(
201208
args,
202-
json_data_list = args.json_data_list,
203-
data_base_dir = args.data_base_dir,
209+
json_data_list=args.json_data_list,
210+
data_base_dir=args.data_base_dir,
204211
rank=rank,
205212
world_size=world_size,
206213
batch_size=args.controlnet_train["batch_size"],
207214
cache_rate=args.controlnet_train["cache_rate"],
208-
fold=args.controlnet_train["fold"]
215+
fold=args.controlnet_train["fold"],
209216
)
210217

211218
# Step 2: define diffusion model and controlnet
@@ -235,16 +242,16 @@ def main():
235242
noise_scheduler = define_instance(args, "noise_scheduler")
236243

237244
if ddp_bool:
238-
controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True)
245+
controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True)
239246

240247
# Step 3: training config
241248
optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"])
242249
total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"]
243-
if rank ==0:
250+
if rank == 0:
244251
print(f"total number of training steps: {total_steps}.")
245252

246253
lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0)
247-
254+
248255
# Step 4: training
249256
n_epochs = args.controlnet_train["n_epochs"]
250257
scaler = GradScaler()
@@ -259,52 +266,54 @@ def main():
259266
epoch_loss_ = 0
260267

261268
for step, batch in enumerate(train_loader):
262-
# get image embedding and label mask
269+
# get image embedding and label mask
263270
inputs = batch["image"].to(device)
264271
labels = batch["label"].to(device)
265272
# get coresponding condtions
266-
top_region_index_tensor = batch['top_region_index'].to(device)
267-
bottom_region_index_tensor = batch['bottom_region_index'].to(device)
268-
spacing_tensor = batch['spacing'].to(device)
269-
273+
top_region_index_tensor = batch["top_region_index"].to(device)
274+
bottom_region_index_tensor = batch["bottom_region_index"].to(device)
275+
spacing_tensor = batch["spacing"].to(device)
276+
270277
optimizer.zero_grad(set_to_none=True)
271278

272279
with autocast(enabled=True):
273280
# generate random noise
274281
noise_shape = list(inputs.shape)
275282
noise = torch.randn(noise_shape, dtype=inputs.dtype).to(inputs.device)
276-
283+
277284
# use binary encoding to encode segmentation mask
278285
controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float()
279-
286+
280287
# create timesteps
281288
timesteps = torch.randint(
282-
0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=inputs.device
283-
).long()
284-
289+
0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=inputs.device
290+
).long()
291+
285292
# create noisy latent
286293
noisy_latent = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
287-
294+
288295
# get controlnet output
289296
down_block_res_samples, mid_block_res_sample = controlnet(
290297
x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond
291298
)
292299
# get noise prediction from diffusion unet
293-
noise_pred = unet(x=noisy_latent,
294-
timesteps=timesteps,
295-
top_region_index_tensor=top_region_index_tensor,
296-
bottom_region_index_tensor=bottom_region_index_tensor,
297-
spacing_tensor=spacing_tensor,
298-
down_block_additional_residuals=down_block_res_samples,
299-
mid_block_additional_residual=mid_block_res_sample)
300+
noise_pred = unet(
301+
x=noisy_latent,
302+
timesteps=timesteps,
303+
top_region_index_tensor=top_region_index_tensor,
304+
bottom_region_index_tensor=bottom_region_index_tensor,
305+
spacing_tensor=spacing_tensor,
306+
down_block_additional_residuals=down_block_res_samples,
307+
mid_block_additional_residual=mid_block_res_sample,
308+
)
300309

301310
if args.weighted_loss > 1.0:
302311
weights = torch.ones_like(inputs).to(inputs.device)
303312
roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device)
304313
interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest")
305314
# assign larger weights for ROI (tumor)
306315
for label in args.weighted_loss_label:
307-
roi[interpolate_label==label] = 1
316+
roi[interpolate_label == label] = 1
308317
weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = args.weighted_loss
309318
loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean()
310319
else:
@@ -330,8 +339,7 @@ def main():
330339
n_epochs,
331340
step,
332341
len(train_loader),
333-
lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else optimizer.param_groups[0][
334-
'lr'],
342+
lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else optimizer.param_groups[0]["lr"],
335343
loss.detach().cpu().item(),
336344
time_left,
337345
)
@@ -348,7 +356,6 @@ def main():
348356
tensorboard_writer.add_scalar("train/train_diffusion_loss_epoch", epoch_loss.cpu().item(), total_step)
349357

350358
torch.cuda.empty_cache()
351-
352359

353360

354361
if __name__ == "__main__":

0 commit comments

Comments
 (0)