Skip to content

Commit a08b8e0

Browse files
bhashemianNic-Ma
andauthored
Update tumor detection pipline with new components (#697)
* Update tumor detection pipline with new components Signed-off-by: Behrooz <[email protected]> * formatting Signed-off-by: Behrooz <[email protected]> * formatting Signed-off-by: Behrooz <[email protected]> * Add sub dataset Signed-off-by: Behrooz <[email protected]> * Convert xxxD transforms to xxxd Signed-off-by: Behrooz <[email protected]> * Update pipline with nvtx annotation Signed-off-by: Behrooz <[email protected]> * Add transforms to validation Signed-off-by: Behrooz <[email protected]> * Update tumor notebook Signed-off-by: Behrooz <[email protected]> * Update torch-based pipeline Signed-off-by: Behrooz <[email protected]> * Update perfomance profiling code Signed-off-by: Behrooz <[email protected]> * Update grid shape Signed-off-by: Behrooz <[email protected]> * Update Signed-off-by: Behrooz <[email protected]> * Update a comment Signed-off-by: Behrooz <[email protected]> * Update training/validation dataset and their links Signed-off-by: Behrooz <[email protected]> * Fix image path Signed-off-by: Behrooz <[email protected]> * Change to validate by default Signed-off-by: Behrooz <[email protected]> * Flip back the location dims Signed-off-by: Behrooz <[email protected]> * Add filtered datasets Signed-off-by: Behrooz <[email protected]> * Update csv loading Signed-off-by: Behrooz <[email protected]> * Update validation defaults Signed-off-by: Behrooz <[email protected]> * Remove csv files Signed-off-by: Behrooz <[email protected]> Co-authored-by: Nic Ma <[email protected]>
1 parent 590461a commit a08b8e0

File tree

6 files changed

+678
-523
lines changed

6 files changed

+678
-523
lines changed

pathology/tumor_detection/README.MD

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@ The model is based on ResNet18 with the last fully connected layer replaced by a
1212

1313
All the data used to train and validate this model is from [Camelyon-16 Challenge](https://camelyon16.grand-challenge.org/). You can download all the images for "CAMELYON16" data set from various sources listed [here](https://camelyon17.grand-challenge.org/Data/).
1414

15-
Location information for training/validation patches (the location on the whole slide image where patches are extracted) are adopted from [NCRF/coords](https://github.com/baidu-research/NCRF/tree/master/coords). The reformatted coordinations and labels are stored in a json file (`dataset_0.json`), and can be downloaded from [here](https://drive.google.com/file/d/1m2pwko6hxwsxeDWZY2oSOV-_KT97Ol0o/view?usp=sharing)
15+
Location information for training/validation patches (the location on the whole slide image where patches are extracted) are adopted from [NCRF/coords](https://github.com/baidu-research/NCRF/tree/master/coords). The reformatted coordinations and labels in CSV format for training (`training.csv`) can be found [here](https://drive.google.com/file/d/1httIjgji6U6rMIb0P8pE0F-hXFAuvQEf/view?usp=sharing) and for validation (`validation.csv`) can be found [here](https://drive.google.com/file/d/1tJulzl9m5LUm16IeFbOCoFnaSWoB6i5L/view?usp=sharing).
1616

17-
This pipeline expects the training/validation data (whole slide images) reside in `cfg["data_root"]/training/images`. By default `data_root` is pointing to `/workspace/data/medical/pathology/` You can easily modify it to point to a different directory by passing the following argument in the runtime: `--data-root /other/data/root/dir/`.
17+
This pipeline expects the training/validation data (whole slide images) reside in `cfg["data_root"]/training/images`. By default `data_root` is pointing to the code folder `./`; however, you can easily modify it to point to a different directory by passing the following argument in the runtime: `--data-root /other/data/root/dir/`.
1818

19-
> `dataset_0_subset_0.json` is also provided [here](https://drive.google.com/file/d/1NCd0y4FR42maQpfZjzKlFSIX4oeKgysg/view?usp=sharing) to check the functionality of the pipeline using only two of the whole slide images: `tumor_001` and `tumor_101`. <br/>
20-
> This dataset should not be used for the real training or any perfomance evaluation.
19+
> [`training_sub.csv`](https://drive.google.com/file/d/1rO8ZY-TrU9nrOsx-Udn1q5PmUYrLG3Mv/view?usp=sharing) and [`validation_sub.csv`](https://drive.google.com/file/d/130pqsrc2e9wiHIImL8w4fT_5NktEGel7/view?usp=sharing) is also provided to check the functionality of the pipeline using only two of the whole slide images: `tumor_001` (for training) and `tumor_101` (for validation). This dataset should not be used for the real training or any performance evaluation.
2120
2221
### Input and output formats
2322

pathology/tumor_detection/ignite/camelyon_train_evaluate.py

Lines changed: 77 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,17 @@
1-
import os
2-
31
import logging
2+
import os
43
import time
54
from argparse import ArgumentParser
65

76
import numpy as np
8-
7+
import pandas as pd
98
import torch
10-
from torch.optim import SGD, lr_scheduler
11-
129
from ignite.metrics import Accuracy
10+
from torch.optim import SGD, lr_scheduler
1311

1412
import monai
15-
from monai.data import DataLoader, load_decathlon_datalist
16-
from monai.transforms import (
17-
ActivationsD,
18-
AsDiscreteD,
19-
CastToTypeD,
20-
Compose,
21-
RandFlipD,
22-
RandRotate90D,
23-
RandZoomD,
24-
ScaleIntensityRangeD,
25-
ToNumpyD,
26-
TorchVisionD,
27-
ToTensorD,
28-
)
29-
from monai.utils import first, set_determinism
30-
from monai.optimizers import Novograd
31-
from monai.engines import SupervisedTrainer, SupervisedEvaluator
13+
from monai.data import DataLoader, PatchWSIDataset, CSVDataset
14+
from monai.engines import SupervisedEvaluator, SupervisedTrainer
3215
from monai.handlers import (
3316
CheckpointSaver,
3417
LrScheduleHandler,
@@ -37,10 +20,24 @@
3720
ValidationHandler,
3821
from_engine,
3922
)
40-
41-
from monai.apps.pathology.data import PatchWSIDataset
4223
from monai.networks.nets import TorchVisionFCModel
43-
24+
from monai.optimizers import Novograd
25+
from monai.transforms import (
26+
Activationsd,
27+
AsDiscreted,
28+
CastToTyped,
29+
Compose,
30+
GridSplitd,
31+
Lambdad,
32+
RandFlipd,
33+
RandRotate90d,
34+
RandZoomd,
35+
ScaleIntensityRanged,
36+
ToNumpyd,
37+
TorchVisiond,
38+
ToTensord,
39+
)
40+
from monai.utils import first, set_determinism
4441

4542
torch.backends.cudnn.enabled = True
4643
set_determinism(seed=0, additional_settings=None)
@@ -65,7 +62,7 @@ def set_device(cfg):
6562
if gpus and torch.cuda.is_available():
6663
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(n) for n in gpus])
6764
device = torch.device("cuda")
68-
print(f'CUDA is being used with GPU ID(s): {os.environ["CUDA_VISIBLE_DEVICES"]}')
65+
print(f'CUDA is being used with GPU Id(s): {os.environ["CUDA_VISIBLE_DEVICES"]}')
6966
else:
7067
device = torch.device("cpu")
7168
print("CPU only!")
@@ -82,54 +79,66 @@ def train(cfg):
8279
# Build MONAI preprocessing
8380
train_preprocess = Compose(
8481
[
85-
ToTensorD(keys="image"),
86-
TorchVisionD(
82+
Lambdad(keys="label", func=lambda x: x.reshape((1, cfg["grid_shape"], cfg["grid_shape"]))),
83+
GridSplitd(
84+
keys=("image", "label"),
85+
grid=(cfg["grid_shape"], cfg["grid_shape"]),
86+
size={"image": cfg["patch_size"], "label": 1},
87+
),
88+
ToTensord(keys=("image")),
89+
TorchVisiond(
8790
keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04
8891
),
89-
ToNumpyD(keys="image"),
90-
RandFlipD(keys="image", prob=0.5),
91-
RandRotate90D(keys="image", prob=0.5),
92-
CastToTypeD(keys="image", dtype=np.float32),
93-
RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
94-
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
95-
ToTensorD(keys=("image", "label")),
92+
ToNumpyd(keys="image"),
93+
RandFlipd(keys="image", prob=0.5),
94+
RandRotate90d(keys="image", prob=0.5, max_k=3, spatial_axes=(-2, -1)),
95+
CastToTyped(keys="image", dtype=np.float32),
96+
RandZoomd(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
97+
ScaleIntensityRanged(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
98+
ToTensord(keys=("image", "label")),
9699
]
97100
)
98101
valid_preprocess = Compose(
99102
[
100-
CastToTypeD(keys="image", dtype=np.float32),
101-
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
102-
ToTensorD(keys=("image", "label")),
103+
Lambdad(keys="label", func=lambda x: x.reshape((1, cfg["grid_shape"], cfg["grid_shape"]))),
104+
GridSplitd(
105+
keys=("image", "label"),
106+
grid=(cfg["grid_shape"], cfg["grid_shape"]),
107+
size={"image": cfg["patch_size"], "label": 1},
108+
),
109+
CastToTyped(keys="image", dtype=np.float32),
110+
ScaleIntensityRanged(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
111+
ToTensord(keys=("image", "label")),
103112
]
104113
)
105114
# __________________________________________________________________________
106115
# Create MONAI dataset
107-
train_json_info_list = load_decathlon_datalist(
108-
data_list_file_path=cfg["dataset_json"],
109-
data_list_key="training",
110-
base_dir=cfg["data_root"],
116+
train_data_list = CSVDataset(
117+
cfg["train_file"],
118+
col_groups={"image": 0, "patch_location": [2, 1], "label": [3, 6, 9, 4, 7, 10, 5, 8, 11]},
119+
kwargs_read_csv={"header": None},
120+
transform=Lambdad("image", lambda x: os.path.join(cfg["root"], "training/images", x + ".tif")),
111121
)
112-
valid_json_info_list = load_decathlon_datalist(
113-
data_list_file_path=cfg["dataset_json"],
114-
data_list_key="validation",
115-
base_dir=cfg["data_root"],
122+
train_dataset = PatchWSIDataset(
123+
data=train_data_list,
124+
patch_size=cfg["region_size"],
125+
patch_level=0,
126+
transform=train_preprocess,
127+
reader="openslide" if cfg["use_openslide"] else "cuCIM",
116128
)
117129

118-
train_dataset = PatchWSIDataset(
119-
train_json_info_list,
120-
cfg["region_size"],
121-
cfg["grid_shape"],
122-
cfg["patch_size"],
123-
train_preprocess,
124-
image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
130+
valid_data_list = CSVDataset(
131+
cfg["valid_file"],
132+
col_groups={"image": 0, "patch_location": [2, 1], "label": [3, 6, 9, 4, 7, 10, 5, 8, 11]},
133+
kwargs_read_csv={"header": None},
134+
transform=Lambdad("image", lambda x: os.path.join(cfg["root"], "training/images", x + ".tif")),
125135
)
126136
valid_dataset = PatchWSIDataset(
127-
valid_json_info_list,
128-
cfg["region_size"],
129-
cfg["grid_shape"],
130-
cfg["patch_size"],
131-
valid_preprocess,
132-
image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
137+
data=valid_data_list,
138+
patch_size=cfg["region_size"],
139+
patch_level=0,
140+
transform=valid_preprocess,
141+
reader="openslide" if cfg["use_openslide"] else "cuCIM",
133142
)
134143

135144
# __________________________________________________________________________
@@ -141,12 +150,10 @@ def train(cfg):
141150
valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True
142151
)
143152

144-
# __________________________________________________________________________
145-
# Get sample batch and some info
153+
# Check first sample
146154
first_sample = first(train_dataloader)
147155
if first_sample is None:
148-
raise ValueError("Fist sample is None!")
149-
156+
raise ValueError("First sample is None!")
150157
print("image: ")
151158
print(" shape", first_sample["image"].shape)
152159
print(" type: ", type(first_sample["image"]))
@@ -194,9 +201,7 @@ def train(cfg):
194201
StatsHandler(output_transform=lambda x: None),
195202
TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None),
196203
]
197-
val_postprocessing = Compose(
198-
[ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5)]
199-
)
204+
val_postprocessing = Compose([Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5)])
200205
evaluator = SupervisedEvaluator(
201206
device=device,
202207
val_data_loader=valid_dataloader,
@@ -219,9 +224,7 @@ def train(cfg):
219224
log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)
220225
),
221226
]
222-
train_postprocessing = Compose(
223-
[ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5)]
224-
)
227+
train_postprocessing = Compose([Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5)])
225228

226229
trainer = SupervisedTrainer(
227230
device=device,
@@ -241,24 +244,18 @@ def train(cfg):
241244
def main():
242245
logging.basicConfig(level=logging.INFO)
243246
parser = ArgumentParser(description="Tumor detection on whole slide pathology images.")
244-
parser.add_argument(
245-
"--dataset",
246-
type=str,
247-
default="../dataset_0.json",
248-
dest="dataset_json",
249-
help="path to dataset json file",
250-
)
251247
parser.add_argument(
252248
"--root",
253249
type=str,
254-
default="/workspace/data/medical/pathology/",
255-
dest="data_root",
256-
help="path to root folder of images containing training folder",
250+
default="/workspace/data/medical/pathology",
251+
help="path to image folder containing training/validation",
257252
)
253+
parser.add_argument("--train-file", type=str, default="training.csv", help="path to training data file")
254+
parser.add_argument("--valid-file", type=str, default="validation.csv", help="path to training data file")
258255
parser.add_argument("--logdir", type=str, default="./logs/", dest="logdir", help="log directory")
259256

260257
parser.add_argument("--rs", type=int, default=256 * 3, dest="region_size", help="region size")
261-
parser.add_argument("--gs", type=int, default=3, dest="grid_shape", help="image grid shape (3x3)")
258+
parser.add_argument("--gs", type=int, default=3, dest="grid_shape", help="image grid shape e.g 3 means 3x3")
262259
parser.add_argument("--ps", type=int, default=224, dest="patch_size", help="patch size")
263260
parser.add_argument("--bs", type=int, default=64, dest="batch_size", help="batch size")
264261
parser.add_argument("--ep", type=int, default=10, dest="n_epochs", help="number of epochs")

0 commit comments

Comments
 (0)