-
Notifications
You must be signed in to change notification settings - Fork 740
ROI Inference pipeline for HoVerNet #1055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
bhashemian
merged 33 commits into
Project-MONAI:main
from
bhashemian:wsi-inference-hovernet
Dec 8, 2022
Merged
Changes from 10 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
979e05a
Add a draft of inference
bhashemian d3b77de
Uncomment load weights
bhashemian fc271ef
Add infer_roi
bhashemian 6582442
Major updates
bhashemian d9d19b5
Update the pipeline
bhashemian 4d6bd79
Merge branch 'main' into wsi-inference-hovernet
bhashemian 0cdac25
keep rio inference only
bhashemian e58050c
Remove test-oly lines
bhashemian 123b721
Add sw batch size
bhashemian 04c1f45
Change settings:
bhashemian 6b61c77
Address comments
bhashemian a534937
Merge branch 'main' into wsi-inference-hovernet
bhashemian ea99823
clean up
bhashemian e79f5f8
fix a typo
bhashemian 0f0884f
change logic of few args
bhashemian b1dac05
Add multi-gpu
bhashemian ccc62ed
Add/remove prints
bhashemian a10f7a8
Rename to inference
bhashemian ec9d8f1
Add output class
bhashemian 73400f2
Update to FalttenSubKeysd
bhashemian eb23e7d
Merge branch 'main' of https://github.com/Project-MONAI/tutorials int…
bhashemian aa4a770
Update with the new hovernet postprocessing
bhashemian 5eb776f
Merge branch 'main' into wsi-inference-hovernet
bhashemian a4aa476
change to nuclear type
bhashemian 4deb7c5
to png
bhashemian 16ad09e
remove test transform
bhashemian f24b2c2
Merge branch 'main' into wsi-inference-hovernet
bhashemian 7b06bb5
Merge branch 'main' into wsi-inference-hovernet
bhashemian 9c4a2bf
Some updates
bhashemian d03067a
Remove device
bhashemian 0450963
few improvements
bhashemian 56d2e98
improvments and bug fix
bhashemian 8df9008
Update default run
bhashemian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import logging | ||
import os | ||
import time | ||
from argparse import ArgumentParser | ||
from glob import glob | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from monai.apps.pathology.inferers import SlidingWindowHoVerNetInferer | ||
from monai.apps.pathology.transforms import ( | ||
GenerateDistanceMapd, | ||
GenerateInstanceBorderd, | ||
GenerateWatershedMarkersd, | ||
GenerateWatershedMaskd, | ||
HoVerNetNuclearTypePostProcessingd, | ||
Watershedd, | ||
) | ||
from monai.data import DataLoader, Dataset, PILReader | ||
from monai.engines import SupervisedEvaluator | ||
from monai.networks.nets import HoVerNet | ||
from monai.transforms import ( | ||
Activationsd, | ||
AsDiscreted, | ||
CastToTyped, | ||
Compose, | ||
EnsureChannelFirstd, | ||
FillHoles, | ||
FromMetaTensord, | ||
GaussianSmooth, | ||
LoadImaged, | ||
PromoteChildItemsd, | ||
SaveImaged, | ||
ScaleIntensityRanged, | ||
) | ||
from monai.utils import HoVerNetBranch, first | ||
|
||
|
||
def create_output_dir(cfg): | ||
if cfg["mode"].lower() == "original": | ||
cfg["patch_size"] = 270 | ||
cfg["out_size"] = 80 | ||
elif cfg["mode"].lower() == "fast": | ||
cfg["patch_size"] = 256 | ||
cfg["out_size"] = 164 | ||
|
||
timestamp = time.strftime("%y%m%d-%H%M%S") | ||
run_folder_name = f"{timestamp}_inference_hovernet_ps{cfg['patch_size']}" | ||
log_dir = os.path.join(cfg["output"], run_folder_name) | ||
print(f"Logs and outputs are saved at '{log_dir}'.") | ||
if not os.path.exists(log_dir): | ||
os.makedirs(log_dir) | ||
return log_dir | ||
|
||
|
||
def run(cfg): | ||
# -------------------------------------------------------------------------- | ||
# Set Directory and Device | ||
# -------------------------------------------------------------------------- | ||
output_dir = create_output_dir(cfg) | ||
multi_gpu = True if cfg["use_gpu"] and torch.cuda.device_count() > 1 else False | ||
if multi_gpu: | ||
dist.init_process_group(backend="nccl", init_method="env://") | ||
device = torch.device("cuda:{}".format(dist.get_rank())) | ||
torch.cuda.set_device(device) | ||
else: | ||
device = torch.device("cuda" if cfg["use_gpu"] else "cpu") | ||
|
||
# -------------------------------------------------------------------------- | ||
# Transforms | ||
# -------------------------------------------------------------------------- | ||
# Preprocessing transforms | ||
pre_transforms = Compose( | ||
[ | ||
LoadImaged(keys=["image"], reader=PILReader, converter=lambda x: x.convert("RGB")), | ||
EnsureChannelFirstd(keys=["image"]), | ||
CastToTyped(keys=["image"], dtype=torch.float32), | ||
ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), | ||
] | ||
) | ||
# Postprocessing transforms | ||
post_transforms = Compose( | ||
[ | ||
PromoteChildItemsd( | ||
keys="pred", | ||
child_keys=[HoVerNetBranch.NC.value, HoVerNetBranch.NP.value, HoVerNetBranch.HV.value], | ||
delete_keys=True, | ||
), | ||
Activationsd(keys=HoVerNetBranch.NC.value, softmax=True), | ||
AsDiscreted(keys=HoVerNetBranch.NC.value, argmax=True), | ||
bhashemian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
GenerateWatershedMaskd(keys=HoVerNetBranch.NP.value, softmax=True), | ||
GenerateInstanceBorderd(keys="mask", hover_map_key=HoVerNetBranch.HV.value, kernel_size=3), | ||
GenerateDistanceMapd(keys="mask", border_key="border", smooth_fn=GaussianSmooth()), | ||
GenerateWatershedMarkersd( | ||
keys="mask", | ||
border_key="border", | ||
threshold=0.7, | ||
radius=2, | ||
postprocess_fn=FillHoles(), | ||
), | ||
Watershedd(keys="dist", mask_key="mask", markers_key="markers"), | ||
HoVerNetNuclearTypePostProcessingd(type_pred_key=HoVerNetBranch.NC.value, instance_pred_key="dist"), | ||
FromMetaTensord(keys=["image", "pred_binary"]), | ||
SaveImaged( | ||
keys="pred_binary", | ||
meta_keys="image_meta_dict", | ||
output_ext="png", | ||
output_dir=output_dir, | ||
output_postfix="pred", | ||
output_dtype="uint8", | ||
separate_folder=False, | ||
scale=255, | ||
), | ||
] | ||
) | ||
# -------------------------------------------------------------------------- | ||
# Data and Data Loading | ||
# -------------------------------------------------------------------------- | ||
# List of whole slide images | ||
data_list = [{"image": image} for image in glob(os.path.join(cfg["root"], "*.png"))] | ||
|
||
# Dataset | ||
dataset = Dataset(data_list, transform=pre_transforms) | ||
|
||
# Dataloader | ||
data_loader = DataLoader( | ||
dataset, | ||
num_workers=cfg["num_workers"], | ||
batch_size=cfg["batch_size"], | ||
pin_memory=True, | ||
) | ||
|
||
# -------------------------------------------------------------------------- | ||
# Run some sanity checks | ||
# -------------------------------------------------------------------------- | ||
# Check first sample | ||
first_sample = first(data_loader) | ||
if first_sample is None: | ||
raise ValueError("First sample is None!") | ||
print("image: ") | ||
print(" shape", first_sample["image"].shape) | ||
print(" type: ", type(first_sample["image"])) | ||
print(" dtype: ", first_sample["image"].dtype) | ||
print(f"batch size: {cfg['batch_size']}") | ||
print(f"number of batches: {len(data_loader)}") | ||
|
||
# -------------------------------------------------------------------------- | ||
# Model | ||
# -------------------------------------------------------------------------- | ||
# Create model and load weights | ||
model = HoVerNet( | ||
mode=cfg["mode"], | ||
in_channels=3, | ||
out_classes=5, | ||
act=("relu", {"inplace": True}), | ||
norm="batch", | ||
).to(device) | ||
model.load_state_dict(torch.load(cfg["ckpt"], map_location=device)) | ||
model.eval() | ||
|
||
if multi_gpu: | ||
model = torch.nn.parallel.DistributedDataParallel( | ||
model, device_ids=[dist.get_rank()], output_device=dist.get_rank() | ||
) | ||
|
||
# -------------------------------------------- | ||
# Inference | ||
# -------------------------------------------- | ||
# Inference engine | ||
sliding_inferer = SlidingWindowHoVerNetInferer( | ||
roi_size=cfg["patch_size"], | ||
sw_batch_size=cfg["sw_batch_size"], | ||
overlap=1.0 - float(cfg["out_size"]) / float(cfg["patch_size"]), | ||
padding_mode="constant", | ||
cval=0, | ||
sw_device=device, | ||
device=device, | ||
progress=True, | ||
extra_input_padding=((cfg["patch_size"] - cfg["out_size"]) // 2,) * 4, | ||
) | ||
|
||
evaluator = SupervisedEvaluator( | ||
device=device, | ||
val_data_loader=data_loader, | ||
network=model, | ||
postprocessing=post_transforms, | ||
inferer=sliding_inferer, | ||
amp=cfg["amp"], | ||
) | ||
evaluator.run() | ||
|
||
if multi_gpu: | ||
dist.destroy_process_group() | ||
|
||
|
||
def main(): | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
parser = ArgumentParser(description="Tumor detection on whole slide pathology images.") | ||
parser.add_argument("--root", type=str, default="./CoNSeP/Test/Images", help="image root dir") | ||
parser.add_argument("--output", type=str, default="./logs/", dest="output", help="log directory") | ||
parser.add_argument("--ckpt", type=str, default="./model_CoNSeP_new.pth", help="Path to the pytorch checkpoint") | ||
parser.add_argument("--mode", type=str, default="original", help="HoVerNet mode (original/fast)") | ||
parser.add_argument("--bs", type=int, default=1, dest="batch_size", help="batch size") | ||
parser.add_argument("--swbs", type=int, default=8, dest="sw_batch_size", help="sliding window batch size") | ||
parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp") | ||
parser.add_argument("--cpu", type=int, default=0, dest="num_workers", help="number of workers") | ||
parser.add_argument("--use-gpu", action="store_true", help="whether to use gpu") | ||
args = parser.parse_args() | ||
|
||
config_dict = vars(args) | ||
print(config_dict) | ||
run(config_dict) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.