Skip to content

mil tutorial improvements #1069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,13 @@ Reference implementation used in MICCAI 2022 [ACR-NVIDIA-NCI Breast Density FL c

**Digital Pathology**
#### [Whole Slide Tumor Detection](./pathology/tumor_detection)
The example show how to train and evaluate a tumor detection model (based on patch classification) on whole-slide histopathology images.
The example shows how to train and evaluate a tumor detection model (based on patch classification) on whole-slide histopathology images.

#### [Profiling Whole Slide Tumor Detection](./pathology/tumor_detection)
The example show how to use MONAI NVTX transforms to tag and profile pre- and post-processing transforms in the digital pathology whole slide tumor detection pipeline.
The example shows how to use MONAI NVTX transforms to tag and profile pre- and post-processing transforms in the digital pathology whole slide tumor detection pipeline.

#### [Multiple Instance Learning WSI classification](./pathology/multiple_instance_learning)
An example of Multiple Instance Learning (MIL) classification from Whole Slide Images (WSI) of prostate histopathology.

#### [NuClick:Interactive Annotation for Pathology](./pathology/nuclick)
The notebook demonstrates examples of training and inference pipelines with interactive annotation for pathology, NuClick is used for delineating nuclei, cells and a squiggle for outlining glands.
Expand Down
3 changes: 1 addition & 2 deletions pathology/multiple_instance_learning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ The script is tested with:
Please install the required dependencies

```bash
pip install tifffile
pip install imagecodecs
pip install cucim gdown
```

For more information please check out [the installation guide](https://docs.monai.io/en/latest/installation.html).
Expand Down
Binary file modified pathology/multiple_instance_learning/mil_train_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified pathology/multiple_instance_learning/mil_val_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified pathology/multiple_instance_learning/mil_val_qwk.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,18 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):

for idx, batch_data in enumerate(loader):

data, target = batch_data["image"].cuda(args.rank), batch_data["label"].cuda(args.rank)
data = batch_data["image"].as_subclass(torch.Tensor).cuda(args.rank)
target = batch_data["label"].as_subclass(torch.Tensor).cuda(args.rank)

optimizer.zero_grad(set_to_none=True)

with autocast(enabled=args.amp):
logits = model(data)
loss = criterion(logits, target)

if args.amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

acc = (logits.sigmoid().sum(1).detach().round() == target.sum(1).round()).float().mean()

Expand Down Expand Up @@ -108,7 +105,8 @@ def val_epoch(model, loader, epoch, args, max_tiles=None):

for idx, batch_data in enumerate(loader):

data, target = batch_data["image"].cuda(args.rank), batch_data["label"].cuda(args.rank)
data = batch_data["image"].as_subclass(torch.Tensor).cuda(args.rank)
target = batch_data["label"].as_subclass(torch.Tensor).cuda(args.rank)

with autocast(enabled=args.amp):

Expand All @@ -122,7 +120,7 @@ def val_epoch(model, loader, epoch, args, max_tiles=None):
logits2 = []

for i in range(int(np.ceil(data.shape[1] / float(max_tiles)))):
data_slice = data[:, i * max_tiles : (i + 1) * max_tiles]
data_slice = data[:, i * max_tiles: (i + 1) * max_tiles]
logits_slice = model(data_slice, no_head=True)
logits.append(logits_slice)

Expand Down Expand Up @@ -329,7 +327,7 @@ def main_worker(gpu, args):
shuffle=(train_sampler is None),
num_workers=args.workers,
pin_memory=True,
multiprocessing_context="spawn",
multiprocessing_context="spawn" if args.workers > 0 else None,
sampler=train_sampler,
collate_fn=list_data_collate,
)
Expand All @@ -339,7 +337,7 @@ def main_worker(gpu, args):
shuffle=False,
num_workers=args.workers,
pin_memory=True,
multiprocessing_context="spawn",
multiprocessing_context="spawn" if args.workers > 0 else None,
sampler=val_sampler,
collate_fn=list_data_collate,
)
Expand Down Expand Up @@ -399,13 +397,11 @@ def main_worker(gpu, args):
else:
writer = None

###RUN TRAINING
#RUN TRAINING
n_epochs = args.epochs
val_acc_max = 0.0

scaler = None
if args.amp: # new native amp
scaler = GradScaler()
scaler = GradScaler(enabled=args.amp)

for epoch in range(start_epoch, n_epochs):

Expand All @@ -430,9 +426,6 @@ def main_worker(gpu, args):
writer.add_scalar("train_loss", train_loss, epoch)
writer.add_scalar("train_acc", train_acc, epoch)

if args.distributed:
torch.distributed.barrier()

b_new_best = False
val_acc = 0
if (epoch + 1) % args.val_every == 0:
Expand Down Expand Up @@ -494,21 +487,21 @@ def parse_args():

parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs")

parser.add_argument("--epochs", default=50, type=int, help="number of training epochs")
parser.add_argument("--epochs", "--max_epochs", default=50, type=int, help="number of training epochs")
parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu")
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")

parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay")
parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
parser.add_argument(
"--val_every",
parser.add_argument("--val_every",
"--val_interval",
default=1,
type=int,
help="run validation after this number of epochs, default 1 to run every epoch",
)
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")

###for multigpu
#for multigpu
parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended")
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
Expand All @@ -519,7 +512,7 @@ def parse_args():

parser.add_argument(
"--quick", action="store_true", help="use a small subset of data for debugging"
) # for debugging
)

args = parser.parse_args()

Expand All @@ -530,7 +523,6 @@ def parse_args():

return args


if __name__ == "__main__":

args = parse_args()
Expand Down