Skip to content

HoVerNet training pipeline #999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 76 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
04e79ed
Merge remote-tracking branch 'origin/main' into main
KumoLiu Jul 21, 2022
4c2b7b3
Merge remote-tracking branch 'origin/main' into main
KumoLiu Jul 22, 2022
078a770
Merge remote-tracking branch 'origin/main' into main
KumoLiu Jul 22, 2022
38de0d2
Merge remote-tracking branch 'origin/main' into main
KumoLiu Jul 25, 2022
4726117
Merge remote-tracking branch 'origin/main' into main
KumoLiu Jul 26, 2022
79dd41f
Merge remote-tracking branch 'origin/main' into main
KumoLiu Aug 6, 2022
5ac6c2d
Merge remote-tracking branch 'origin/main' into main
KumoLiu Aug 18, 2022
9138bcd
Merge remote-tracking branch 'origin/main' into main
KumoLiu Aug 22, 2022
fd6a6b1
easy-integrate-bundle-v1
KumoLiu Aug 22, 2022
58ce3ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2022
7b69b3a
add multigpu implementation
KumoLiu Aug 23, 2022
567da73
Merge remote-tracking branch 'yliu/easy-integrate-bundle' into main
KumoLiu Aug 23, 2022
a152924
Merge remote-tracking branch 'yliu/easy-integrate-bundle' into easy-i…
KumoLiu Aug 23, 2022
e4fd544
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2022
e40cf40
Merge remote-tracking branch 'yliu/easy-integrate-bundle' into main
KumoLiu Aug 23, 2022
a4f29d1
Merge branch 'main' of https://github.com/Project-MONAI/tutorials int…
KumoLiu Sep 9, 2022
4494091
Merge branch 'main' of https://github.com/Project-MONAI/tutorials int…
KumoLiu Sep 9, 2022
1c88106
Merge branch 'main' of https://github.com/Project-MONAI/tutorials int…
KumoLiu Sep 9, 2022
82590de
Merge branch 'main' of https://github.com/Project-MONAI/tutorials int…
KumoLiu Sep 21, 2022
396edfd
Merge branch 'main' of https://github.com/Project-MONAI/tutorials int…
KumoLiu Sep 21, 2022
38091e8
Merge remote-tracking branch 'origin/main' into hovernet-train
KumoLiu Oct 17, 2022
6449a97
first commit
KumoLiu Oct 17, 2022
96aac94
draft toch training pipeline
KumoLiu Oct 19, 2022
635bcdf
Merge remote-tracking branch 'origin/main' into hovernet-train
KumoLiu Oct 19, 2022
29e3655
minor fix
KumoLiu Oct 19, 2022
23fbdd5
minor fix
KumoLiu Oct 19, 2022
1e83974
draft ignite pipeline
KumoLiu Oct 19, 2022
000f496
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2022
46bb81b
minor fix
KumoLiu Oct 20, 2022
028f301
update based on comments
KumoLiu Oct 20, 2022
dcaa591
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2022
5dd1f8d
enable multigpu
KumoLiu Oct 20, 2022
8adff4f
minor fix
KumoLiu Oct 21, 2022
9f0e4e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2022
41af7e3
add torch version
KumoLiu Oct 24, 2022
85f971f
use original repo split
KumoLiu Oct 26, 2022
a107e8a
pipeline for lizard
KumoLiu Oct 28, 2022
ef48608
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2022
d1295b2
Merge branch 'main' into hovernet-train
bhashemian Nov 3, 2022
9cc944d
update consep pipeline
KumoLiu Nov 8, 2022
8c73563
Merge branch 'hovernet-train' of https://github.com/KumoLiu/tutorials…
KumoLiu Nov 8, 2022
deb2e78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2022
46efba3
add Infer (torch version)
KumoLiu Nov 8, 2022
175deaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2022
74c27fc
Merge branch 'main' into hovernet-train
bhashemian Nov 10, 2022
c0e0e69
update ignite version
KumoLiu Nov 18, 2022
147bd4f
Merge branch 'hovernet-train' of https://github.com/KumoLiu/tutorials…
KumoLiu Nov 18, 2022
81892f1
Merge remote-tracking branch 'origin/main' into hovernet-train
KumoLiu Nov 18, 2022
1e81576
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
204c763
minor fix
KumoLiu Nov 18, 2022
6675212
Merge remote-tracking branch 'yliu/hovernet-train' into hovernet-train
KumoLiu Nov 18, 2022
4173cc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
8b84c7a
minor fix
KumoLiu Nov 18, 2022
cb20f2f
add ignite version evaluation
KumoLiu Nov 21, 2022
c5acaef
rm torch version train and infer
KumoLiu Nov 21, 2022
82a3578
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2022
747ef46
Merge branch 'main' into hovernet-train
bhashemian Nov 21, 2022
282e3db
add mode args
KumoLiu Nov 22, 2022
d2a4466
add README
KumoLiu Nov 22, 2022
0a52995
Update the terms of use
KumoLiu Nov 22, 2022
3f73587
Merge branch 'hovernet-train' of https://github.com/KumoLiu/tutorials…
KumoLiu Nov 22, 2022
44b7bc1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2022
cc1922c
address comments
KumoLiu Nov 23, 2022
30cf617
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2022
0b8336b
update metrics
KumoLiu Nov 23, 2022
8b829fd
Merge branch 'main' into hovernet-train
bhashemian Nov 23, 2022
44bbf1d
add prepare patches
KumoLiu Nov 24, 2022
7fb79f8
Merge branch 'hovernet-train' of https://github.com/KumoLiu/tutorials…
KumoLiu Nov 24, 2022
1a9f7f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2022
74a6acf
minor fix
KumoLiu Nov 24, 2022
75475dd
minor fix
KumoLiu Nov 24, 2022
a880c56
replace cv2 with PIL
KumoLiu Nov 28, 2022
e99e99e
minor fix
KumoLiu Nov 29, 2022
1b9f64d
Merge branch 'main' of https://github.com/Project-MONAI/tutorials int…
KumoLiu Nov 29, 2022
f340451
Merge branch 'main' into hovernet-train
bhashemian Dec 5, 2022
7be403e
Merge branch 'main' into hovernet-train
bhashemian Dec 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions pathology/hovernet/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# HoVerNet Examples

