Skip to content

Commit e7b317a

Browse files
authored
update 3d_registration tutorial (#806)
1 parent f329028 commit e7b317a

File tree

1 file changed

+11
-23
lines changed

1 file changed

+11
-23
lines changed

3d_registration/paired_lung_ct.ipynb

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,11 @@
129129
"from monai.networks.blocks import Warp\n",
130130
"from monai.networks.nets import LocalNet\n",
131131
"from monai.transforms import (\n",
132-
" AddChanneld,\n",
133132
" Compose,\n",
134133
" LoadImaged,\n",
135134
" RandAffined,\n",
136135
" Resized,\n",
137136
" ScaleIntensityRanged,\n",
138-
" EnsureTyped,\n",
139137
")\n",
140138
"from monai.utils import set_determinism, first\n",
141139
"\n",
@@ -302,12 +300,10 @@
302300
"source": [
303301
"## Setup transforms for training and validation\n",
304302
"Here we use several transforms to augment the dataset:\n",
305-
"1. LoadImaged loads the lung CT images and labels from NIfTI format files.\n",
306-
"2. AddChanneld as the original data doesn't have channel dim, add 1 dim to construct \"channel first\" shape.\n",
307-
"3. ScaleIntensityRanged extracts intensity range [-285, 3770] and scales to [0, 1].\n",
308-
"4. RandAffined efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform.\n",
309-
"5. Resized resize images to the same size.\n",
310-
"6. EnsureTyped converts the numpy array to PyTorch Tensor for further steps."
303+
"1. LoadImaged loads the lung CT images and labels from NIfTI format files. \"ensure_channel_first=True\" ensure that the first dim is channel.\n",
304+
"2. ScaleIntensityRanged extracts intensity range [-285, 3770] and scales to [0, 1].\n",
305+
"3. RandAffined efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform.\n",
306+
"4. Resized resize images to the same size."
311307
]
312308
},
313309
{
@@ -324,10 +320,8 @@
324320
"train_transforms = Compose(\n",
325321
" [\n",
326322
" LoadImaged(\n",
327-
" keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n",
328-
" ),\n",
329-
" AddChanneld(\n",
330-
" keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n",
323+
" keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"],\n",
324+
" ensure_channel_first=True\n",
331325
" ),\n",
332326
" ScaleIntensityRanged(\n",
333327
" keys=[\"fixed_image\", \"moving_image\"],\n",
@@ -345,18 +339,13 @@
345339
" align_corners=(True, True, None, None),\n",
346340
" spatial_size=(96, 96, 104)\n",
347341
" ),\n",
348-
" EnsureTyped(\n",
349-
" keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n",
350-
" ),\n",
351342
" ]\n",
352343
")\n",
353344
"val_transforms = Compose(\n",
354345
" [\n",
355346
" LoadImaged(\n",
356-
" keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n",
357-
" ),\n",
358-
" AddChanneld(\n",
359-
" keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n",
347+
" keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"],\n",
348+
" ensure_channel_first=True\n",
360349
" ),\n",
361350
" ScaleIntensityRanged(\n",
362351
" keys=[\"fixed_image\", \"moving_image\"],\n",
@@ -369,9 +358,6 @@
369358
" align_corners=(True, True, None, None),\n",
370359
" spatial_size=(96, 96, 104)\n",
371360
" ),\n",
372-
" EnsureTyped(\n",
373-
" keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n",
374-
" ),\n",
375361
" ]\n",
376362
")"
377363
]
@@ -693,6 +679,7 @@
693679
"\n",
694680
" val_ddf, val_pred_image, val_pred_label = forward(\n",
695681
" val_data, model)\n",
682+
" val_pred_label[val_pred_label > 1] = 1\n",
696683
"\n",
697684
" val_fixed_image = val_data[\"fixed_image\"].to(device)\n",
698685
" val_fixed_label = val_data[\"fixed_label\"].to(device)\n",
@@ -723,6 +710,7 @@
723710
" optimizer.zero_grad()\n",
724711
"\n",
725712
" ddf, pred_image, pred_label = forward(batch_data, model)\n",
713+
" pred_label[pred_label > 1] = 1\n",
726714
"\n",
727715
" fixed_image = batch_data[\"fixed_image\"].to(device)\n",
728716
" fixed_label = batch_data[\"fixed_label\"].to(device)\n",
@@ -1281,7 +1269,7 @@
12811269
"name": "python",
12821270
"nbconvert_exporter": "python",
12831271
"pygments_lexer": "ipython3",
1284-
"version": "3.8.12"
1272+
"version": "3.8.13"
12851273
}
12861274
},
12871275
"nbformat": 4,

0 commit comments

Comments
 (0)