Skip to content

Commit a971dce

Browse files
authored
mil tutorial improvements (Project-MONAI#1069)
Fixes # Project-MONAI/MONAI#5081 - this fixes an issue MetaTensor not working with SyncBatchNorm. This fix is a workaround to cast MetaTensor back to torch.Tensor. - this also adds several Readme/doc improvements regarding MIL tutorial Signed-off-by: myron <[email protected]>
1 parent 93bc555 commit a971dce

File tree

6 files changed

+23
-29
lines changed

6 files changed

+23
-29
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,13 @@ Reference implementation used in MICCAI 2022 [ACR-NVIDIA-NCI Breast Density FL c
166166

167167
**Digital Pathology**
168168
#### [Whole Slide Tumor Detection](./pathology/tumor_detection)
169-
The example show how to train and evaluate a tumor detection model (based on patch classification) on whole-slide histopathology images.
169+
The example shows how to train and evaluate a tumor detection model (based on patch classification) on whole-slide histopathology images.
170170

171171
#### [Profiling Whole Slide Tumor Detection](./pathology/tumor_detection)
172-
The example show how to use MONAI NVTX transforms to tag and profile pre- and post-processing transforms in the digital pathology whole slide tumor detection pipeline.
172+
The example shows how to use MONAI NVTX transforms to tag and profile pre- and post-processing transforms in the digital pathology whole slide tumor detection pipeline.
173+
174+
#### [Multiple Instance Learning WSI classification](./pathology/multiple_instance_learning)
175+
An example of Multiple Instance Learning (MIL) classification from Whole Slide Images (WSI) of prostate histopathology.
173176

174177
#### [NuClick:Interactive Annotation for Pathology](./pathology/nuclick)
175178
The notebook demonstrates examples of training and inference pipelines with interactive annotation for pathology, NuClick is used for delineating nuclei, cells and a squiggle for outlining glands.

pathology/multiple_instance_learning/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ The script is tested with:
2727
Please install the required dependencies
2828

2929
```bash
30-
pip install tifffile
31-
pip install imagecodecs
30+
pip install cucim gdown
3231
```
3332

3433
For more information please check out [the installation guide](https://docs.monai.io/en/latest/installation.html).
Loading
Loading
Loading

pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,18 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
4848

4949
for idx, batch_data in enumerate(loader):
5050

51-
data, target = batch_data["image"].cuda(args.rank), batch_data["label"].cuda(args.rank)
51+
data = batch_data["image"].as_subclass(torch.Tensor).cuda(args.rank)
52+
target = batch_data["label"].as_subclass(torch.Tensor).cuda(args.rank)
5253

5354
optimizer.zero_grad(set_to_none=True)
5455

5556
with autocast(enabled=args.amp):
5657
logits = model(data)
5758
loss = criterion(logits, target)
5859

59-
if args.amp:
60-
scaler.scale(loss).backward()
61-
scaler.step(optimizer)
62-
scaler.update()
63-
else:
64-
loss.backward()
65-
optimizer.step()
60+
scaler.scale(loss).backward()
61+
scaler.step(optimizer)
62+
scaler.update()
6663

6764
acc = (logits.sigmoid().sum(1).detach().round() == target.sum(1).round()).float().mean()
6865

@@ -108,7 +105,8 @@ def val_epoch(model, loader, epoch, args, max_tiles=None):
108105

109106
for idx, batch_data in enumerate(loader):
110107

111-
data, target = batch_data["image"].cuda(args.rank), batch_data["label"].cuda(args.rank)
108+
data = batch_data["image"].as_subclass(torch.Tensor).cuda(args.rank)
109+
target = batch_data["label"].as_subclass(torch.Tensor).cuda(args.rank)
112110

113111
with autocast(enabled=args.amp):
114112

@@ -122,7 +120,7 @@ def val_epoch(model, loader, epoch, args, max_tiles=None):
122120
logits2 = []
123121

124122
for i in range(int(np.ceil(data.shape[1] / float(max_tiles)))):
125-
data_slice = data[:, i * max_tiles : (i + 1) * max_tiles]
123+
data_slice = data[:, i * max_tiles: (i + 1) * max_tiles]
126124
logits_slice = model(data_slice, no_head=True)
127125
logits.append(logits_slice)
128126

@@ -329,7 +327,7 @@ def main_worker(gpu, args):
329327
shuffle=(train_sampler is None),
330328
num_workers=args.workers,
331329
pin_memory=True,
332-
multiprocessing_context="spawn",
330+
multiprocessing_context="spawn" if args.workers > 0 else None,
333331
sampler=train_sampler,
334332
collate_fn=list_data_collate,
335333
)
@@ -339,7 +337,7 @@ def main_worker(gpu, args):
339337
shuffle=False,
340338
num_workers=args.workers,
341339
pin_memory=True,
342-
multiprocessing_context="spawn",
340+
multiprocessing_context="spawn" if args.workers > 0 else None,
343341
sampler=val_sampler,
344342
collate_fn=list_data_collate,
345343
)
@@ -399,13 +397,11 @@ def main_worker(gpu, args):
399397
else:
400398
writer = None
401399

402-
###RUN TRAINING
400+
#RUN TRAINING
403401
n_epochs = args.epochs
404402
val_acc_max = 0.0
405403

406-
scaler = None
407-
if args.amp: # new native amp
408-
scaler = GradScaler()
404+
scaler = GradScaler(enabled=args.amp)
409405

410406
for epoch in range(start_epoch, n_epochs):
411407

@@ -430,9 +426,6 @@ def main_worker(gpu, args):
430426
writer.add_scalar("train_loss", train_loss, epoch)
431427
writer.add_scalar("train_acc", train_acc, epoch)
432428

433-
if args.distributed:
434-
torch.distributed.barrier()
435-
436429
b_new_best = False
437430
val_acc = 0
438431
if (epoch + 1) % args.val_every == 0:
@@ -494,21 +487,21 @@ def parse_args():
494487

495488
parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs")
496489

497-
parser.add_argument("--epochs", default=50, type=int, help="number of training epochs")
490+
parser.add_argument("--epochs", "--max_epochs", default=50, type=int, help="number of training epochs")
498491
parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu")
499492
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
500493

501494
parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay")
502495
parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
503-
parser.add_argument(
504-
"--val_every",
496+
parser.add_argument("--val_every",
497+
"--val_interval",
505498
default=1,
506499
type=int,
507500
help="run validation after this number of epochs, default 1 to run every epoch",
508501
)
509502
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
510503

511-
###for multigpu
504+
#for multigpu
512505
parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended")
513506
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
514507
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
@@ -519,7 +512,7 @@ def parse_args():
519512

520513
parser.add_argument(
521514
"--quick", action="store_true", help="use a small subset of data for debugging"
522-
) # for debugging
515+
)
523516

524517
args = parser.parse_args()
525518

@@ -530,7 +523,6 @@ def parse_args():
530523

531524
return args
532525

533-
534526
if __name__ == "__main__":
535527

536528
args = parse_args()

0 commit comments

Comments
 (0)