Skip to content

Commit 0dc8cae

Browse files
bhashemianpre-commit-ci[bot]Nic-Ma
authored
Update Multiple Instance Learning Pipeline (#728)
* Update mil pipeline Signed-off-by: Behrooz <[email protected]> * Fix a typo Signed-off-by: Behrooz <[email protected]> * Tiffile to cucim Signed-off-by: Behrooz <[email protected]> * Update to RandGridPatchd Signed-off-by: Behrooz <[email protected]> * Fix import Signed-off-by: Behrooz <[email protected]> * Add sort_key Signed-off-by: Behrooz <[email protected]> * sort key max to min Signed-off-by: Behrooz <[email protected]> * Fix patch_size Signed-off-by: Behrooz <[email protected]> * Stack patch locations Signed-off-by: Behrooz <[email protected]> * Remove location Signed-off-by: Behrooz <[email protected]> * Change to fix_num_patches Signed-off-by: Behrooz <[email protected]> * Update grid patch Signed-off-by: Behrooz <[email protected]> * Add threshold_filter Signed-off-by: Behrooz <[email protected]> * Update threshold and add pad_mode=None Signed-off-by: Behrooz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update mil pipeline with new grid patch Signed-off-by: Behrooz <[email protected]> * Update GridPatch args Signed-off-by: Behrooz <[email protected]> * Update LabelEncodeIntegerGraded and sort imports Signed-off-by: Behrooz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nic Ma <[email protected]>
1 parent a08b8e0 commit 0dc8cae

File tree

1 file changed

+134
-129
lines changed

1 file changed

+134
-129
lines changed

pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py

Lines changed: 134 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,36 @@
1-
import os
2-
import time
3-
import shutil
41
import argparse
52
import collections.abc
6-
import gdown
3+
import os
4+
import shutil
5+
import time
76

7+
import gdown
88
import numpy as np
9-
from sklearn.metrics import cohen_kappa_score
10-
119
import torch
12-
import torch.nn as nn
13-
from torch.cuda.amp import GradScaler, autocast
14-
15-
from torch.utils.tensorboard import SummaryWriter
16-
from torch.utils.data.distributed import DistributedSampler
17-
from torch.utils.data.dataloader import default_collate
18-
1910
import torch.distributed as dist
2011
import torch.multiprocessing as mp
21-
12+
import torch.nn as nn
13+
from monai.config import KeysCollection
2214
from monai.data import Dataset, load_decathlon_datalist
23-
from monai.data.image_reader import WSIReader
15+
from monai.data.wsi_reader import WSIReader
2416
from monai.metrics import Cumulative, CumulativeAverage
25-
from monai.transforms import Transform, Compose, LoadImageD, RandFlipd, RandRotate90d, ScaleIntensityRangeD, ToTensord
26-
from monai.apps.pathology.transforms import TileOnGridd
2717
from monai.networks.nets import milmodel
28-
29-
30-
def parse_args():
31-
32-
parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) example of classification from WSI.")
33-
parser.add_argument(
34-
"--data_root", default="/PandaChallenge2020/train_images/", help="path to root folder of images"
35-
)
36-
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
37-
38-
parser.add_argument("--num_classes", default=5, type=int, help="number of output classes")
39-
parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm")
40-
parser.add_argument(
41-
"--tile_count", default=44, type=int, help="number of patches (instances) to extract from WSI image"
42-
)
43-
parser.add_argument("--tile_size", default=256, type=int, help="size of square patch (instance) in pixels")
44-
45-
parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
46-
parser.add_argument(
47-
"--validate",
48-
action="store_true",
49-
help="run only inference on the validation set, must specify the checkpoint argument",
50-
)
51-
52-
parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs")
53-
54-
parser.add_argument("--epochs", default=50, type=int, help="number of training epochs")
55-
parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu")
56-
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
57-
58-
parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay")
59-
parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
60-
parser.add_argument(
61-
"--val_every",
62-
default=1,
63-
type=int,
64-
help="run validation after this number of epochs, default 1 to run every epoch",
65-
)
66-
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
67-
68-
###for multigpu
69-
parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended")
70-
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
71-
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
72-
parser.add_argument(
73-
"--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training"
74-
)
75-
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
76-
77-
parser.add_argument(
78-
"--quick", action="store_true", help="use a small subset of data for debugging"
79-
) # for debugging
80-
81-
args = parser.parse_args()
82-
83-
print("Argument values:")
84-
for k, v in vars(args).items():
85-
print(k, "=>", v)
86-
print("-----------------")
87-
88-
return args
18+
from monai.transforms import (
19+
Compose,
20+
GridPatchd,
21+
LoadImaged,
22+
MapTransform,
23+
RandFlipd,
24+
RandGridPatchd,
25+
RandRotate90d,
26+
ScaleIntensityRanged,
27+
ToTensord,
28+
)
29+
from sklearn.metrics import cohen_kappa_score
30+
from torch.cuda.amp import GradScaler, autocast
31+
from torch.utils.data.dataloader import default_collate
32+
from torch.utils.data.distributed import DistributedSampler
33+
from torch.utils.tensorboard import SummaryWriter
8934

