Skip to content

ROI Inference pipeline for HoVerNet #1055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Dec 8, 2022
Merged
Changes from 10 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
979e05a
Add a draft of inference
bhashemian Nov 10, 2022
d3b77de
Uncomment load weights
bhashemian Nov 14, 2022
fc271ef
Add infer_roi
bhashemian Nov 15, 2022
6582442
Major updates
bhashemian Nov 17, 2022
d9d19b5
Update the pipeline
bhashemian Nov 21, 2022
4d6bd79
Merge branch 'main' into wsi-inference-hovernet
bhashemian Nov 21, 2022
0cdac25
keep rio inference only
bhashemian Nov 21, 2022
e58050c
Remove test-oly lines
bhashemian Nov 21, 2022
123b721
Add sw batch size
bhashemian Nov 21, 2022
04c1f45
Change settings:
bhashemian Nov 21, 2022
6b61c77
Address comments
bhashemian Nov 22, 2022
a534937
Merge branch 'main' into wsi-inference-hovernet
bhashemian Nov 22, 2022
ea99823
clean up
bhashemian Nov 22, 2022
e79f5f8
fix a typo
bhashemian Nov 22, 2022
0f0884f
change logic of few args
bhashemian Nov 22, 2022
b1dac05
Add multi-gpu
bhashemian Nov 22, 2022
ccc62ed
Add/remove prints
bhashemian Nov 22, 2022
a10f7a8
Rename to inference
bhashemian Nov 22, 2022
ec9d8f1
Add output class
bhashemian Nov 22, 2022
73400f2
Update to FalttenSubKeysd
bhashemian Nov 23, 2022
eb23e7d
Merge branch 'main' of https://github.com/Project-MONAI/tutorials int…
bhashemian Nov 30, 2022
aa4a770
Update with the new hovernet postprocessing
bhashemian Dec 1, 2022
5eb776f
Merge branch 'main' into wsi-inference-hovernet
bhashemian Dec 1, 2022
a4aa476
change to nuclear type
bhashemian Dec 2, 2022
4deb7c5
to png
bhashemian Dec 5, 2022
16ad09e
remove test transform
bhashemian Dec 5, 2022
f24b2c2
Merge branch 'main' into wsi-inference-hovernet
bhashemian Dec 8, 2022
7b06bb5
Merge branch 'main' into wsi-inference-hovernet
bhashemian Dec 8, 2022
9c4a2bf
Some updates
bhashemian Dec 8, 2022
d03067a
Remove device
bhashemian Dec 8, 2022
0450963
few improvements
bhashemian Dec 8, 2022
56d2e98
improvments and bug fix
bhashemian Dec 8, 2022
8df9008
Update default run
bhashemian Dec 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions pathology/hovernet/infer_roi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import logging
import os
import time
from argparse import ArgumentParser
from glob import glob

import torch
import torch.distributed as dist

from monai.apps.pathology.inferers import SlidingWindowHoVerNetInferer
from monai.apps.pathology.transforms import (
GenerateDistanceMapd,
GenerateInstanceBorderd,
GenerateWatershedMarkersd,
GenerateWatershedMaskd,
HoVerNetNuclearTypePostProcessingd,
Watershedd,
)
from monai.data import DataLoader, Dataset, PILReader
from monai.engines import SupervisedEvaluator
from monai.networks.nets import HoVerNet
from monai.transforms import (
Activationsd,
AsDiscreted,
CastToTyped,
Compose,
EnsureChannelFirstd,
FillHoles,
FromMetaTensord,
GaussianSmooth,
LoadImaged,
PromoteChildItemsd,
SaveImaged,
ScaleIntensityRanged,
)
from monai.utils import HoVerNetBranch, first


def create_output_dir(cfg):
if cfg["mode"].lower() == "original":
cfg["patch_size"] = 270
cfg["out_size"] = 80
elif cfg["mode"].lower() == "fast":
cfg["patch_size"] = 256
cfg["out_size"] = 164

timestamp = time.strftime("%y%m%d-%H%M%S")
run_folder_name = f"{timestamp}_inference_hovernet_ps{cfg['patch_size']}"
log_dir = os.path.join(cfg["output"], run_folder_name)
print(f"Logs and outputs are saved at '{log_dir}'.")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir


def run(cfg):
# --------------------------------------------------------------------------
# Set Directory and Device
# --------------------------------------------------------------------------
output_dir = create_output_dir(cfg)
multi_gpu = True if cfg["use_gpu"] and torch.cuda.device_count() > 1 else False
if multi_gpu:
dist.init_process_group(backend="nccl", init_method="env://")
device = torch.device("cuda:{}".format(dist.get_rank()))
torch.cuda.set_device(device)
else:
device = torch.device("cuda" if cfg["use_gpu"] else "cpu")

