Skip to content

Commit 590461a

Browse files
authored
A tutorial to use randomised permutation during training (#807)
* A tutorial to use randomised permutation during training
1 parent e7b317a commit 590461a

File tree

7 files changed

+308
-0
lines changed

7 files changed

+308
-0
lines changed

modules/generate_random_permutations/Creating dataset with randomized transform chain.ipynb

Lines changed: 197 additions & 0 deletions
Large diffs are not rendered by default.
Loading
Loading
Loading
Loading
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import numpy as np
2+
import os
3+
import torch
4+
import glob
5+
import random
6+
from monai.data import DataLoader
7+
from monai.transforms.transform import Transform
8+
from monai.transforms import (Affine, LoadImage, Rotate, NormalizeIntensity, Transpose, Compose, Resize, AsChannelFirst, AsChannelLast, ScaleIntensity, RandFlip, Rotate90, AddChannel, GaussianSmooth, AdjustContrast)
9+
from random import shuffle
10+
11+
class Dataset(torch.utils.data.Dataset):
12+
def __init__(self, image_file_list, transforms, shuffle_transforms=1):
13+
self.image_file_list = image_file_list
14+
if shuffle_transforms:
15+
transform_list = [LoadImage(image_only=True), AddChannel(), Resize((299, 299))] + shuffle(transforms)
16+
self.transform = Compose(transpose_list)
17+
else:
18+
self.transform = Compose([LoadImage(image_only=True), AddChannel(), Resize((299, 299))] + transforms)
19+
20+
def __len__(self):
21+
return len(self.image_file_list)
22+
23+
def __getitem__(self, index):
24+
return self.transform(self.image_file_list[index])
25+
26+
27+
class AugmentData(object):
28+
def __init__(self, image_loading_transforms = [LoadImage(image_only=True)], augmentation_dict = {}, num_augmentations=5, output_size=(200, 200), batch_size=3):
29+
self.output_size = output_size
30+
self.batch_size = batch_size
31+
self.augmentation_dict = augmentation_dict
32+
self.aug_seq = self.create_augmentation_sequence()
33+
self.image_loading_transforms = image_loading_transforms
34+
self.num_augmentations = num_augmentations
35+
36+
def create_augmentation_sequence(self):
37+
augmentation_transforms = []
38+
for aug, num_aug in self.augmentation_dict.items():
39+
_x = [aug]*num_aug
40+
augmentation_transforms = augmentation_transforms + _x
41+
return augmentation_transforms
42+
43+
44+
def create_transform_list(self, augmentation_sequence):
45+
transform_list = self.image_loading_transforms
46+
for _aug in augmentation_sequence:
47+
if _aug == 'rotate':
48+
transform_list.append(Rotate(random.randint(0, 100)))
49+
if _aug == 'flip':
50+
transform_list.append(RandFlip())
51+
if _aug == 'rotate90':
52+
transform_list.append(Rotate90())
53+
if _aug == 'intensityGaussian':
54+
transform_list.append(GaussianSmooth(sigma=random.randint(0, 10)))
55+
if _aug == 'adjustContrast':
56+
transform_list.append(AdjustContrast(gamma=random.randint(0, 10)))
57+
58+
transform_list.append(ScaleIntensity())
59+
transform_list.append(Resize(self.output_size))
60+
return transform_list
61+
62+
63+
def create_native_transform_list(self):
64+
transform_list = Compose(self.image_loading_transforms + [ScaleIntensity(), Resize(self.output_size)])
65+
return transform_list
66+
67+
68+
def __call__(self, image_file_list, *args, **kwargs):
69+
image_file_list = image_file_list
70+
71+
IMG = []
72+
for img in zip(image_file_list):
73+
native_transform_list = self.create_native_transform_list()
74+
native_img = native_transform_list(img)
75+
IMG = IMG + native_img
76+
for i in range(self.num_augmentations):
77+
shuffle(self.aug_seq)
78+
transform_list = self.create_transform_list(self.aug_seq)
79+
img_augmentated = Compose(transform_list)(img)
80+
IMG = IMG + img_augmentated
81+
82+
random.shuffle(IMG)
83+
ALLIMG_NP = np.stack(IMG, axis=0)
84+
OUT_IMAGE_NP = ALLIMG_NP[0:self.batch_size, :]
85+
return OUT_IMAGE_NP
86+
87+
88+
89+
def main():
90+
91+
image_dir='./exampleImages'
92+
image_file_list = glob.glob(image_dir + '/*.png')
93+
output_size = (400, 400)
94+
transform_list = [RandFlip(), Rotate(20), NormalizeIntensity(), Rotate90()]
95+
96+
#print(LoadImage(image_only=True)(image_file_list[0]).shape)
97+
#train_dataset=Dataset(image_file_list, transform_list, shuffle_transforms=0)
98+
#train_dataloader = DataLoader(train_dataset, batch_size=4, num_workers=2)
99+
#for _batch_data in train_dataloader:
100+
# img = _batch_data[0]
101+
102+
image_loading_transforms = [LoadImage(image_only=True), AddChannel()]
103+
augmentation_dict = {'rotate': 3, 'flip': 2, 'rotate90': 1, 'intensityGaussian': 2, 'adjustContrast' : 2}
104+
105+
img = AugmentData(image_loading_transforms=image_loading_transforms, augmentation_dict = augmentation_dict)(image_file_list)
106+
print(img.shape)
107+
108+
109+
if __name__ == '__main__':
110+
main()

runner.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ pattern="-and -name '*' -and ! -wholename '*federated_learning*'\
7474
-and ! -wholename '*nuclick_infer*'\
7575
-and ! -wholename '*nuclick_training_notebook*'\
7676
-and ! -wholename '*full_gpu_inference_pipeline*'\
77+
-and ! -wholename '*generate_random_permutations*'\
7778
-and ! -wholename '*get_started*'"
7879
kernelspec="python3"
7980

0 commit comments

Comments
 (0)