|
1 | 1 | {
|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 |
| - "cell_type": "code", |
5 |
| - "execution_count": null, |
| 4 | + "cell_type": "markdown", |
6 | 5 | "metadata": {},
|
7 |
| - "outputs": [], |
8 | 6 | "source": [
|
9 |
| - "# Copyright (c) MONAI Consortium\n", |
10 |
| - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", |
11 |
| - "# you may not use this file except in compliance with the License.\n", |
12 |
| - "# You may obtain a copy of the License at\n", |
13 |
| - "# http://www.apache.org/licenses/LICENSE-2.0\n", |
14 |
| - "# Unless required by applicable law or agreed to in writing, software\n", |
15 |
| - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", |
16 |
| - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", |
17 |
| - "# See the License for the specific language governing permissions and\n", |
18 |
| - "# limitations under the License." |
| 7 | + "Copyright (c) MONAI Consortium \n", |
| 8 | + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", |
| 9 | + "you may not use this file except in compliance with the License. \n", |
| 10 | + "You may obtain a copy of the License at \n", |
| 11 | + " http://www.apache.org/licenses/LICENSE-2.0 \n", |
| 12 | + "Unless required by applicable law or agreed to in writing, software \n", |
| 13 | + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", |
| 14 | + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", |
| 15 | + "See the License for the specific language governing permissions and \n", |
| 16 | + "limitations under the License." |
19 | 17 | ]
|
20 | 18 | },
|
21 | 19 | {
|
|
210 | 208 | " section=\"training\", # validation\n",
|
211 | 209 | " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n",
|
212 | 210 | " num_workers=4,\n",
|
213 |
| - " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", |
| 211 | + " download=True, # Set download to True if the dataset hasnt been downloaded yet\n", |
214 | 212 | " seed=0,\n",
|
215 | 213 | " transform=train_transforms,\n",
|
216 | 214 | ")\n",
|
|
612 | 610 | "metadata": {},
|
613 | 611 | "outputs": [],
|
614 | 612 | "source": [
|
615 |
| - "n_epochs = 4000\n", |
| 613 | + "max_epochs = 4000\n", |
616 | 614 | "val_interval = 50\n",
|
617 | 615 | "epoch_loss_list = []\n",
|
618 | 616 | "val_epoch_loss_list = []"
|
|
725 | 723 | "scaler = GradScaler()\n",
|
726 | 724 | "total_start = time.time()\n",
|
727 | 725 | "\n",
|
728 |
| - "for epoch in range(n_epochs):\n", |
| 726 | + "for epoch in range(max_epochs):\n", |
729 | 727 | " model.train()\n",
|
730 | 728 | " epoch_loss = 0\n",
|
731 | 729 | "\n",
|
|
776 | 774 | "print(f\"train diffusion completed, total time: {total_time}.\")\n",
|
777 | 775 | "\n",
|
778 | 776 | "plt.title(\"Learning Curves Diffusion Model\", fontsize=20)\n",
|
779 |
| - "plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n", |
| 777 | + "plt.plot(np.linspace(1, max_epochs, max_epochs), epoch_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n", |
780 | 778 | "plt.plot(\n",
|
781 |
| - " np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),\n", |
| 779 | + " np.linspace(val_interval, max_epochs, int(max_epochs / val_interval)),\n", |
782 | 780 | " val_epoch_loss_list,\n",
|
783 | 781 | " color=\"C1\",\n",
|
784 | 782 | " linewidth=2.0,\n",
|
|
0 commit comments