# --------------------------------------------------------------------------
# Transforms
# --------------------------------------------------------------------------
# Preprocessing transforms
pre_transforms = Compose(
[
LoadImaged(keys=["image"], reader=PILReader, converter=lambda x: x.convert("RGB")),
EnsureChannelFirstd(keys=["image"]),
CastToTyped(keys=["image"], dtype=torch.float32),
ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
]
)
# Postprocessing transforms
post_transforms = Compose(
[
PromoteChildItemsd(
keys="pred",
child_keys=[HoVerNetBranch.NC.value, HoVerNetBranch.NP.value, HoVerNetBranch.HV.value],
delete_keys=True,
),
Activationsd(keys=HoVerNetBranch.NC.value, softmax=True),
AsDiscreted(keys=HoVerNetBranch.NC.value, argmax=True),
GenerateWatershedMaskd(keys=HoVerNetBranch.NP.value, softmax=True),
GenerateInstanceBorderd(keys="mask", hover_map_key=HoVerNetBranch.HV.value, kernel_size=3),
GenerateDistanceMapd(keys="mask", border_key="border", smooth_fn=GaussianSmooth()),
GenerateWatershedMarkersd(
keys="mask",
border_key="border",
threshold=0.7,
radius=2,
postprocess_fn=FillHoles(),
),
Watershedd(keys="dist", mask_key="mask", markers_key="markers"),
HoVerNetNuclearTypePostProcessingd(type_pred_key=HoVerNetBranch.NC.value, instance_pred_key="dist"),
FromMetaTensord(keys=["image", "pred_binary"]),
SaveImaged(
keys="pred_binary",
meta_keys="image_meta_dict",
output_ext="png",
output_dir=output_dir,
output_postfix="pred",
output_dtype="uint8",
separate_folder=False,
scale=255,
),
]
)
# --------------------------------------------------------------------------
# Data and Data Loading
# --------------------------------------------------------------------------
# List of whole slide images
data_list = [{"image": image} for image in glob(os.path.join(cfg["root"], "*.png"))]

# Dataset
dataset = Dataset(data_list, transform=pre_transforms)

# Dataloader
data_loader = DataLoader(
dataset,
num_workers=cfg["num_workers"],
batch_size=cfg["batch_size"],
pin_memory=True,
)

# --------------------------------------------------------------------------
# Run some sanity checks
# --------------------------------------------------------------------------
# Check first sample
first_sample = first(data_loader)
if first_sample is None:
raise ValueError("First sample is None!")
print("image: ")
print(" shape", first_sample["image"].shape)
print(" type: ", type(first_sample["image"]))
print(" dtype: ", first_sample["image"].dtype)
print(f"batch size: {cfg['batch_size']}")
print(f"number of batches: {len(data_loader)}")

# --------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------
# Create model and load weights
model = HoVerNet(
mode=cfg["mode"],
in_channels=3,
out_classes=5,
act=("relu", {"inplace": True}),
norm="batch",
).to(device)
model.load_state_dict(torch.load(cfg["ckpt"], map_location=device))
model.eval()

if multi_gpu:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[dist.get_rank()], output_device=dist.get_rank()
)

# --------------------------------------------
# Inference
# --------------------------------------------
# Inference engine
sliding_inferer = SlidingWindowHoVerNetInferer(
roi_size=cfg["patch_size"],
sw_batch_size=cfg["sw_batch_size"],
overlap=1.0 - float(cfg["out_size"]) / float(cfg["patch_size"]),
padding_mode="constant",
cval=0,
sw_device=device,
device=device,
progress=True,
extra_input_padding=((cfg["patch_size"] - cfg["out_size"]) // 2,) * 4,
)

evaluator = SupervisedEvaluator(
device=device,
val_data_loader=data_loader,
network=model,
postprocessing=post_transforms,
inferer=sliding_inferer,
amp=cfg["amp"],
)
evaluator.run()

if multi_gpu:
dist.destroy_process_group()


def main():
logging.basicConfig(level=logging.INFO)

parser = ArgumentParser(description="Tumor detection on whole slide pathology images.")
parser.add_argument("--root", type=str, default="./CoNSeP/Test/Images", help="image root dir")
parser.add_argument("--output", type=str, default="./logs/", dest="output", help="log directory")
parser.add_argument("--ckpt", type=str, default="./model_CoNSeP_new.pth", help="Path to the pytorch checkpoint")
parser.add_argument("--mode", type=str, default="original", help="HoVerNet mode (original/fast)")
parser.add_argument("--bs", type=int, default=1, dest="batch_size", help="batch size")
parser.add_argument("--swbs", type=int, default=8, dest="sw_batch_size", help="sliding window batch size")
parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp")
parser.add_argument("--cpu", type=int, default=0, dest="num_workers", help="number of workers")
parser.add_argument("--use-gpu", action="store_true", help="whether to use gpu")
args = parser.parse_args()

config_dict = vars(args)
print(config_dict)
run(config_dict)


if __name__ == "__main__":
main()