This folder contains ignite version examples to run train and validate a HoVerNet model.
It also has torch version notebooks to run training and evaluation.
<p align="center">
<img src="https://ars.els-cdn.com/content/image/1-s2.0-S1361841519301045-fx1_lrg.jpg" alt="hovernet scheme")
</p>
implementation based on:

Simon Graham et al., HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images.' Medical Image Analysis, (2019). https://arxiv.org/abs/1812.06499

### 1. Data

CoNSeP datasets which are used in the examples can be downloaded from https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet/.
- First download CoNSeP dataset to `data_root`.
- Run prepare_patches.py to prepare patches from images.

### 2. Questions and bugs

- For questions relating to the use of MONAI, please us our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).


### 3. List of notebooks and examples
#### [Prepare Your Data](./prepare_patches.py)
This example is used to prepare patches from tiles referring to the implementation from https://github.com/vqdang/hover_net/blob/master/extract_patches.py. Prepared patches will be saved in `data_root`/Prepared.

```bash
# Run to know all possible options
python ./prepare_patches.py -h

# Prepare patches from images
python ./prepare_patches.py \
--root `data_root`
```

#### [HoVerNet Training](./training.py)
This example uses MONAI workflow to train a HoVerNet model on prepared CoNSeP dataset.
Since HoVerNet is training via a two-stage approach. First initialised the model with pre-trained weights on the [ImageNet dataset](https://ieeexplore.ieee.org/document/5206848), trained only the decoders for the first 50 epochs, and then fine-tuned all layers for another 50 epochs. We need to specify `--stage` during training.

Each user is responsible for checking the content of models/datasets and the applicable licenses and determining if suitable for the intended use.
The license for the pre-trained model used in examples is different than MONAI license. Please check the source where these weights are obtained from:
https://github.com/vqdang/hover_net#data-format


```bash
# Run to know all possible options
python ./training.py -h

# Train a hovernet model on single-gpu(replace with your own ckpt path)
export CUDA_VISIBLE_DEVICES=0; python training.py \
--ep 50 \
--stage 0 \
--bs 16 \
--root `save_root`
export CUDA_VISIBLE_DEVICES=0; python training.py \
--ep 50 \
--stage 1 \
--bs 4 \
--root `save_root` \
--ckpt logs/stage0/checkpoint_epoch=50.pt

# Train a hovernet model on multi-gpu (NVIDIA)(replace with your own ckpt path)
torchrun --nnodes=1 --nproc_per_node=2 training.py \
--ep 50 \
--bs 8 \
--root `save_root` \
--stage 0
torchrun --nnodes=1 --nproc_per_node=2 training.py \
--ep 50 \
--bs 2 \
--root `save_root` \
--stage 1 \
--ckpt logs/stage0/checkpoint_epoch=50.pt
```

#### [HoVerNet Validation](./evaluation.py)
This example uses MONAI workflow to evaluate the trained HoVerNet model on prepared test data from CoNSeP dataset.
With their metrics on original mode. We reproduce the results with Dice: 0.82762; PQ: 0.48976; F1d: 0.73592.
```bash
# Run to know all possible options
python ./evaluation.py -h

# Evaluate a HoVerNet model
python ./evaluation.py
--root `save_root` \
--ckpt logs/stage0/checkpoint_epoch=50.pt
```

## Disclaimer

This is an example, not to be used for diagnostic purposes.
150 changes: 150 additions & 0 deletions pathology/hovernet/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import glob
import logging
import torch
from argparse import ArgumentParser
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import HoVerNet
from monai.engines import SupervisedEvaluator
from monai.transforms import (
LoadImaged,
Lambdad,
Activationsd,
Compose,
CastToTyped,
ComputeHoVerMapsd,
ScaleIntensityRanged,
CenterSpatialCropd,
)
from monai.handlers import (
MeanDice,
StatsHandler,
CheckpointLoader,
)
from monai.utils.enums import HoVerNetBranch
from monai.apps.pathology.handlers.utils import from_engine_hovernet
from monai.apps.pathology.engines.utils import PrepareBatchHoVerNet
from skimage import measure


def prepare_data(data_dir, phase):
data_dir = os.path.join(data_dir, phase)

images = list(sorted(
glob.glob(os.path.join(data_dir, "*/*image.npy"))))
inst_maps = list(sorted(
glob.glob(os.path.join(data_dir, "*/*inst_map.npy"))))
type_maps = list(sorted(
glob.glob(os.path.join(data_dir, "*/*type_map.npy"))))

data_dicts = [
{"image": _image, "label_inst": _inst_map, "label_type": _type_map}
for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps)
]

