|
129 | 129 | "from monai.networks.blocks import Warp\n",
|
130 | 130 | "from monai.networks.nets import LocalNet\n",
|
131 | 131 | "from monai.transforms import (\n",
|
132 |
| - " AddChanneld,\n", |
133 | 132 | " Compose,\n",
|
134 | 133 | " LoadImaged,\n",
|
135 | 134 | " RandAffined,\n",
|
136 | 135 | " Resized,\n",
|
137 | 136 | " ScaleIntensityRanged,\n",
|
138 |
| - " EnsureTyped,\n", |
139 | 137 | ")\n",
|
140 | 138 | "from monai.utils import set_determinism, first\n",
|
141 | 139 | "\n",
|
|
302 | 300 | "source": [
|
303 | 301 | "## Setup transforms for training and validation\n",
|
304 | 302 | "Here we use several transforms to augment the dataset:\n",
|
305 |
| - "1. LoadImaged loads the lung CT images and labels from NIfTI format files.\n", |
306 |
| - "2. AddChanneld as the original data doesn't have channel dim, add 1 dim to construct \"channel first\" shape.\n", |
307 |
| - "3. ScaleIntensityRanged extracts intensity range [-285, 3770] and scales to [0, 1].\n", |
308 |
| - "4. RandAffined efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform.\n", |
309 |
| - "5. Resized resize images to the same size.\n", |
310 |
| - "6. EnsureTyped converts the numpy array to PyTorch Tensor for further steps." |
| 303 | + "1. LoadImaged loads the lung CT images and labels from NIfTI format files. \"ensure_channel_first=True\" ensure that the first dim is channel.\n", |
| 304 | + "2. ScaleIntensityRanged extracts intensity range [-285, 3770] and scales to [0, 1].\n", |
| 305 | + "3. RandAffined efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform.\n", |
| 306 | + "4. Resized resize images to the same size." |
311 | 307 | ]
|
312 | 308 | },
|
313 | 309 | {
|
|
324 | 320 | "train_transforms = Compose(\n",
|
325 | 321 | " [\n",
|
326 | 322 | " LoadImaged(\n",
|
327 |
| - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
328 |
| - " ),\n", |
329 |
| - " AddChanneld(\n", |
330 |
| - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
| 323 | + " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"],\n", |
| 324 | + " ensure_channel_first=True\n", |
331 | 325 | " ),\n",
|
332 | 326 | " ScaleIntensityRanged(\n",
|
333 | 327 | " keys=[\"fixed_image\", \"moving_image\"],\n",
|
|
345 | 339 | " align_corners=(True, True, None, None),\n",
|
346 | 340 | " spatial_size=(96, 96, 104)\n",
|
347 | 341 | " ),\n",
|
348 |
| - " EnsureTyped(\n", |
349 |
| - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
350 |
| - " ),\n", |
351 | 342 | " ]\n",
|
352 | 343 | ")\n",
|
353 | 344 | "val_transforms = Compose(\n",
|
354 | 345 | " [\n",
|
355 | 346 | " LoadImaged(\n",
|
356 |
| - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
357 |
| - " ),\n", |
358 |
| - " AddChanneld(\n", |
359 |
| - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
| 347 | + " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"],\n", |
| 348 | + " ensure_channel_first=True\n", |
360 | 349 | " ),\n",
|
361 | 350 | " ScaleIntensityRanged(\n",
|
362 | 351 | " keys=[\"fixed_image\", \"moving_image\"],\n",
|
|
369 | 358 | " align_corners=(True, True, None, None),\n",
|
370 | 359 | " spatial_size=(96, 96, 104)\n",
|
371 | 360 | " ),\n",
|
372 |
| - " EnsureTyped(\n", |
373 |
| - " keys=[\"fixed_image\", \"moving_image\", \"fixed_label\", \"moving_label\"]\n", |
374 |
| - " ),\n", |
375 | 361 | " ]\n",
|
376 | 362 | ")"
|
377 | 363 | ]
|
|
693 | 679 | "\n",
|
694 | 680 | " val_ddf, val_pred_image, val_pred_label = forward(\n",
|
695 | 681 | " val_data, model)\n",
|
| 682 | + " val_pred_label[val_pred_label > 1] = 1\n", |
696 | 683 | "\n",
|
697 | 684 | " val_fixed_image = val_data[\"fixed_image\"].to(device)\n",
|
698 | 685 | " val_fixed_label = val_data[\"fixed_label\"].to(device)\n",
|
|
723 | 710 | " optimizer.zero_grad()\n",
|
724 | 711 | "\n",
|
725 | 712 | " ddf, pred_image, pred_label = forward(batch_data, model)\n",
|
| 713 | + " pred_label[pred_label > 1] = 1\n", |
726 | 714 | "\n",
|
727 | 715 | " fixed_image = batch_data[\"fixed_image\"].to(device)\n",
|
728 | 716 | " fixed_label = batch_data[\"fixed_label\"].to(device)\n",
|
|
1281 | 1269 | "name": "python",
|
1282 | 1270 | "nbconvert_exporter": "python",
|
1283 | 1271 | "pygments_lexer": "ipython3",
|
1284 |
| - "version": "3.8.12" |
| 1272 | + "version": "3.8.13" |
1285 | 1273 | }
|
1286 | 1274 | },
|
1287 | 1275 | "nbformat": 4,
|
|
0 commit comments