Skip to content

Update Multiple Instance Learning Pipeline #728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jul 20, 2022
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6eee290
Update mil pipeline
bhashemian May 23, 2022
5325295
Fix a typo
bhashemian May 23, 2022
afb66f7
Tiffile to cucim
bhashemian May 23, 2022
de9f7eb
Update to RandGridPatchd
bhashemian May 23, 2022
a54a17d
Fix import
bhashemian May 23, 2022
554ed71
Add sort_key
bhashemian May 23, 2022
0724dfe
sort key max to min
bhashemian May 23, 2022
7f34968
Fix patch_size
bhashemian May 23, 2022
a99fe58
Stack patch locations
bhashemian May 23, 2022
633f4ec
Merge branch 'main' of github.com:Project-MONAI/tutorials into update…
bhashemian May 26, 2022
8116255
Remove location
bhashemian May 27, 2022
a07e121
Change to fix_num_patches
bhashemian May 27, 2022
9b66a93
Merge branch 'main' of github.com:Project-MONAI/tutorials into update…
bhashemian May 27, 2022
7bd2ecb
Merge branch 'main' of github.com:Project-MONAI/tutorials into update…
bhashemian May 31, 2022
975edc4
Update grid patch
bhashemian May 31, 2022
e21b81d
Add threshold_filter
bhashemian Jun 2, 2022
10ef5fa
Update threshold and add pad_mode=None
bhashemian Jun 2, 2022
ecb1ba1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2022
eeaaf43
Merge branch 'main' into update-mil
bhashemian Jun 2, 2022
f7a956d
Merge branch 'main' into update-mil
bhashemian Jun 6, 2022
a774973
Update mil pipeline with new grid patch
bhashemian Jun 7, 2022
3419580
Update GridPatch args
bhashemian Jun 9, 2022
01a6698
Merge branch 'main' into update-mil
bhashemian Jul 7, 2022
1a45305
Update LabelEncodeIntegerGraded and sort imports
bhashemian Jul 8, 2022
eee3a74
Merge branch 'main' into update-mil
bhashemian Jul 11, 2022
3e7cdd8
Merge branch 'main' into update-mil
bhashemian Jul 18, 2022
7bbd4d3
Merge branch 'main' into update-mil
Nic-Ma Jul 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,91 +1,36 @@
import os
import time
import shutil
import argparse
import collections.abc
import gdown
import os
import shutil
import time

import gdown
import numpy as np
from sklearn.metrics import cohen_kappa_score

import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.dataloader import default_collate

import torch.distributed as dist
import torch.multiprocessing as mp

import torch.nn as nn
from monai.config import KeysCollection
from monai.data import Dataset, load_decathlon_datalist
from monai.data.image_reader import WSIReader
from monai.data.wsi_reader import WSIReader
from monai.metrics import Cumulative, CumulativeAverage
from monai.transforms import Transform, Compose, LoadImageD, RandFlipd, RandRotate90d, ScaleIntensityRangeD, ToTensord
from monai.apps.pathology.transforms import TileOnGridd
from monai.networks.nets import milmodel


def parse_args():

parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) example of classification from WSI.")
parser.add_argument(
"--data_root", default="/PandaChallenge2020/train_images/", help="path to root folder of images"
)
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")

parser.add_argument("--num_classes", default=5, type=int, help="number of output classes")
parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm")
parser.add_argument(
"--tile_count", default=44, type=int, help="number of patches (instances) to extract from WSI image"
)
parser.add_argument("--tile_size", default=256, type=int, help="size of square patch (instance) in pixels")

parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
parser.add_argument(
"--validate",
action="store_true",
help="run only inference on the validation set, must specify the checkpoint argument",
)

parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs")

parser.add_argument("--epochs", default=50, type=int, help="number of training epochs")
parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu")
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")

parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay")
parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
parser.add_argument(
"--val_every",
default=1,
type=int,
help="run validation after this number of epochs, default 1 to run every epoch",
)
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")

###for multigpu
parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended")
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
parser.add_argument(
"--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training"
)
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")

parser.add_argument(
"--quick", action="store_true", help="use a small subset of data for debugging"
) # for debugging

args = parser.parse_args()

print("Argument values:")
for k, v in vars(args).items():
print(k, "=>", v)
print("-----------------")

return args
from monai.transforms import (
Compose,
GridPatchd,
LoadImaged,
MapTransform,
RandFlipd,
RandGridPatchd,
RandRotate90d,
ScaleIntensityRanged,
ToTensord,
)
from sklearn.metrics import cohen_kappa_score
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter


def train_epoch(model, loader, optimizer, scaler, epoch, args):
Expand Down Expand Up @@ -246,22 +191,26 @@ def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0):
print("Saving checkpoint", filename)


