Skip to content

Commit c7580d2

Browse files
fix part of type error
Signed-off-by: Yiheng Wang <[email protected]>
1 parent 7b23d85 commit c7580d2

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

generation/image_to_image_translation/tutorial_segmentation_with_ddpm.ipynb

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@
727727
" model.train()\n",
728728
" epoch_loss = 0\n",
729729
"\n",
730-
" for step, data in enumerate(train_loader):\n",
730+
" for _step, data in enumerate(train_loader):\n",
731731
" images = data[\"image\"].to(device)\n",
732732
" seg = data[\"label\"].to(device) # this is the ground truth segmentation\n",
733733
" optimizer.zero_grad(set_to_none=True)\n",
@@ -741,7 +741,8 @@
741741
" ) # we only add noise to the segmentation mask\n",
742742
" combined = torch.cat(\n",
743743
" (images, noisy_seg), dim=1\n",
744-
" ) # we concatenate the brain MR image with the noisy segmenatation mask, to condition the generation process\n",
744+
" # we concatenate the brain MR image with the noisy segmenatation mask, to condition the generation process\n",
745+
" )\n",
745746
" prediction = model(x=combined, timesteps=timesteps)\n",
746747
" # Get model prediction\n",
747748
" loss = F.mse_loss(prediction.float(), noise.float())\n",
@@ -750,11 +751,11 @@
750751
" scaler.update()\n",
751752
" epoch_loss += loss.item()\n",
752753
"\n",
753-
" epoch_loss_list.append(epoch_loss / (step + 1))\n",
754+
" epoch_loss_list.append(epoch_loss / (_step + 1))\n",
754755
" if (epoch) % val_interval == 0:\n",
755756
" model.eval()\n",
756757
" val_epoch_loss = 0\n",
757-
" for step, data_val in enumerate(val_loader):\n",
758+
" for _step, data_val in enumerate(val_loader):\n",
758759
" images = data_val[\"image\"].to(device)\n",
759760
" seg = data_val[\"label\"].to(device) # this is the ground truth segmentation\n",
760761
" timesteps = torch.randint(0, 1000, (len(images),)).to(device)\n",
@@ -766,8 +767,8 @@
766767
" prediction = model(x=combined, timesteps=timesteps)\n",
767768
" val_loss = F.mse_loss(prediction.float(), noise.float())\n",
768769
" val_epoch_loss += val_loss.item()\n",
769-
" print(\"Epoch\", epoch, \"Validation loss\", val_epoch_loss / (step + 1))\n",
770-
" val_epoch_loss_list.append(val_epoch_loss / (step + 1))\n",
770+
" print(\"Epoch\", epoch, \"Validation loss\", val_epoch_loss / (_step + 1))\n",
771+
" val_epoch_loss_list.append(val_epoch_loss / (_step + 1))\n",
771772
"\n",
772773
"torch.save(model.state_dict(), \"./segmodel.pt\")\n",
773774
"total_time = time.time() - total_start\n",
@@ -1222,7 +1223,7 @@
12221223
"n = 5\n",
12231224
"input_img = inputimg[None, None, ...].to(device)\n",
12241225
"ensemble = []\n",
1225-
"for k in range(5):\n",
1226+
"for _ in range(5):\n",
12261227
" noise = torch.randn_like(input_img).to(device)\n",
12271228
" current_img = noise # for the segmentation mask, we start from random noise.\n",
12281229
" combined = torch.cat(\n",
@@ -1243,7 +1244,8 @@
12431244
" chain = torch.cat((chain, current_img.cpu()), dim=-1)\n",
12441245
" combined = torch.cat(\n",
12451246
" (input_img, current_img), dim=1\n",
1246-
" ) # in every step during the denoising process, the brain MR image is concatenated to add anatomical information\n",
1247+
" # in every step during the denoising process, the brain MR image is concatenated to add anatomical information\n",
1248+
" )\n",
12471249
"\n",
12481250
" plt.style.use(\"default\")\n",
12491251
" plt.imshow(chain[0, 0, ..., 64:].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n",

0 commit comments

Comments
 (0)