Skip to content

Commit 1ba81a3

Browse files
author
Yuchen Xu
committed
modification done, pending final checks
1 parent 362354a commit 1ba81a3

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

acceleration/automatic_mixed_precision.ipynb

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
"import torch\n",
7272
"from monai.apps import download_and_extract\n",
7373
"from monai.config import print_config\n",
74-
"from monai.data import CacheDataset, list_data_collate, decollate_batch\n",
74+
"from monai.data import DataLoader, CacheDataset, list_data_collate, decollate_batch\n",
7575
"from monai.inferers import sliding_window_inference\n",
7676
"from monai.losses import DiceLoss\n",
7777
"from monai.metrics import DiceMetric\n",
@@ -90,7 +90,6 @@
9090
" Spacingd,\n",
9191
")\n",
9292
"from monai.utils import get_torch_version_tuple, set_determinism\n",
93-
"from torch.utils.data import DataLoader\n",
9493
"\n",
9594
"print_config()\n",
9695
"\n",
@@ -286,7 +285,6 @@
286285
" batch_size=2,\n",
287286
" shuffle=True,\n",
288287
" num_workers=1,\n",
289-
" collate_fn=list_data_collate,\n",
290288
" )\n",
291289
" val_loader = DataLoader(val_ds, batch_size=1, num_workers=1)\n",
292290
" device = torch.device(\"cuda:0\")\n",
@@ -373,8 +371,8 @@
373371
" val_outputs = sliding_window_inference(\n",
374372
" val_inputs, roi_size, sw_batch_size, model\n",
375373
" )\n",
376-
" val_outputs = [post_pred(i) for i in decollate_batch(val_outputs.as_tensor())]\n",
377-
" val_labels = [post_label(i) for i in decollate_batch(val_labels.as_tensor())]\n",
374+
" val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]\n",
375+
" val_labels = [post_label(i) for i in decollate_batch(val_labels)]\n",
378376
" dice_metric(y_pred=val_outputs, y=val_labels)\n",
379377
"\n",
380378
" metric = dice_metric.aggregate().item()\n",

acceleration/fast_training_tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,8 @@
636636
"\n",
637637
" # profiling: decollate batch\n",
638638
" with nvtx.annotate(\"decollate batch\", color=\"blue\") if profiling else no_profiling:\n",
639-
" val_outputs = [post_pred(i) for i in decollate_batch(val_outputs.as_tensor())]\n",
640-
" val_labels = [post_label(i) for i in decollate_batch(val_labels.as_tensor())]\n",
639+
" val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]\n",
640+
" val_labels = [post_label(i) for i in decollate_batch(val_labels)]\n",
641641
"\n",
642642
" # profiling: compute metric\n",
643643
" with nvtx.annotate(\"compute metric\", color=\"yellow\") if profiling else no_profiling:\n",

0 commit comments

Comments
 (0)