1
+ import json
2
+ import logging
3
+ import copy
4
+ import math
5
+ import os
6
+ import random
7
+ import cv2
8
+ import numpy as np
9
+ import skimage
10
+ import torch
11
+ from tqdm import tqdm
12
+ from skimage .measure import regionprops
13
+ #from lib.handlers import TensorBoardImageHandler
14
+ #from lib.transforms import FilterImaged
15
+ #from lib.utils import split_dataset, split_nuclei_dataset
16
+ from monai .config import KeysCollection
17
+ from monai .engines import SupervisedTrainer , SupervisedEvaluator
18
+ from monai .handlers import (
19
+ CheckpointSaver ,
20
+ EarlyStopHandler ,
21
+ LrScheduleHandler ,
22
+ MeanDice ,
23
+ StatsHandler ,
24
+ TensorBoardImageHandler ,
25
+ TensorBoardStatsHandler ,
26
+ ValidationHandler ,
27
+ from_engine
28
+ )
29
+ from monai .inferers import SimpleInferer
30
+ from monai .losses import DiceLoss
31
+ from monai .networks .nets import BasicUNet
32
+ from monai .data import (
33
+ CacheDataset ,
34
+ Dataset ,
35
+ DataLoader ,
36
+ )
37
+ from monai .transforms import (
38
+ Activationsd ,
39
+ AddChanneld ,
40
+ AsChannelFirstd ,
41
+ AsDiscreted ,
42
+ Compose ,
43
+ EnsureTyped ,
44
+ LoadImaged ,
45
+ LoadImage ,
46
+ MapTransform ,
47
+ RandomizableTransform ,
48
+ RandRotate90d ,
49
+ ScaleIntensityRangeD ,
50
+ ToNumpyd ,
51
+ TorchVisiond ,
52
+ ToTensord ,
53
+ Transform ,
54
+ )
55
+
56
+ #from monai.apps.nuclick.dataset_prep import split_pannuke_dataset, split_nuclei_dataset
57
+ from monai .apps .nuclick .transforms import (
58
+ FlattenLabeld ,
59
+ ExtractPatchd ,
60
+ SplitLabeld ,
61
+ AddPointGuidanceSignald ,
62
+ FilterImaged
63
+ )
64
+
65
+ #from monailabel.interfaces.datastore import Datastore
66
+ #from monailabel.tasks.train.basic_train import BasicTrainTask, Context
67
+
68
+ def split_pannuke_dataset (image , label , output_dir , groups ):
69
+ groups = groups if groups else dict ()
70
+ groups = [groups ] if isinstance (groups , str ) else groups
71
+ if not isinstance (groups , dict ):
72
+ groups = {v : k + 1 for k , v in enumerate (groups )}
73
+
74
+ label_channels = {
75
+ 0 : "Neoplastic cells" ,
76
+ 1 : "Inflammatory" ,
77
+ 2 : "Connective/Soft tissue cells" ,
78
+ 3 : "Dead Cells" ,
79
+ 4 : "Epithelial" ,
80
+ }
81
+
82
+ print (f"++ Using Groups: { groups } " )
83
+ print (f"++ Using Label Channels: { label_channels } " )
84
+ #logger.info(f"++ Using Groups: {groups}")
85
+ #logger.info(f"++ Using Label Channels: {label_channels}")
86
+
87
+ images = np .load (image )
88
+ labels = np .load (label )
89
+ print (f"Image Shape: { images .shape } " )
90
+ print (f"Labels Shape: { labels .shape } " )
91
+ #logger.info(f"Image Shape: {images.shape}")
92
+ #logger.info(f"Labels Shape: {labels.shape}")
93
+
94
+ images_dir = output_dir
95
+ labels_dir = os .path .join (output_dir , "labels" , "final" )
96
+ os .makedirs (images_dir , exist_ok = True )
97
+ os .makedirs (labels_dir , exist_ok = True )
98
+
99
+ dataset_json = []
100
+ for i in tqdm (range (images .shape [0 ])):
101
+ name = f"img_{ str (i ).zfill (4 )} .npy"
102
+ image_file = os .path .join (images_dir , name )
103
+ label_file = os .path .join (labels_dir , name )
104
+
105
+ image_np = images [i ]
106
+ mask = labels [i ]
107
+ label_np = np .zeros (shape = mask .shape [:2 ])
108
+
109
+ for idx , name in label_channels .items ():
110
+ if idx < mask .shape [2 ]:
111
+ m = mask [:, :, idx ]
112
+ if np .count_nonzero (m ):
113
+ m [m > 0 ] = groups .get (name , 1 )
114
+ label_np = np .where (m > 0 , m , label_np )
115
+
116
+ np .save (image_file , image_np )
117
+ np .save (label_file , label_np )
118
+ dataset_json .append ({"image" : image_file , "label" : label_file })
119
+
120
+ return dataset_json
121
+
122
+ def split_nuclei_dataset (d , centroid_key = "centroid" , mask_value_key = "mask_value" , min_area = 5 ):
123
+ dataset_json = []
124
+
125
+ mask = LoadImage (image_only = True , dtype = np .uint8 )(d ["label" ])
126
+ _ , labels , _ , _ = cv2 .connectedComponentsWithStats (mask , 4 , cv2 .CV_32S )
127
+
128
+ stats = regionprops (labels )
129
+ for stat in stats :
130
+ if stat .area < min_area :
131
+ #logger.debug(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})")
132
+ print (f"++++ Ignored label with smaller area => ( { stat .area } < { min_area } )" )
133
+ continue
134
+
135
+ x , y = stat .centroid
136
+ x = int (math .floor (x ))
137
+ y = int (math .floor (y ))
138
+
139
+ item = copy .deepcopy (d )
140
+ item [centroid_key ] = (x , y )
141
+ item [mask_value_key ] = stat .label
142
+
143
+ # logger.info(f"{d['label']} => {len(stats)} => {mask.shape} => {stat.label}")
144
+ dataset_json .append (item )
145
+ return dataset_json
146
+
147
+ def main ():
148
+
149
+ # Paths
150
+ img_data_path = os .path .normpath ('/scratch/pan_nuke_data/fold_1/Fold_1/images/fold1/images.npy' )
151
+ label_data_path = os .path .normpath ('/scratch/pan_nuke_data/fold_1/Fold_1/masks/fold1/masks.npy' )
152
+ dataset_path = os .path .normpath ('/home/vishwesh/nuclick_experiments/try_1/data' )
153
+ json_path = os .path .normpath ('/home/vishwesh/nuclick_experiments/try_1/data_list.json' )
154
+ logging_dir = os .path .normpath ('/home/vishwesh/nuclick_experiments/try_6/' )
155
+ groups = [
156
+ "Neoplastic cells" ,
157
+ "Inflammatory" ,
158
+ "Connective/Soft tissue cells" ,
159
+ "Dead Cells" ,
160
+ "Epithelial" ,
161
+ ]
162
+
163
+ #Hyper-params
164
+ patch_size = 128
165
+ min_area = 5
166
+
167
+ # Create Dataset
168
+ if os .path .isfile (json_path ) == 0 :
169
+ dataset_json = split_pannuke_dataset (image = img_data_path ,
170
+ label = label_data_path ,
171
+ output_dir = dataset_path ,
172
+ groups = groups )
173
+
174
+ with open (json_path , 'w' ) as j_file :
175
+ json .dump (dataset_json , j_file )
176
+ j_file .close ()
177
+ else :
178
+ with open (json_path , 'r' ) as j_file :
179
+ dataset_json = json .load (j_file )
180
+ j_file .close ()
181
+
182
+ ds_json_new = []
183
+ for d in tqdm (dataset_json ):
184
+ ds_json_new .extend (split_nuclei_dataset (d , min_area = min_area ))
185
+
186
+ print ('Total DataSize is {}' .format (len (ds_json_new )))
187
+ val_split = round (len (ds_json_new ) * 0.8 )
188
+ train_ds_json_new = ds_json_new [:val_split ]
189
+ val_ds_json_new = ds_json_new [val_split :]
190
+
191
+ # Transforms
192
+ train_pre_transforms = Compose (
193
+ [
194
+ LoadImaged (keys = ("image" , "label" ), dtype = np .uint8 ),
195
+ FilterImaged (keys = "image" , min_size = 5 ),
196
+ FlattenLabeld (keys = "label" ),
197
+ AsChannelFirstd (keys = "image" ),
198
+ AddChanneld (keys = "label" ),
199
+ ExtractPatchd (keys = ("image" , "label" ), patch_size = patch_size ),
200
+ SplitLabeld (label = "label" , others = "others" , mask_value = "mask_value" , min_area = min_area ),
201
+ ToTensord (keys = "image" ),
202
+ TorchVisiond (
203
+ keys = "image" , name = "ColorJitter" , brightness = 64.0 / 255.0 , contrast = 0.75 , saturation = 0.25 , hue = 0.04
204
+ ),
205
+ ToNumpyd (keys = "image" ),
206
+ RandRotate90d (keys = ("image" , "label" , "others" ), prob = 0.5 , spatial_axes = (0 , 1 )),
207
+ ScaleIntensityRangeD (keys = "image" , a_min = 0.0 , a_max = 255.0 , b_min = - 1.0 , b_max = 1.0 ),
208
+ AddPointGuidanceSignald (image = "image" , label = "label" , others = "others" ),
209
+ EnsureTyped (keys = ("image" , "label" ))
210
+ ]
211
+ )
212
+
213
+ train_post_transforms = Compose (
214
+ [
215
+ Activationsd (keys = "pred" , sigmoid = True ),
216
+ AsDiscreted (keys = "pred" , threshold_values = True , logit_thresh = 0.5 ),
217
+ ]
218
+ )
219
+
220
+ val_transforms = Compose (
221
+ [
222
+ LoadImaged (keys = ("image" , "label" ), dtype = np .uint8 ),
223
+ FilterImaged (keys = "image" , min_size = 5 ),
224
+ FlattenLabeld (keys = "label" ),
225
+ AsChannelFirstd (keys = "image" ),
226
+ AddChanneld (keys = "label" ),
227
+ ExtractPatchd (keys = ("image" , "label" ), patch_size = patch_size ),
228
+ SplitLabeld (label = "label" , others = "others" , mask_value = "mask_value" , min_area = min_area ),
229
+ ToTensord (keys = "image" ),
230
+ TorchVisiond (
231
+ keys = "image" , name = "ColorJitter" , brightness = 64.0 / 255.0 , contrast = 0.75 , saturation = 0.25 , hue = 0.04
232
+ ),
233
+ ToNumpyd (keys = "image" ),
234
+ RandRotate90d (keys = ("image" , "label" , "others" ), prob = 0.5 , spatial_axes = (0 , 1 )),
235
+ ScaleIntensityRangeD (keys = "image" , a_min = 0.0 , a_max = 255.0 , b_min = - 1.0 , b_max = 1.0 ),
236
+ AddPointGuidanceSignald (image = "image" , label = "label" , others = "others" , drop_rate = 1.0 ),
237
+ EnsureTyped (keys = ("image" , "label" ))
238
+ ]
239
+ )
240
+
241
+ train_key_metric = {"train_dice" : MeanDice (include_background = False , output_transform = from_engine (["pred" , "label" ]))}
242
+ val_key_metric = {"val_dice" : MeanDice (include_background = False , output_transform = from_engine (["pred" , "label" ]))}
243
+ val_inferer = SimpleInferer ()
244
+
245
+ # Define Dataset & Loading
246
+ train_data_set = Dataset (train_ds_json_new , transform = train_pre_transforms )
247
+ train_data_loader = DataLoader (
248
+ dataset = train_data_set ,
249
+ batch_size = 32 ,
250
+ shuffle = True ,
251
+ num_workers = 2
252
+ )
253
+
254
+ val_data_set = Dataset (val_ds_json_new , transform = val_transforms )
255
+ val_data_loader = DataLoader (
256
+ dataset = val_data_set ,
257
+ batch_size = 32 ,
258
+ shuffle = True ,
259
+ num_workers = 2
260
+ )
261
+
262
+ # Network Definition, Optimizer etc
263
+ device = torch .device ("cuda" )
264
+
265
+ network = BasicUNet (
266
+ spatial_dims = 2 ,
267
+ in_channels = 5 ,
268
+ out_channels = 1 ,
269
+ features = (32 , 64 , 128 , 256 , 512 , 32 ),
270
+ )
271
+
272
+ network .to (device )
273
+ optimizer = torch .optim .Adam (network .parameters (), 0.0001 )
274
+ dice_loss = DiceLoss (sigmoid = True , squared_pred = True )
275
+
276
+ # Training Process
277
+ #TODO Consider uisng the Supervised Trainer over here from MONAI
278
+ #network.train()
279
+ #TODO Refer here for how to fix up a validation when using a SupervisedTrainer. In short a supervisedevaluator needs to be created as a
280
+ # training handler
281
+ #TODO https://github.com/Project-MONAI/tutorials/blob/bc342633bd8e50be7b4a67b723006bb03285f6ba/acceleration/distributed_training/unet_training_workflows.py#L187
282
+
283
+ val_handlers = [
284
+ # use the logger "train_log" defined at the beginning of this program
285
+ StatsHandler (name = "train_log" , output_transform = lambda x : None ),
286
+ TensorBoardStatsHandler (log_dir = logging_dir , output_transform = lambda x : None ),
287
+ TensorBoardImageHandler (
288
+ log_dir = logging_dir ,
289
+ batch_transform = from_engine (["image" , "label" ]),
290
+ output_transform = from_engine (["pred" ]),
291
+ ),
292
+ CheckpointSaver (save_dir = logging_dir , save_dict = {"network" : network }, save_key_metric = True ),
293
+ ]
294
+
295
+ evaluator = SupervisedEvaluator (
296
+ device = device ,
297
+ val_data_loader = val_data_loader ,
298
+ network = network ,
299
+ inferer = val_inferer ,
300
+ postprocessing = train_post_transforms ,
301
+ key_val_metric = val_key_metric ,
302
+ #additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
303
+ val_handlers = val_handlers ,
304
+ # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
305
+ amp = False ,
306
+ )
307
+
308
+ train_handlers = [
309
+ # apply “EarlyStop” logic based on the loss value, use “-” negative value because smaller loss is better
310
+ #EarlyStopHandler(trainer=None, patience=20, score_function=lambda x: -x.state.output[0]["loss"], epoch_level=False),
311
+ #LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
312
+ ValidationHandler (validator = evaluator , interval = 1 , epoch_level = True ),
313
+ # use the logger "train_log" defined at the beginning of this program
314
+ StatsHandler (name = "train_log" ,
315
+ tag_name = "train_loss" ,
316
+ output_transform = from_engine (["loss" ], first = True )),
317
+ TensorBoardStatsHandler (log_dir = logging_dir ,
318
+ tag_name = "train_loss" ,
319
+ output_transform = from_engine (["loss" ], first = True )
320
+ ),
321
+ CheckpointSaver (save_dir = logging_dir ,
322
+ save_dict = {"net" : network , "opt" : optimizer },
323
+ save_interval = 1 ,
324
+ epoch_level = True ),
325
+ ]
326
+
327
+ trainer = SupervisedTrainer (
328
+ device = device ,
329
+ max_epochs = 30 ,
330
+ train_data_loader = train_data_loader ,
331
+ network = network ,
332
+ optimizer = optimizer ,
333
+ loss_function = dice_loss ,
334
+ inferer = SimpleInferer (),
335
+ # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
336
+ amp = False ,
337
+ postprocessing = train_post_transforms ,
338
+ key_train_metric = train_key_metric ,
339
+ train_handlers = train_handlers ,
340
+ )
341
+ trainer .run ()
342
+
343
+ # End ...
344
+ return None
345
+
346
+ if __name__ == "__main__" :
347
+ main ()
0 commit comments