Skip to content

Update tumor detection pipline with new components #697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6d49bba
Update tumor detection pipline with new components
bhashemian May 11, 2022
3745cf9
formatting
bhashemian May 11, 2022
414fea7
formatting
bhashemian May 11, 2022
96c9b46
Add sub dataset
bhashemian May 12, 2022
27717a1
Merge branch 'master' into update-tumor-det
bhashemian May 12, 2022
bbf2720
Convert xxxD transforms to xxxd
bhashemian May 12, 2022
b6bfd46
Merge branch 'update-tumor-det' of github.com:drbeh/tutorials into up…
bhashemian May 12, 2022
041ede0
Update pipline with nvtx annotation
bhashemian May 13, 2022
cb830be
Merge branch 'master' into update-tumor-det
bhashemian May 13, 2022
13acc50
Add transforms to validation
bhashemian May 13, 2022
dd918f1
Update tumor notebook
bhashemian May 13, 2022
358e809
Update torch-based pipeline
bhashemian May 13, 2022
e8bcced
Update perfomance profiling code
bhashemian May 13, 2022
6a2c468
Update grid shape
bhashemian May 16, 2022
edc6a02
Merge branch 'master' into update-tumor-det
bhashemian May 16, 2022
1e94103
Update
bhashemian May 16, 2022
6a624f2
Update a comment
bhashemian May 16, 2022
2514271
Merge branch 'update-tumor-det' of github.com:drbeh/tutorials into up…
bhashemian May 16, 2022
0e59498
Update training/validation dataset and their links
bhashemian May 16, 2022
87cae7b
Merge branch 'master' into update-tumor-det
bhashemian May 18, 2022
936ecd7
Fix image path
bhashemian May 26, 2022
cdeef33
Change to validate by default
bhashemian May 31, 2022
6a0aa32
Merge branch 'main' of github.com:Project-MONAI/tutorials into update…
bhashemian May 31, 2022
be9c6e4
Flip back the location dims
bhashemian Jun 2, 2022
0a64d80
Add filtered datasets
bhashemian Jun 2, 2022
819b45d
Update csv loading
bhashemian Jun 7, 2022
35ad2f8
Merge branch 'main' into update-tumor-det
bhashemian Jun 8, 2022
f648faa
Update validation defaults
bhashemian Jun 8, 2022
b71fa5e
Merge branch 'update-tumor-det' of github.com:drbeh/tutorials into up…
bhashemian Jun 8, 2022
4cddebf
Merge branch 'main' into update-tumor-det
bhashemian Jul 7, 2022
f8c04c3
Remove csv files
bhashemian Jul 8, 2022
bb91848
Merge branch 'main' into update-tumor-det
bhashemian Jul 11, 2022
d777ab4
Merge branch 'main' into update-tumor-det
bhashemian Jul 18, 2022
f3e8fe8
Merge branch 'main' into update-tumor-det
Nic-Ma Jul 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions pathology/tumor_detection/README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ The model is based on ResNet18 with the last fully connected layer replaced by a