9035

9136
def train_epoch(model, loader, optimizer, scaler, epoch, args):
@@ -246,22 +191,26 @@ def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0):
246191
print("Saving checkpoint", filename)
247192

248193

249-
class LabelEncodeIntegerGraded(Transform):
194+
class LabelEncodeIntegerGraded(MapTransform):
250195
"""
251196
Convert an integer label to encoded array representation of length num_classes,
252197
with 1 filled in up to label index, and 0 otherwise. For example for num_classes=5,
253198
embedding of 2 -> (1,1,0,0,0)
254199
255200
Args:
256201
num_classes: the number of classes to convert to encoded format.
257-
keys: keys of the corresponding items to be transformed
258-
Defaults to ``['label']``.
202+
keys: keys of the corresponding items to be transformed. Defaults to ``'label'``.
203+
allow_missing_keys: don't raise exception if key is missing.
259204
260205
"""
261206

262-
def __init__(self, num_classes, keys=["label"]):
263-
super().__init__()
264-
self.keys = keys
207+
def __init__(
208+
self,
209+
num_classes: int,
210+
keys: KeysCollection = "label",
211+
allow_missing_keys: bool = False,
212+
):
213+
super().__init__(keys, allow_missing_keys)
265214
self.num_classes = num_classes
266215

267216
def __call__(self, data):
@@ -278,35 +227,12 @@ def __call__(self, data):
278227
return d
279228

280229

281-
def main():
282-
283-
args = parse_args()
284-
285-
if args.dataset_json is None:
286-
# download default json datalist
287-
resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK"
288-
dst = "./datalist_panda_0.json"
289-
if not os.path.exists(dst):
290-
gdown.download(resource, dst, quiet=False)
291-
args.dataset_json = dst
292-
293-
if args.distributed:
294-
ngpus_per_node = torch.cuda.device_count()
295-
args.optim_lr = ngpus_per_node * args.optim_lr / 2 # heuristic to scale up learning rate in multigpu setup
296-
args.world_size = ngpus_per_node * args.world_size
297-
298-
print("Multigpu", ngpus_per_node, "rescaled lr", args.optim_lr)
299-
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(args,))
300-
else:
301-
main_worker(0, args)
302-
303-
304230
def list_data_collate(batch: collections.abc.Sequence):
305-
'''
306-
Combine instances from a list of dicts into a single dict, by stacking them along first dim
307-
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
308-
followed by the default collate which will form a batch BxNx3xHxW
309-
'''
231+
"""
232+
Combine instances from a list of dicts into a single dict, by stacking them along first dim
233+
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
234+
followed by the default collate which will form a batch BxNx3xHxW
235+
"""
310236

311237
for i, item in enumerate(batch):
312238
data = item[0]
@@ -352,37 +278,36 @@ def main_worker(gpu, args):
352278

