|
4 | 4 | "cell_type": "markdown",
|
5 | 5 | "metadata": {},
|
6 | 6 | "source": [
|
7 |
| - "# Train NNUNet on Decathlon datasets\n", |
8 |
| - "This tutorial shows how to train 3D segmentation tasks on all the 10 decathlon datasets.\n", |
| 7 | + "# Train DynUNet on Decathlon datasets\n", |
| 8 | + "This tutorial shows how to train 3D segmentation tasks on all the 10 decathlon datasets with `DynUNet`.\n", |
9 | 9 | "\n",
|
10 | 10 | "Refer to papers:\n",
|
11 | 11 | "\n",
|
|
90 | 90 | "from monai.handlers import MeanDice, StatsHandler\n",
|
91 | 91 | "from monai.inferers import SimpleInferer\n",
|
92 | 92 | "from monai.losses import DiceLoss\n",
|
93 |
| - "from monai.networks.nets import NNUnet\n", |
| 93 | + "from monai.networks.nets import DynUNet\n", |
94 | 94 | "from monai.transforms import (\n",
|
95 | 95 | " Compose,\n",
|
96 | 96 | " LoadNiftid,\n",
|
|
118 | 118 | "metadata": {},
|
119 | 119 | "source": [
|
120 | 120 | "## Select Decathlon task\n",
|
121 |
| - "The Decathlon dataset contains 10 tasks, this NNUNet tutorial can support all of them.\n", |
| 121 | + "The Decathlon dataset contains 10 tasks, this dynUNet tutorial can support all of them.\n", |
122 | 122 | "\n",
|
123 | 123 | "Just need to select task ID and other parameters will be automatically selected.\n",
|
124 | 124 | "\n",
|
|
448 | 448 | " strides.append(stride)\n",
|
449 | 449 | "strides.insert(0, len(spacings) * [1])\n",
|
450 | 450 | "kernels.append(len(spacings) * [3])\n",
|
451 |
| - "net = NNUnet(\n", |
| 451 | + "net = DynUNet(\n", |
452 | 452 | " spatial_dims=3,\n",
|
453 | 453 | " in_channels=in_channels,\n",
|
454 | 454 | " out_channels=n_class,\n",
|
|
491 | 491 | "]\n",
|
492 | 492 | "\n",
|
493 | 493 | "# Define customized evaluator\n",
|
494 |
| - "class nnUNetEvaluator(SupervisedEvaluator):\n", |
| 494 | + "class DynUNetEvaluator(SupervisedEvaluator):\n", |
495 | 495 | " def _iteration(self, engine, batchdata):\n",
|
496 | 496 | " inputs, targets = self.prepare_batch(batchdata)\n",
|
497 | 497 | " inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)\n",
|
|
513 | 513 | " return {\"image\": inputs, \"label\": targets, \"pred\": predictions}\n",
|
514 | 514 | "\n",
|
515 | 515 | "\n",
|
516 |
| - "evaluator = nnUNetEvaluator(\n", |
| 516 | + "evaluator = DynUNetEvaluator(\n", |
517 | 517 | " device=device,\n",
|
518 | 518 | " val_data_loader=val_loader,\n",
|
519 | 519 | " network=net,\n",
|
|
559 | 559 | "]\n",
|
560 | 560 | "\n",
|
561 | 561 | "# define customized trainer\n",
|
562 |
| - "class nnUNetTrainer(SupervisedTrainer):\n", |
| 562 | + "class DynUNetTrainer(SupervisedTrainer):\n", |
563 | 563 | " def _iteration(self, engine, batchdata):\n",
|
564 | 564 | " inputs, targets = self.prepare_batch(batchdata)\n",
|
565 | 565 | " inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)\n",
|
|
585 | 585 | " return {\"image\": inputs, \"label\": targets, \"pred\": predictions, \"loss\": loss.item()}\n",
|
586 | 586 | "\n",
|
587 | 587 | "\n",
|
588 |
| - "trainer = nnUNetTrainer(\n", |
| 588 | + "trainer = DynUNetTrainer(\n", |
589 | 589 | " device=device,\n",
|
590 | 590 | " max_epochs=max_epochs,\n",
|
591 | 591 | " train_data_loader=train_loader,\n",
|
|
0 commit comments