|
727 | 727 | " model.train()\n",
|
728 | 728 | " epoch_loss = 0\n",
|
729 | 729 | "\n",
|
730 |
| - " for step, data in enumerate(train_loader):\n", |
| 730 | + " for _step, data in enumerate(train_loader):\n", |
731 | 731 | " images = data[\"image\"].to(device)\n",
|
732 | 732 | " seg = data[\"label\"].to(device) # this is the ground truth segmentation\n",
|
733 | 733 | " optimizer.zero_grad(set_to_none=True)\n",
|
|
741 | 741 | " ) # we only add noise to the segmentation mask\n",
|
742 | 742 | " combined = torch.cat(\n",
|
743 | 743 | " (images, noisy_seg), dim=1\n",
|
744 |
| - " ) # we concatenate the brain MR image with the noisy segmenatation mask, to condition the generation process\n", |
| 744 | + " # we concatenate the brain MR image with the noisy segmenatation mask, to condition the generation process\n", |
| 745 | + " )\n", |
745 | 746 | " prediction = model(x=combined, timesteps=timesteps)\n",
|
746 | 747 | " # Get model prediction\n",
|
747 | 748 | " loss = F.mse_loss(prediction.float(), noise.float())\n",
|
|
750 | 751 | " scaler.update()\n",
|
751 | 752 | " epoch_loss += loss.item()\n",
|
752 | 753 | "\n",
|
753 |
| - " epoch_loss_list.append(epoch_loss / (step + 1))\n", |
| 754 | + " epoch_loss_list.append(epoch_loss / (_step + 1))\n", |
754 | 755 | " if (epoch) % val_interval == 0:\n",
|
755 | 756 | " model.eval()\n",
|
756 | 757 | " val_epoch_loss = 0\n",
|
757 |
| - " for step, data_val in enumerate(val_loader):\n", |
| 758 | + " for _step, data_val in enumerate(val_loader):\n", |
758 | 759 | " images = data_val[\"image\"].to(device)\n",
|
759 | 760 | " seg = data_val[\"label\"].to(device) # this is the ground truth segmentation\n",
|
760 | 761 | " timesteps = torch.randint(0, 1000, (len(images),)).to(device)\n",
|
|
766 | 767 | " prediction = model(x=combined, timesteps=timesteps)\n",
|
767 | 768 | " val_loss = F.mse_loss(prediction.float(), noise.float())\n",
|
768 | 769 | " val_epoch_loss += val_loss.item()\n",
|
769 |
| - " print(\"Epoch\", epoch, \"Validation loss\", val_epoch_loss / (step + 1))\n", |
770 |
| - " val_epoch_loss_list.append(val_epoch_loss / (step + 1))\n", |
| 770 | + " print(\"Epoch\", epoch, \"Validation loss\", val_epoch_loss / (_step + 1))\n", |
| 771 | + " val_epoch_loss_list.append(val_epoch_loss / (_step + 1))\n", |
771 | 772 | "\n",
|
772 | 773 | "torch.save(model.state_dict(), \"./segmodel.pt\")\n",
|
773 | 774 | "total_time = time.time() - total_start\n",
|
|
1222 | 1223 | "n = 5\n",
|
1223 | 1224 | "input_img = inputimg[None, None, ...].to(device)\n",
|
1224 | 1225 | "ensemble = []\n",
|
1225 |
| - "for k in range(5):\n", |
| 1226 | + "for _ in range(5):\n", |
1226 | 1227 | " noise = torch.randn_like(input_img).to(device)\n",
|
1227 | 1228 | " current_img = noise # for the segmentation mask, we start from random noise.\n",
|
1228 | 1229 | " combined = torch.cat(\n",
|
|
1243 | 1244 | " chain = torch.cat((chain, current_img.cpu()), dim=-1)\n",
|
1244 | 1245 | " combined = torch.cat(\n",
|
1245 | 1246 | " (input_img, current_img), dim=1\n",
|
1246 |
| - " ) # in every step during the denoising process, the brain MR image is concatenated to add anatomical information\n", |
| 1247 | + " # in every step during the denoising process, the brain MR image is concatenated to add anatomical information\n", |
| 1248 | + " )\n", |
1247 | 1249 | "\n",
|
1248 | 1250 | " plt.style.use(\"default\")\n",
|
1249 | 1251 | " plt.imshow(chain[0, 0, ..., 64:].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n",
|
|
0 commit comments