Skip to content

Commit ddec44f

Browse files
committed
TTA
1 parent ba4c709 commit ddec44f

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

modules/inverse_transforms_and_test_time_augmentations.ipynb

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"\n",
1212
"### What are transforms?\n",
1313
"\n",
14-
"- We use transforms to modify data. In MONAI, we use them to (for exampl) load images from file, add a channel component, normalise the intensities and reshape the image.\n",
14+
"- We use transforms to modify data. In MONAI, we use them to (for example) load images from file, add a channel component, normalise the intensities and reshape the image.\n",
1515
"- We can also use transforms as a method of data augmentation – we have a finite amount of data so to avoid overfitting, we can apply random transforms to modify our data each epoch.\n",
1616
"- Examples of random transformations might be randomly flipping, rotating, cropping, padding, zooming, as well as applying non-rigid deformations.\n",
1717
"\n",
@@ -229,8 +229,8 @@
229229
" def __call__(self, data):\n",
230230
" d = dict(data)\n",
231231
" im = d[self.label_key]\n",
232-
" q = np.sum((im > 0).reshape(-1, im.shape[-1]), axis=0)\n",
233-
" _slice = np.where(q == np.max(q))[0][0]\n",
232+
" q = (im > 0).reshape(-1, im.shape[-1]).sum(dim=0)\n",
233+
" _slice = q.argmax(dim=0)\n",
234234
" for key in self.keys:\n",
235235
" d[key] = d[key][..., _slice]\n",
236236
" return d\n",
@@ -244,10 +244,9 @@
244244
" def __call__(self, data):\n",
245245
" d = {}\n",
246246
" for key in self.keys:\n",
247-
" fname = os.path.basename(\n",
248-
" data[key + \"_meta_dict\"][\"filename_or_obj\"])\n",
247+
" fname = os.path.basename(data[key].meta[\"filename_or_obj\"])\n",
249248
" path = os.path.join(self.path, key, fname)\n",
250-
" nib.save(nib.Nifti1Image(data[key], np.eye(4)), path)\n",
249+
" nib.save(nib.Nifti1Image(data[key].numpy(), np.eye(4)), path)\n",
251250
" d[key] = path\n",
252251
" return d\n",
253252
"\n",
@@ -258,12 +257,12 @@
258257
" os.makedirs(os.path.join(data_dir, key), exist_ok=True)\n",
259258
"transform_2d_slice = Compose([\n",
260259
" LoadImaged(keys),\n",
261-
" FromMetaTensord(keys),\n",
262260
" AsChannelFirstd(\"image\"),\n",
263261
" AddChanneld(\"label\"),\n",
264262
" SliceWithMaxNumLabelsd(keys, \"label\"),\n",
265263
" SaveSliced(keys, data_dir),\n",
266264
"])\n",
265+
"\n",
267266
"# Running the whole way through the dataset will create the 2D slices and save to file\n",
268267
"ds_2d = Dataset(data_dicts, transform_2d_slice)\n",
269268
"dl_2d = DataLoader(ds_2d, batch_size=1, num_workers=10)\n",
@@ -306,8 +305,8 @@
306305
" [\n",
307306
" LoadImaged(keys),\n",
308307
" FromMetaTensord(keys),\n",
309-
" Lambdad(\"label\", lambda x: (x > 0).astype(\n",
310-
" np.float64)), # make label binary\n",
308+
" Lambdad(\"label\", lambda x: (x > 0).to(\n",
309+
" torch.float32)), # make label binary\n",
311310
" RandAffined(\n",
312311
" keys,\n",
313312
" prob=1.0,\n",
@@ -1627,8 +1626,8 @@
16271626
"minimal_transforms = Compose([\n",
16281627
" LoadImaged(keys),\n",
16291628
" FromMetaTensord(keys),\n",
1630-
" Lambdad(\"label\", lambda x: (x > 0).astype(\n",
1631-
" np.float64)), # make label binary\n",
1629+
" Lambdad(\"label\", lambda x: (x > 0).to(\n",
1630+
" torch.float32)), # make label binary\n",
16321631
" ScaleIntensityd(\"image\"),\n",
16331632
" EnsureTyped(keys),\n",
16341633
"])\n",

0 commit comments

Comments
 (0)