|
71 | 71 | "import torch\n",
|
72 | 72 | "from monai.apps import download_and_extract\n",
|
73 | 73 | "from monai.config import print_config\n",
|
74 |
| - "from monai.data import CacheDataset, list_data_collate, decollate_batch\n", |
| 74 | + "from monai.data import DataLoader, CacheDataset, list_data_collate, decollate_batch\n", |
75 | 75 | "from monai.inferers import sliding_window_inference\n",
|
76 | 76 | "from monai.losses import DiceLoss\n",
|
77 | 77 | "from monai.metrics import DiceMetric\n",
|
|
90 | 90 | " Spacingd,\n",
|
91 | 91 | ")\n",
|
92 | 92 | "from monai.utils import get_torch_version_tuple, set_determinism\n",
|
93 |
| - "from torch.utils.data import DataLoader\n", |
94 | 93 | "\n",
|
95 | 94 | "print_config()\n",
|
96 | 95 | "\n",
|
|
286 | 285 | " batch_size=2,\n",
|
287 | 286 | " shuffle=True,\n",
|
288 | 287 | " num_workers=1,\n",
|
289 |
| - " collate_fn=list_data_collate,\n", |
290 | 288 | " )\n",
|
291 | 289 | " val_loader = DataLoader(val_ds, batch_size=1, num_workers=1)\n",
|
292 | 290 | " device = torch.device(\"cuda:0\")\n",
|
|
373 | 371 | " val_outputs = sliding_window_inference(\n",
|
374 | 372 | " val_inputs, roi_size, sw_batch_size, model\n",
|
375 | 373 | " )\n",
|
376 |
| - " val_outputs = [post_pred(i) for i in decollate_batch(val_outputs.as_tensor())]\n", |
377 |
| - " val_labels = [post_label(i) for i in decollate_batch(val_labels.as_tensor())]\n", |
| 374 | + " val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]\n", |
| 375 | + " val_labels = [post_label(i) for i in decollate_batch(val_labels)]\n", |
378 | 376 | " dice_metric(y_pred=val_outputs, y=val_labels)\n",
|
379 | 377 | "\n",
|
380 | 378 | " metric = dice_metric.aggregate().item()\n",
|
|
0 commit comments