Skip to content

Update NNUnet to dynUNet in the tutorial #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ This tutorial shows how to leverage `EnsembleEvaluator`, `MeanEnsemble` and `Vot
This notebook is a quick demo for devices, run the Ignite trainer engine on CPU, GPU and multiple GPUs.
#### [nifti_read_example](./nifti_read_example.ipynb)
Illustrate reading NIfTI files and iterating over image patches of the volumes loaded from them.
#### [nnunet_tutorial](./nnunet_tutorial.ipynb)
This tutorial shows how to train 3D segmentation tasks on all the 10 decathlon datasets with the reimplementation of NNUnet in MONAI.
#### [dynunet_tutorial](./dynunet_tutorial.ipynb)
This tutorial shows how to train 3D segmentation tasks on all the 10 decathlon datasets with the reimplementation of dynUNet in MONAI.
#### [post_transforms](./post_transforms.ipynb)
This notebook shows the usage of several post transforms based on the model output of spleen segmentation task.
#### [public_datasets](./public_datasets.ipynb)
Expand Down
18 changes: 9 additions & 9 deletions nnunet_tutorial.ipynb → dynunet_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train NNUNet on Decathlon datasets\n",
"This tutorial shows how to train 3D segmentation tasks on all the 10 decathlon datasets.\n",
"# Train DynUNet on Decathlon datasets\n",
"This tutorial shows how to train 3D segmentation tasks on all the 10 decathlon datasets with `DynUNet`.\n",
"\n",
"Refer to papers:\n",
"\n",
Expand Down Expand Up @@ -90,7 +90,7 @@
"from monai.handlers import MeanDice, StatsHandler\n",
"from monai.inferers import SimpleInferer\n",
"from monai.losses import DiceLoss\n",
"from monai.networks.nets import NNUnet\n",
"from monai.networks.nets import DynUNet\n",
"from monai.transforms import (\n",
" Compose,\n",
" LoadNiftid,\n",
Expand Down Expand Up @@ -118,7 +118,7 @@
"metadata": {},
"source": [
"## Select Decathlon task\n",
"The Decathlon dataset contains 10 tasks, this NNUNet tutorial can support all of them.\n",
"The Decathlon dataset contains 10 tasks, this dynUNet tutorial can support all of them.\n",
"\n",
"Just need to select task ID and other parameters will be automatically selected.\n",
"\n",
Expand Down Expand Up @@ -448,7 +448,7 @@
" strides.append(stride)\n",
"strides.insert(0, len(spacings) * [1])\n",
"kernels.append(len(spacings) * [3])\n",
"net = NNUnet(\n",
"net = DynUNet(\n",
" spatial_dims=3,\n",
" in_channels=in_channels,\n",
" out_channels=n_class,\n",
Expand Down Expand Up @@ -491,7 +491,7 @@
"]\n",
"\n",
"# Define customized evaluator\n",
"class nnUNetEvaluator(SupervisedEvaluator):\n",
"class DynUNetEvaluator(SupervisedEvaluator):\n",
" def _iteration(self, engine, batchdata):\n",
" inputs, targets = self.prepare_batch(batchdata)\n",
" inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)\n",
Expand All @@ -513,7 +513,7 @@
" return {\"image\": inputs, \"label\": targets, \"pred\": predictions}\n",
"\n",
"\n",
"evaluator = nnUNetEvaluator(\n",
"evaluator = DynUNetEvaluator(\n",
" device=device,\n",
" val_data_loader=val_loader,\n",
" network=net,\n",
Expand Down Expand Up @@ -559,7 +559,7 @@
"]\n",
"\n",
"# define customized trainer\n",
"class nnUNetTrainer(SupervisedTrainer):\n",
"class DynUNetTrainer(SupervisedTrainer):\n",
" def _iteration(self, engine, batchdata):\n",
" inputs, targets = self.prepare_batch(batchdata)\n",
" inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)\n",
Expand All @@ -585,7 +585,7 @@
" return {\"image\": inputs, \"label\": targets, \"pred\": predictions, \"loss\": loss.item()}\n",
"\n",
"\n",
"trainer = nnUNetTrainer(\n",
"trainer = DynUNetTrainer(\n",
" device=device,\n",
" max_epochs=max_epochs,\n",
" train_data_loader=train_loader,\n",
Expand Down