class LabelEncodeIntegerGraded(Transform):
class LabelEncodeIntegerGraded(MapTransform):
"""
Convert an integer label to encoded array representation of length num_classes,
with 1 filled in up to label index, and 0 otherwise. For example for num_classes=5,
embedding of 2 -> (1,1,0,0,0)

Args:
num_classes: the number of classes to convert to encoded format.
keys: keys of the corresponding items to be transformed
Defaults to ``['label']``.
keys: keys of the corresponding items to be transformed. Defaults to ``'label'``.
allow_missing_keys: don't raise exception if key is missing.

"""

def __init__(self, num_classes, keys=["label"]):
super().__init__()
self.keys = keys
def __init__(
self,
num_classes: int,
keys: KeysCollection = "label",
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
self.num_classes = num_classes

def __call__(self, data):
Expand All @@ -278,35 +227,12 @@ def __call__(self, data):
return d


def main():

args = parse_args()

if args.dataset_json is None:
# download default json datalist
resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK"
dst = "./datalist_panda_0.json"
if not os.path.exists(dst):
gdown.download(resource, dst, quiet=False)
args.dataset_json = dst

if args.distributed:
ngpus_per_node = torch.cuda.device_count()
args.optim_lr = ngpus_per_node * args.optim_lr / 2 # heuristic to scale up learning rate in multigpu setup
args.world_size = ngpus_per_node * args.world_size

print("Multigpu", ngpus_per_node, "rescaled lr", args.optim_lr)
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(args,))
else:
main_worker(0, args)


def list_data_collate(batch: collections.abc.Sequence):
'''
Combine instances from a list of dicts into a single dict, by stacking them along first dim
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
followed by the default collate which will form a batch BxNx3xHxW
'''
"""
Combine instances from a list of dicts into a single dict, by stacking them along first dim
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
followed by the default collate which will form a batch BxNx3xHxW
"""

for i, item in enumerate(batch):
data = item[0]
Expand Down Expand Up @@ -352,37 +278,36 @@ def main_worker(gpu, args):

train_transform = Compose(
[
LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True),
LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True),
LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
TileOnGridd(
RandGridPatchd(
keys=["image"],
tile_count=args.tile_count,
tile_size=args.tile_size,
random_offset=True,
background_val=255,
return_list_of_dicts=True,
patch_size=(args.tile_size, args.tile_size),
num_patches=args.tile_count,
sort_fn="min",
pad_mode=None,
constant_values=255,
),
RandFlipd(keys=["image"], spatial_axis=0, prob=0.5),
RandFlipd(keys=["image"], spatial_axis=1, prob=0.5),
RandRotate90d(keys=["image"], prob=0.5),
ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
ToTensord(keys=["image", "label"]),
]
)

valid_transform = Compose(
[
LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True),
LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True),
LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
TileOnGridd(
GridPatchd(
keys=["image"],
tile_count=None,
tile_size=args.tile_size,
random_offset=False,
background_val=255,
return_list_of_dicts=True,
patch_size=(args.tile_size, args.tile_size),
threshold=0.999 * 3 * 255 * args.tile_size * args.tile_size,
pad_mode=None,
constant_values=255,
),
ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)),
ToTensord(keys=["image", "label"]),
]
)
Expand Down Expand Up @@ -540,5 +465,85 @@ def main_worker(gpu, args):
print("ALL DONE")


def parse_args():

parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) example of classification from WSI.")
parser.add_argument(
"--data_root", default="/PandaChallenge2020/train_images/", help="path to root folder of images"
)
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")

parser.add_argument("--num_classes", default=5, type=int, help="number of output classes")
parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm")
parser.add_argument(
"--tile_count", default=44, type=int, help="number of patches (instances) to extract from WSI image"
)
parser.add_argument("--tile_size", default=256, type=int, help="size of square patch (instance) in pixels")

parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
parser.add_argument(
"--validate",
action="store_true",
help="run only inference on the validation set, must specify the checkpoint argument",
)

parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs")

parser.add_argument("--epochs", default=50, type=int, help="number of training epochs")
parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu")
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")

parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay")
parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
parser.add_argument(
"--val_every",
default=1,
type=int,
help="run validation after this number of epochs, default 1 to run every epoch",
)
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")

###for multigpu
parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended")
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
parser.add_argument(
"--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training"
)
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")

parser.add_argument(
"--quick", action="store_true", help="use a small subset of data for debugging"
) # for debugging

args = parser.parse_args()

print("Argument values:")
for k, v in vars(args).items():
print(k, "=>", v)
print("-----------------")

return args


if __name__ == "__main__":
main()

args = parse_args()

if args.dataset_json is None:
# download default json datalist
resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK"
dst = "./datalist_panda_0.json"
if not os.path.exists(dst):
gdown.download(resource, dst, quiet=False)
args.dataset_json = dst

if args.distributed:
ngpus_per_node = torch.cuda.device_count()
args.optim_lr = ngpus_per_node * args.optim_lr / 2 # heuristic to scale up learning rate in multigpu setup
args.world_size = ngpus_per_node * args.world_size

print("Multigpu", ngpus_per_node, "rescaled lr", args.optim_lr)
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(args,))
else:
main_worker(0, args)