return data_dicts


def run(cfg):
if cfg["mode"].lower() == "original":
cfg["patch_size"] = [270, 270]
cfg["out_size"] = [80, 80]
elif cfg["mode"].lower() == "fast":
cfg["patch_size"] = [256, 256]
cfg["out_size"] = [164, 164]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_transforms = Compose(
[
LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True),
Lambdad(keys="label_inst", func=lambda x: measure.label(x)),
CastToTyped(keys=["image", "label_inst"], dtype=torch.int),
CenterSpatialCropd(
keys="image",
roi_size=cfg["patch_size"],
),
ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
ComputeHoVerMapsd(keys="label_inst"),
Lambdad(keys="label_inst", func=lambda x: x > 0, overwrite="label"),
CenterSpatialCropd(
keys=["label", "hover_label_inst", "label_inst", "label_type"],
roi_size=cfg["out_size"],
),
CastToTyped(keys=["image", "label_inst", "label_type"], dtype=torch.float32),
]
)

# Create MONAI DataLoaders
valid_data = prepare_data(cfg["root"], "valid")
valid_ds = CacheDataset(data=valid_data, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(
valid_ds,
batch_size=cfg["batch_size"],
num_workers=cfg["num_workers"],
pin_memory=torch.cuda.is_available()
)

# initialize model
model = HoVerNet(
mode=cfg["mode"],
in_channels=3,
out_classes=cfg["out_classes"],
act=("relu", {"inplace": True}),
norm="batch",
pretrained_url=None,
freeze_encoder=False,
).to(device)

post_process_np = Compose([
Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1: 2, ...] > 0.5)])
post_process = Lambdad(keys="pred", func=post_process_np)

# Evaluator
val_handlers = [
CheckpointLoader(load_path=cfg["ckpt_path"], load_dict={"net": model}),
StatsHandler(output_transform=lambda x: None),
]
evaluator = SupervisedEvaluator(
device=device,
val_data_loader=val_loader,
prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']),
network=model,
postprocessing=post_process,
key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))},
val_handlers=val_handlers,
amp=cfg["amp"],
)

state = evaluator.run()
print(state)


def main():
parser = ArgumentParser(description="Tumor detection on whole slide pathology images.")
parser.add_argument(
"--root",
type=str,
default="/workspace/Data/CoNSeP/Prepared/consep",
help="root data dir",
)

parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size")
parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp")
parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes")
parser.add_argument("--mode", type=str, default="original", help="choose either `original` or `fast`")

parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers")
parser.add_argument("--use_gpu", type=bool, default=True, dest="use_gpu", help="whether to use gpu")
parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path")

args = parser.parse_args()
cfg = vars(args)
print(cfg)

logging.basicConfig(level=logging.INFO)
run(cfg)


if __name__ == "__main__":
main()
Loading