Skip to content

Commit d0de28e

Browse files
authored
Reduce scale factor and val loss across ranks for DDP (#1461)
Fixes #1458 . ### Description Reduce scale factor and val loss across ranks for DDP ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: Can-Zhao <[email protected]>
1 parent 6d4182f commit d0de28e

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

generative/2d_ldm/train_diffusion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def main():
126126
)
127127
print(f"Scaling factor set to {1/torch.std(z)}")
128128
scale_factor = 1 / torch.std(z)
129+
print(f"Rank {rank}: local scale_factor: {scale_factor}")
130+
if ddp_bool:
131+
dist.barrier()
132+
dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG)
133+
print(f"Rank {rank}: final scale_factor -> {scale_factor}")
129134

130135
# Define Diffusion Model
131136
unet = define_instance(args, "diffusion_def").to(device)
@@ -261,9 +266,15 @@ def main():
261266
timesteps=timesteps,
262267
)
263268
val_loss = F.mse_loss(noise_pred.float(), noise.float())
264-
val_recon_epoch_loss += val_loss.item()
269+
val_recon_epoch_loss += val_loss
265270
val_recon_epoch_loss = val_recon_epoch_loss / (step + 1)
266271

272+
if ddp_bool:
273+
dist.barrier()
274+
dist.all_reduce(val_recon_epoch_loss, op=torch.distributed.ReduceOp.AVG)
275+
276+
val_recon_epoch_loss = val_recon_epoch_loss.item()
277+
267278
# write val loss and save best model
268279
if rank == 0:
269280
tensorboard_writer.add_scalar("val_diffusion_loss", val_recon_epoch_loss, epoch + 1)

generative/3d_ldm/train_diffusion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def main():
127127
)
128128
print(f"Scaling factor set to {1/torch.std(z)}")
129129
scale_factor = 1 / torch.std(z)
130+
print(f"Rank {rank}: local scale_factor: {scale_factor}")
131+
if ddp_bool:
132+
dist.barrier()
133+
dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG)
134+
print(f"Rank {rank}: final scale_factor -> {scale_factor}")
130135

131136
# Define Diffusion Model
132137
unet = define_instance(args, "diffusion_def").to(device)
@@ -243,9 +248,15 @@ def main():
243248
timesteps=timesteps,
244249
)
245250
val_loss = F.mse_loss(noise_pred.float(), noise.float())
246-
val_recon_epoch_loss += val_loss.item()
251+
val_recon_epoch_loss += val_loss
247252
val_recon_epoch_loss = val_recon_epoch_loss / (step + 1)
248253

254+
if ddp_bool:
255+
dist.barrier()
256+
dist.all_reduce(val_recon_epoch_loss, op=torch.distributed.ReduceOp.AVG)
257+
258+
val_recon_epoch_loss = val_recon_epoch_loss.item()
259+
249260
# write val loss and save best model
250261
if rank == 0:
251262
tensorboard_writer.add_scalar("val_diffusion_loss", val_recon_epoch_loss, epoch)

0 commit comments

Comments
 (0)