All the data used to train and validate this model is from [Camelyon-16 Challenge](https://camelyon16.grand-challenge.org/). You can download all the images for "CAMELYON16" data set from various sources listed [here](https://camelyon17.grand-challenge.org/Data/).

Location information for training/validation patches (the location on the whole slide image where patches are extracted) are adopted from [NCRF/coords](https://github.com/baidu-research/NCRF/tree/master/coords). The reformatted coordinations and labels are stored in a json file (`dataset_0.json`), and can be downloaded from [here](https://drive.google.com/file/d/1m2pwko6hxwsxeDWZY2oSOV-_KT97Ol0o/view?usp=sharing)
Location information for training/validation patches (the location on the whole slide image where patches are extracted) are adopted from [NCRF/coords](https://github.com/baidu-research/NCRF/tree/master/coords). The reformatted coordinations and labels in CSV format for training (`training.csv`) can be found [here](https://drive.google.com/file/d/1httIjgji6U6rMIb0P8pE0F-hXFAuvQEf/view?usp=sharing) and for validation (`validation.csv`) can be found [here](https://drive.google.com/file/d/1tJulzl9m5LUm16IeFbOCoFnaSWoB6i5L/view?usp=sharing).

This pipeline expects the training/validation data (whole slide images) reside in `cfg["data_root"]/training/images`. By default `data_root` is pointing to `/workspace/data/medical/pathology/` You can easily modify it to point to a different directory by passing the following argument in the runtime: `--data-root /other/data/root/dir/`.
This pipeline expects the training/validation data (whole slide images) reside in `cfg["data_root"]/training/images`. By default `data_root` is pointing to the code folder `./`; however, you can easily modify it to point to a different directory by passing the following argument in the runtime: `--data-root /other/data/root/dir/`.

> `dataset_0_subset_0.json` is also provided [here](https://drive.google.com/file/d/1NCd0y4FR42maQpfZjzKlFSIX4oeKgysg/view?usp=sharing) to check the functionality of the pipeline using only two of the whole slide images: `tumor_001` and `tumor_101`. <br/>
> This dataset should not be used for the real training or any perfomance evaluation.
> [`training_sub.csv`](https://drive.google.com/file/d/1rO8ZY-TrU9nrOsx-Udn1q5PmUYrLG3Mv/view?usp=sharing) and [`validation_sub.csv`](https://drive.google.com/file/d/130pqsrc2e9wiHIImL8w4fT_5NktEGel7/view?usp=sharing) is also provided to check the functionality of the pipeline using only two of the whole slide images: `tumor_001` (for training) and `tumor_101` (for validation). This dataset should not be used for the real training or any performance evaluation.

### Input and output formats

Expand Down
157 changes: 77 additions & 80 deletions pathology/tumor_detection/ignite/camelyon_train_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,17 @@
import os

import logging
import os
import time
from argparse import ArgumentParser

import numpy as np

import pandas as pd
import torch
from torch.optim import SGD, lr_scheduler

from ignite.metrics import Accuracy
from torch.optim import SGD, lr_scheduler

import monai
from monai.data import DataLoader, load_decathlon_datalist
from monai.transforms import (
ActivationsD,
AsDiscreteD,
CastToTypeD,
Compose,
RandFlipD,
RandRotate90D,
RandZoomD,
ScaleIntensityRangeD,
ToNumpyD,
TorchVisionD,
ToTensorD,
)
from monai.utils import first, set_determinism
from monai.optimizers import Novograd
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from monai.data import DataLoader, PatchWSIDataset, CSVDataset
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
CheckpointSaver,
LrScheduleHandler,
Expand All @@ -37,10 +20,24 @@
ValidationHandler,
from_engine,
)

from monai.apps.pathology.data import PatchWSIDataset
from monai.networks.nets import TorchVisionFCModel

from monai.optimizers import Novograd
from monai.transforms import (
Activationsd,
AsDiscreted,
CastToTyped,
Compose,
GridSplitd,
Lambdad,
RandFlipd,
RandRotate90d,
RandZoomd,
ScaleIntensityRanged,
ToNumpyd,
TorchVisiond,
ToTensord,
)
from monai.utils import first, set_determinism

torch.backends.cudnn.enabled = True
set_determinism(seed=0, additional_settings=None)
Expand All @@ -65,7 +62,7 @@ def set_device(cfg):
if gpus and torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(n) for n in gpus])
device = torch.device("cuda")
print(f'CUDA is being used with GPU ID(s): {os.environ["CUDA_VISIBLE_DEVICES"]}')
print(f'CUDA is being used with GPU Id(s): {os.environ["CUDA_VISIBLE_DEVICES"]}')
else:
device = torch.device("cpu")
print("CPU only!")
Expand All @@ -82,54 +79,66 @@ def train(cfg):
# Build MONAI preprocessing
train_preprocess = Compose(
[
ToTensorD(keys="image"),
TorchVisionD(
Lambdad(keys="label", func=lambda x: x.reshape((1, cfg["grid_shape"], cfg["grid_shape"]))),
GridSplitd(
keys=("image", "label"),
grid=(cfg["grid_shape"], cfg["grid_shape"]),
size={"image": cfg["patch_size"], "label": 1},
),
ToTensord(keys=("image")),
TorchVisiond(
keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04
),
ToNumpyD(keys="image"),
RandFlipD(keys="image", prob=0.5),
RandRotate90D(keys="image", prob=0.5),
CastToTypeD(keys="image", dtype=np.float32),
RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
ToTensorD(keys=("image", "label")),
ToNumpyd(keys="image"),
RandFlipd(keys="image", prob=0.5),
RandRotate90d(keys="image", prob=0.5, max_k=3, spatial_axes=(-2, -1)),
CastToTyped(keys="image", dtype=np.float32),
RandZoomd(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
ScaleIntensityRanged(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
ToTensord(keys=("image", "label")),
]
)
valid_preprocess = Compose(
[
CastToTypeD(keys="image", dtype=np.float32),
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
ToTensorD(keys=("image", "label")),
Lambdad(keys="label", func=lambda x: x.reshape((1, cfg["grid_shape"], cfg["grid_shape"]))),
GridSplitd(
keys=("image", "label"),
grid=(cfg["grid_shape"], cfg["grid_shape"]),
size={"image": cfg["patch_size"], "label": 1},
),
CastToTyped(keys="image", dtype=np.float32),
ScaleIntensityRanged(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
ToTensord(keys=("image", "label")),
]
)
# __________________________________________________________________________
# Create MONAI dataset
train_json_info_list = load_decathlon_datalist(
data_list_file_path=cfg["dataset_json"],
data_list_key="training",
base_dir=cfg["data_root"],
train_data_list = CSVDataset(
cfg["train_file"],
col_groups={"image": 0, "patch_location": [2, 1], "label": [3, 6, 9, 4, 7, 10, 5, 8, 11]},
kwargs_read_csv={"header": None},
transform=Lambdad("image", lambda x: os.path.join(cfg["root"], "training/images", x + ".tif")),
)
valid_json_info_list = load_decathlon_datalist(
data_list_file_path=cfg["dataset_json"],
data_list_key="validation",
base_dir=cfg["data_root"],
train_dataset = PatchWSIDataset(
data=train_data_list,
patch_size=cfg["region_size"],
patch_level=0,
transform=train_preprocess,
reader="openslide" if cfg["use_openslide"] else "cuCIM",
)

train_dataset = PatchWSIDataset(
train_json_info_list,
cfg["region_size"],
cfg["grid_shape"],
cfg["patch_size"],
train_preprocess,
image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
valid_data_list = CSVDataset(
cfg["valid_file"],
col_groups={"image": 0, "patch_location": [2, 1], "label": [3, 6, 9, 4, 7, 10, 5, 8, 11]},
kwargs_read_csv={"header": None},
transform=Lambdad("image", lambda x: os.path.join(cfg["root"], "training/images", x + ".tif")),
)
valid_dataset = PatchWSIDataset(
valid_json_info_list,
cfg["region_size"],
cfg["grid_shape"],
cfg["patch_size"],
valid_preprocess,
image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
data=valid_data_list,
patch_size=cfg["region_size"],
patch_level=0,
transform=valid_preprocess,
reader="openslide" if cfg["use_openslide"] else "cuCIM",
)

# __________________________________________________________________________
Expand All @@ -141,12 +150,10 @@ def train(cfg):
valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True
)

# __________________________________________________________________________
# Get sample batch and some info
# Check first sample
first_sample = first(train_dataloader)
if first_sample is None:
raise ValueError("Fist sample is None!")

raise ValueError("First sample is None!")
print("image: ")
print(" shape", first_sample["image"].shape)
print(" type: ", type(first_sample["image"]))
Expand Down Expand Up @@ -194,9 +201,7 @@ def train(cfg):
StatsHandler(output_transform=lambda x: None),
TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None),
]
val_postprocessing = Compose(
[ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5)]
)
val_postprocessing = Compose([Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5)])
evaluator = SupervisedEvaluator(
device=device,
val_data_loader=valid_dataloader,
Expand All @@ -219,9 +224,7 @@ def train(cfg):
log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)
),
]
train_postprocessing = Compose(
[ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5)]
)
train_postprocessing = Compose([Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5)])

