1
- import os
2
- import time
3
- import shutil
4
1
import argparse
5
2
import collections .abc
6
- import gdown
3
+ import os
4
+ import shutil
5
+ import time
7
6
7
+ import gdown
8
8
import numpy as np
9
- from sklearn .metrics import cohen_kappa_score
10
-
11
9
import torch
12
- import torch .nn as nn
13
- from torch .cuda .amp import GradScaler , autocast
14
-
15
- from torch .utils .tensorboard import SummaryWriter
16
- from torch .utils .data .distributed import DistributedSampler
17
- from torch .utils .data .dataloader import default_collate
18
-
19
10
import torch .distributed as dist
20
11
import torch .multiprocessing as mp
21
-
12
+ import torch .nn as nn
13
+ from monai .config import KeysCollection
22
14
from monai .data import Dataset , load_decathlon_datalist
23
- from monai .data .image_reader import WSIReader
15
+ from monai .data .wsi_reader import WSIReader
24
16
from monai .metrics import Cumulative , CumulativeAverage
25
- from monai .transforms import Transform , Compose , LoadImageD , RandFlipd , RandRotate90d , ScaleIntensityRangeD , ToTensord
26
- from monai .apps .pathology .transforms import TileOnGridd
27
17
from monai .networks .nets import milmodel
28
-
29
-
30
- def parse_args ():
31
-
32
- parser = argparse .ArgumentParser (description = "Multiple Instance Learning (MIL) example of classification from WSI." )
33
- parser .add_argument (
34
- "--data_root" , default = "/PandaChallenge2020/train_images/" , help = "path to root folder of images"
35
- )
36
- parser .add_argument ("--dataset_json" , default = None , type = str , help = "path to dataset json file" )
37
-
38
- parser .add_argument ("--num_classes" , default = 5 , type = int , help = "number of output classes" )
39
- parser .add_argument ("--mil_mode" , default = "att_trans" , help = "MIL algorithm" )
40
- parser .add_argument (
41
- "--tile_count" , default = 44 , type = int , help = "number of patches (instances) to extract from WSI image"
42
- )
43
- parser .add_argument ("--tile_size" , default = 256 , type = int , help = "size of square patch (instance) in pixels" )
44
-
45
- parser .add_argument ("--checkpoint" , default = None , help = "load existing checkpoint" )
46
- parser .add_argument (
47
- "--validate" ,
48
- action = "store_true" ,
49
- help = "run only inference on the validation set, must specify the checkpoint argument" ,
50
- )
51
-
52
- parser .add_argument ("--logdir" , default = None , help = "path to log directory to store Tensorboard logs" )
53
-
54
- parser .add_argument ("--epochs" , default = 50 , type = int , help = "number of training epochs" )
55
- parser .add_argument ("--batch_size" , default = 4 , type = int , help = "batch size, the number of WSI images per gpu" )
56
- parser .add_argument ("--optim_lr" , default = 3e-5 , type = float , help = "initial learning rate" )
57
-
58
- parser .add_argument ("--weight_decay" , default = 0 , type = float , help = "optimizer weight decay" )
59
- parser .add_argument ("--amp" , action = "store_true" , help = "use AMP, recommended" )
60
- parser .add_argument (
61
- "--val_every" ,
62
- default = 1 ,
63
- type = int ,
64
- help = "run validation after this number of epochs, default 1 to run every epoch" ,
65
- )
66
- parser .add_argument ("--workers" , default = 2 , type = int , help = "number of workers for data loading" )
67
-
68
- ###for multigpu
69
- parser .add_argument ("--distributed" , action = "store_true" , help = "use multigpu training, recommended" )
70
- parser .add_argument ("--world_size" , default = 1 , type = int , help = "number of nodes for distributed training" )
71
- parser .add_argument ("--rank" , default = 0 , type = int , help = "node rank for distributed training" )
72
- parser .add_argument (
73
- "--dist-url" , default = "tcp://127.0.0.1:23456" , type = str , help = "url used to set up distributed training"
74
- )
75
- parser .add_argument ("--dist-backend" , default = "nccl" , type = str , help = "distributed backend" )
76
-
77
- parser .add_argument (
78
- "--quick" , action = "store_true" , help = "use a small subset of data for debugging"
79
- ) # for debugging
80
-
81
- args = parser .parse_args ()
82
-
83
- print ("Argument values:" )
84
- for k , v in vars (args ).items ():
85
- print (k , "=>" , v )
86
- print ("-----------------" )
87
-
88
- return args
18
+ from monai .transforms import (
19
+ Compose ,
20
+ GridPatchd ,
21
+ LoadImaged ,
22
+ MapTransform ,
23
+ RandFlipd ,
24
+ RandGridPatchd ,
25
+ RandRotate90d ,
26
+ ScaleIntensityRanged ,
27
+ ToTensord ,
28
+ )
29
+ from sklearn .metrics import cohen_kappa_score
30
+ from torch .cuda .amp import GradScaler , autocast
31
+ from torch .utils .data .dataloader import default_collate
32
+ from torch .utils .data .distributed import DistributedSampler
33
+ from torch .utils .tensorboard import SummaryWriter
89
34
90
35
91
36
def train_epoch (model , loader , optimizer , scaler , epoch , args ):
@@ -246,22 +191,26 @@ def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0):
246
191
print ("Saving checkpoint" , filename )
247
192
248
193
249
- class LabelEncodeIntegerGraded (Transform ):
194
+ class LabelEncodeIntegerGraded (MapTransform ):
250
195
"""
251
196
Convert an integer label to encoded array representation of length num_classes,
252
197
with 1 filled in up to label index, and 0 otherwise. For example for num_classes=5,
253
198
embedding of 2 -> (1,1,0,0,0)
254
199
255
200
Args:
256
201
num_classes: the number of classes to convert to encoded format.
257
- keys: keys of the corresponding items to be transformed
258
- Defaults to ``['label']`` .
202
+ keys: keys of the corresponding items to be transformed. Defaults to ``'label'``.
203
+ allow_missing_keys: don't raise exception if key is missing .
259
204
260
205
"""
261
206
262
- def __init__ (self , num_classes , keys = ["label" ]):
263
- super ().__init__ ()
264
- self .keys = keys
207
+ def __init__ (
208
+ self ,
209
+ num_classes : int ,
210
+ keys : KeysCollection = "label" ,
211
+ allow_missing_keys : bool = False ,
212
+ ):
213
+ super ().__init__ (keys , allow_missing_keys )
265
214
self .num_classes = num_classes
266
215
267
216
def __call__ (self , data ):
@@ -278,35 +227,12 @@ def __call__(self, data):
278
227
return d
279
228
280
229
281
- def main ():
282
-
283
- args = parse_args ()
284
-
285
- if args .dataset_json is None :
286
- # download default json datalist
287
- resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK"
288
- dst = "./datalist_panda_0.json"
289
- if not os .path .exists (dst ):
290
- gdown .download (resource , dst , quiet = False )
291
- args .dataset_json = dst
292
-
293
- if args .distributed :
294
- ngpus_per_node = torch .cuda .device_count ()
295
- args .optim_lr = ngpus_per_node * args .optim_lr / 2 # heuristic to scale up learning rate in multigpu setup
296
- args .world_size = ngpus_per_node * args .world_size
297
-
298
- print ("Multigpu" , ngpus_per_node , "rescaled lr" , args .optim_lr )
299
- mp .spawn (main_worker , nprocs = ngpus_per_node , args = (args ,))
300
- else :
301
- main_worker (0 , args )
302
-
303
-
304
230
def list_data_collate (batch : collections .abc .Sequence ):
305
- '''
306
- Combine instances from a list of dicts into a single dict, by stacking them along first dim
307
- [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
308
- followed by the default collate which will form a batch BxNx3xHxW
309
- '''
231
+ """
232
+ Combine instances from a list of dicts into a single dict, by stacking them along first dim
233
+ [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
234
+ followed by the default collate which will form a batch BxNx3xHxW
235
+ """
310
236
311
237
for i , item in enumerate (batch ):
312
238
data = item [0 ]
@@ -352,37 +278,36 @@ def main_worker(gpu, args):
352
278
353
279
train_transform = Compose (
354
280
[
355
- LoadImageD (keys = ["image" ], reader = WSIReader , backend = "TiffFile " , dtype = np .uint8 , level = 1 , image_only = True ),
281
+ LoadImaged (keys = ["image" ], reader = WSIReader , backend = "cucim " , dtype = np .uint8 , level = 1 , image_only = True ),
356
282
LabelEncodeIntegerGraded (keys = ["label" ], num_classes = args .num_classes ),
357
- TileOnGridd (
283
+ RandGridPatchd (
358
284
keys = ["image" ],
359
- tile_count = args .tile_count ,
360
- tile_size = args .tile_size ,
361
- random_offset = True ,
362
- background_val = 255 ,
363
- return_list_of_dicts = True ,
285
+ patch_size = ( args .tile_size , args . tile_size ) ,
286
+ num_patches = args .tile_count ,
287
+ sort_fn = "min" ,
288
+ pad_mode = None ,
289
+ constant_values = 255 ,
364
290
),
365
291
RandFlipd (keys = ["image" ], spatial_axis = 0 , prob = 0.5 ),
366
292
RandFlipd (keys = ["image" ], spatial_axis = 1 , prob = 0.5 ),
367
293
RandRotate90d (keys = ["image" ], prob = 0.5 ),
368
- ScaleIntensityRangeD (keys = ["image" ], a_min = np .float32 (255 ), a_max = np .float32 (0 )),
294
+ ScaleIntensityRanged (keys = ["image" ], a_min = np .float32 (255 ), a_max = np .float32 (0 )),
369
295
ToTensord (keys = ["image" , "label" ]),
370
296
]
371
297
)
372
298
373
299
valid_transform = Compose (
374
300
[
375
- LoadImageD (keys = ["image" ], reader = WSIReader , backend = "TiffFile " , dtype = np .uint8 , level = 1 , image_only = True ),
301
+ LoadImaged (keys = ["image" ], reader = WSIReader , backend = "cucim " , dtype = np .uint8 , level = 1 , image_only = True ),
376
302
LabelEncodeIntegerGraded (keys = ["label" ], num_classes = args .num_classes ),
377
- TileOnGridd (
303
+ GridPatchd (
378
304
keys = ["image" ],
379
- tile_count = None ,
380
- tile_size = args .tile_size ,
381
- random_offset = False ,
382
- background_val = 255 ,
383
- return_list_of_dicts = True ,
305
+ patch_size = (args .tile_size , args .tile_size ),
306
+ threshold = 0.999 * 3 * 255 * args .tile_size * args .tile_size ,
307
+ pad_mode = None ,
308
+ constant_values = 255 ,
384
309
),
385
- ScaleIntensityRangeD (keys = ["image" ], a_min = np .float32 (255 ), a_max = np .float32 (0 )),
310
+ ScaleIntensityRanged (keys = ["image" ], a_min = np .float32 (255 ), a_max = np .float32 (0 )),
386
311
ToTensord (keys = ["image" , "label" ]),
387
312
]
388
313
)
@@ -540,5 +465,85 @@ def main_worker(gpu, args):
540
465
print ("ALL DONE" )
541
466
542
467
468
+ def parse_args ():
469
+
470
+ parser = argparse .ArgumentParser (description = "Multiple Instance Learning (MIL) example of classification from WSI." )
471
+ parser .add_argument (
472
+ "--data_root" , default = "/PandaChallenge2020/train_images/" , help = "path to root folder of images"
473
+ )
474
+ parser .add_argument ("--dataset_json" , default = None , type = str , help = "path to dataset json file" )
475
+
476
+ parser .add_argument ("--num_classes" , default = 5 , type = int , help = "number of output classes" )
477
+ parser .add_argument ("--mil_mode" , default = "att_trans" , help = "MIL algorithm" )
478
+ parser .add_argument (
479
+ "--tile_count" , default = 44 , type = int , help = "number of patches (instances) to extract from WSI image"
480
+ )
481
+ parser .add_argument ("--tile_size" , default = 256 , type = int , help = "size of square patch (instance) in pixels" )
482
+
483
+ parser .add_argument ("--checkpoint" , default = None , help = "load existing checkpoint" )
484
+ parser .add_argument (
485
+ "--validate" ,
486
+ action = "store_true" ,
487
+ help = "run only inference on the validation set, must specify the checkpoint argument" ,
488
+ )
489
+
490
+ parser .add_argument ("--logdir" , default = None , help = "path to log directory to store Tensorboard logs" )
491
+
492
+ parser .add_argument ("--epochs" , default = 50 , type = int , help = "number of training epochs" )
493
+ parser .add_argument ("--batch_size" , default = 4 , type = int , help = "batch size, the number of WSI images per gpu" )
494
+ parser .add_argument ("--optim_lr" , default = 3e-5 , type = float , help = "initial learning rate" )
495
+
496
+ parser .add_argument ("--weight_decay" , default = 0 , type = float , help = "optimizer weight decay" )
497
+ parser .add_argument ("--amp" , action = "store_true" , help = "use AMP, recommended" )
498
+ parser .add_argument (
499
+ "--val_every" ,
500
+ default = 1 ,
501
+ type = int ,
502
+ help = "run validation after this number of epochs, default 1 to run every epoch" ,
503
+ )
504
+ parser .add_argument ("--workers" , default = 2 , type = int , help = "number of workers for data loading" )
505
+
506
+ ###for multigpu
507
+ parser .add_argument ("--distributed" , action = "store_true" , help = "use multigpu training, recommended" )
508
+ parser .add_argument ("--world_size" , default = 1 , type = int , help = "number of nodes for distributed training" )
509
+ parser .add_argument ("--rank" , default = 0 , type = int , help = "node rank for distributed training" )
510
+ parser .add_argument (
511
+ "--dist-url" , default = "tcp://127.0.0.1:23456" , type = str , help = "url used to set up distributed training"
512
+ )
513
+ parser .add_argument ("--dist-backend" , default = "nccl" , type = str , help = "distributed backend" )
514
+
515
+ parser .add_argument (
516
+ "--quick" , action = "store_true" , help = "use a small subset of data for debugging"
517
+ ) # for debugging
518
+
519
+ args = parser .parse_args ()
520
+
521
+ print ("Argument values:" )
522
+ for k , v in vars (args ).items ():
523
+ print (k , "=>" , v )
524
+ print ("-----------------" )
525
+
526
+ return args
527
+
528
+
543
529
if __name__ == "__main__" :
544
- main ()
530
+
531
+ args = parse_args ()
532
+
533
+ if args .dataset_json is None :
534
+ # download default json datalist
535
+ resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK"
536
+ dst = "./datalist_panda_0.json"
537
+ if not os .path .exists (dst ):
538
+ gdown .download (resource , dst , quiet = False )
539
+ args .dataset_json = dst
540
+
541
+ if args .distributed :
542
+ ngpus_per_node = torch .cuda .device_count ()
543
+ args .optim_lr = ngpus_per_node * args .optim_lr / 2 # heuristic to scale up learning rate in multigpu setup
544
+ args .world_size = ngpus_per_node * args .world_size
545
+
546
+ print ("Multigpu" , ngpus_per_node , "rescaled lr" , args .optim_lr )
547
+ mp .spawn (main_worker , nprocs = ngpus_per_node , args = (args ,))
548
+ else :
549
+ main_worker (0 , args )
0 commit comments