353279
train_transform = Compose(
354280
[
355-
LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True),
281+
LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True),
356282
LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
357-
TileOnGridd(
283+
RandGridPatchd(
358284
keys=["image"],
359-
tile_count=args.tile_count,
360-
tile_size=args.tile_size,
361-
random_offset=True,
362-
background_val=255,
363-
return_list_of_dicts=True,
285+
patch_size=(args.tile_size, args.tile_size),
286+
num_patches=args.tile_count,
287+
sort_fn="min",
288+
pad_mode=None,
289+
constant_values=255,
364290
),
365291
RandFlipd(keys=["image"], spatial_axis=0, prob=0.5),
366292
RandFlipd(keys=["image"], spatial_axis=1, prob=0.5),
367293
RandRotate90d(keys=["image"], prob=0.5),
368-
ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
294+
ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
369295
ToTensord(keys=["image", "label"]),
370296
]
371297
)
372298

373299
valid_transform = Compose(
374300
[
375-
LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True),
301+
LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True),
376302
LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
377-
TileOnGridd(
303+
GridPatchd(
378304
keys=["image"],
379-
tile_count=None,
380-
tile_size=args.tile_size,
381-
random_offset=False,
382-
background_val=255,
383-
return_list_of_dicts=True,
305+
patch_size=(args.tile_size, args.tile_size),
306+
threshold=0.999 * 3 * 255 * args.tile_size * args.tile_size,
307+
pad_mode=None,
308+
constant_values=255,
384309
),
385-
ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
310+
ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
386311
ToTensord(keys=["image", "label"]),
387312
]
388313
)
@@ -540,5 +465,85 @@ def main_worker(gpu, args):
540465
print("ALL DONE")
541466

542467

468+
def parse_args():
469+
470+
parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) example of classification from WSI.")
471+
parser.add_argument(
472+
"--data_root", default="/PandaChallenge2020/train_images/", help="path to root folder of images"
473+
)
474+
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
475+
476+
parser.add_argument("--num_classes", default=5, type=int, help="number of output classes")
477+
parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm")
478+
parser.add_argument(
479+
"--tile_count", default=44, type=int, help="number of patches (instances) to extract from WSI image"
480+
)
481+
parser.add_argument("--tile_size", default=256, type=int, help="size of square patch (instance) in pixels")
482+
483+
parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
484+
parser.add_argument(
485+
"--validate",
486+
action="store_true",
487+
help="run only inference on the validation set, must specify the checkpoint argument",
488+
)
489+
490+
parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs")
491+
492+
parser.add_argument("--epochs", default=50, type=int, help="number of training epochs")
493+
parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu")
494+
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
495+
496+
parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay")
497+
parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
498+
parser.add_argument(
499+
"--val_every",
500+
default=1,
501+
type=int,
502+
help="run validation after this number of epochs, default 1 to run every epoch",
503+
)
504+
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
505+
506+
###for multigpu
507+
parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended")
508+
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
509+
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
510+
parser.add_argument(
511+
"--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training"
512+
)
513+
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
514+
515+
parser.add_argument(
516+
"--quick", action="store_true", help="use a small subset of data for debugging"
517+
) # for debugging
518+
519+
args = parser.parse_args()
520+
521+
print("Argument values:")
522+
for k, v in vars(args).items():
523+
print(k, "=>", v)
524+
print("-----------------")
525+
526+
return args
527+
528+
543529
if __name__ == "__main__":
544-
main()
530+
531+
args = parse_args()
532+
533+
if args.dataset_json is None:
534+
# download default json datalist
535+
resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK"
536+
dst = "./datalist_panda_0.json"
537+
if not os.path.exists(dst):
538+
gdown.download(resource, dst, quiet=False)
539+
args.dataset_json = dst
540+
541+
if args.distributed:
542+
ngpus_per_node = torch.cuda.device_count()
543+
args.optim_lr = ngpus_per_node * args.optim_lr / 2 # heuristic to scale up learning rate in multigpu setup
544+
args.world_size = ngpus_per_node * args.world_size
545+
546+
print("Multigpu", ngpus_per_node, "rescaled lr", args.optim_lr)
547+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(args,))
548+
else:
549+
main_worker(0, args)

0 commit comments

Comments
 (0)