trainer = SupervisedTrainer(
device=device,
Expand All @@ -241,24 +244,18 @@ def train(cfg):
def main():
logging.basicConfig(level=logging.INFO)
parser = ArgumentParser(description="Tumor detection on whole slide pathology images.")
parser.add_argument(
"--dataset",
type=str,
default="../dataset_0.json",
dest="dataset_json",
help="path to dataset json file",
)
parser.add_argument(
"--root",
type=str,
default="/workspace/data/medical/pathology/",
dest="data_root",
help="path to root folder of images containing training folder",
default="/workspace/data/medical/pathology",
help="path to image folder containing training/validation",
)
parser.add_argument("--train-file", type=str, default="training.csv", help="path to training data file")
parser.add_argument("--valid-file", type=str, default="validation.csv", help="path to training data file")
parser.add_argument("--logdir", type=str, default="./logs/", dest="logdir", help="log directory")

parser.add_argument("--rs", type=int, default=256 * 3, dest="region_size", help="region size")
parser.add_argument("--gs", type=int, default=3, dest="grid_shape", help="image grid shape (3x3)")
parser.add_argument("--gs", type=int, default=3, dest="grid_shape", help="image grid shape e.g 3 means 3x3")
parser.add_argument("--ps", type=int, default=224, dest="patch_size", help="patch size")
parser.add_argument("--bs", type=int, default=64, dest="batch_size", help="batch size")
parser.add_argument("--ep", type=int, default=10, dest="n_epochs", help="number of epochs")
Expand Down
Loading