Skip to content

Commit f9b9832

Browse files
authored
Mb/metrics reloaded (#1222)
Closes Project-MONAI/MONAI#6025 --------- Co-authored-by: Mikael Brudfors <[email protected]>
1 parent 4d4c565 commit f9b9832

File tree

3 files changed

+380
-0
lines changed

3 files changed

+380
-0
lines changed

modules/metrics_reloaded/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# MetricsReloaded
2+
These scipts show how to use the [MetricsReloaded package](https://github.com/Project-MONAI/MetricsReloaded) with MONAI to compute a range of metrics for a binary segmentation task.
3+
4+
## Install
5+
Besides having installed MONAI, make sure to install the MetricsReloaded package by, e.g:
6+
```sh
7+
pip install git+https://github.com/Project-MONAI/MetricsReloaded@monai-support
8+
```
9+
10+
## Run
11+
First, run the training script:
12+
```sh
13+
python unet_training.py
14+
```
15+
to train a UNet on synthetic data. This script shows you how to use MetricsReloaded during validation.
16+
17+
Next, run the evaluation script:
18+
```sh
19+
python unet_evaluation.py
20+
```
21+
to predict on unsen cases and compute MetricsReloaded metrics from the predictions and references, which have been saved on disk. The requested metrics are printed to screen as well as saved to `results_metrics_reloaded.csv`.
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright 2023 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
import sys
15+
import tempfile
16+
from glob import glob
17+
18+
import nibabel as nib
19+
import numpy as np
20+
import torch
21+
from ignite.engine import Engine
22+
23+
from monai import config
24+
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
25+
from monai.handlers import CheckpointLoader, MeanDice, StatsHandler
26+
from monai.inferers import sliding_window_inference
27+
from monai.networks.nets import UNet
28+
from monai.transforms import Activations, EnsureChannelFirst, AsDiscrete, Compose, SaveImage, ScaleIntensity
29+
30+
from MetricsReloaded.processes.overall_process import ProcessEvaluation
31+
32+
33+
def get_metrics_reloaded_dict(pth_ref, pth_pred):
34+
"""Prepare input dictionary for MetricsReloaded package."""
35+
preds = []
36+
refs = []
37+
names = []
38+
for r, p in zip(pth_ref, pth_pred):
39+
name = r.split(os.sep)[-1].split(".nii.gz")[0]
40+
names.append(name)
41+
42+
ref = nib.load(r).get_fdata()
43+
pred = nib.load(p).get_fdata()
44+
refs.append(ref)
45+
preds.append(pred)
46+
47+
dict_file = {}
48+
dict_file["pred_loc"] = preds
49+
dict_file["ref_loc"] = refs
50+
dict_file["pred_prob"] = preds
51+
dict_file["ref_class"] = refs
52+
dict_file["pred_class"] = preds
53+
dict_file["list_values"] = [1]
54+
dict_file["file"] = pth_pred
55+
dict_file["names"] = names
56+
57+
return dict_file
58+
59+
60+
def main(tempdir, img_size=96):
61+
config.print_config()
62+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
63+
64+
# Set patch size
65+
patch_size = (int(img_size / 2.0),) * 3
66+
67+
print(f"generating synthetic data to {tempdir} (this may take a while)")
68+
for i in range(5):
69+
im, seg = create_test_image_3d(img_size, img_size, img_size, num_seg_classes=1)
70+
71+
n = nib.Nifti1Image(im, np.eye(4))
72+
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
73+
74+
n = nib.Nifti1Image(seg, np.eye(4))
75+
nib.save(n, os.path.join(tempdir, f"lab{i:d}.nii.gz"))
76+
77+
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
78+
segs = sorted(glob(os.path.join(tempdir, "lab*.nii.gz")))
79+
80+
# define transforms for image and segmentation
81+
imtrans = Compose([ScaleIntensity(), EnsureChannelFirst()])
82+
segtrans = Compose([EnsureChannelFirst()])
83+
ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
84+
85+
# Compute UNet levels and strides from image size
86+
min_size = 4 # minimum size allowed at coarsest resolution level
87+
num_levels = int(np.maximum(np.ceil(np.log2(np.min(img_size)) - np.log2(min_size)), 1))
88+
channels = [2 ** (i + 4) for i in range(num_levels)]
89+
90+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91+
net = UNet(
92+
spatial_dims=3,
93+
in_channels=1,
94+
out_channels=1,
95+
channels=channels,
96+
strides=(2,) * (num_levels - 1),
97+
num_res_units=2,
98+
).to(device)
99+
100+
# define sliding window size and batch size for windows inference
101+
roi_size = patch_size
102+
sw_batch_size = 4
103+
104+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
105+
save_image = SaveImage(output_dir=tempdir, output_ext=".nii.gz", output_postfix="pred", separate_folder=False)
106+
107+
def _sliding_window_processor(engine, batch):
108+
net.eval()
109+
with torch.no_grad():
110+
val_images, val_labels = batch[0].to(device), batch[1].to(device)
111+
seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
112+
seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
113+
for seg_prob in seg_probs:
114+
save_image(seg_prob)
115+
return seg_probs, val_labels
116+
117+
evaluator = Engine(_sliding_window_processor)
118+
119+
# StatsHandler prints loss at every iteration and print metrics at every epoch,
120+
# we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
121+
val_stats_handler = StatsHandler(
122+
name="evaluator",
123+
output_transform=lambda x: None, # no need to print loss value, so disable per iteration output
124+
)
125+
val_stats_handler.attach(evaluator)
126+
127+
# the model was trained by "unet_training_array" example
128+
cwd = os.sep.join(os.path.abspath(__file__).split(os.sep)[:-1])
129+
load_path = sorted(list(filter(os.path.isfile, glob(cwd + os.sep + "runs_array" + os.sep + "*.pt"))))
130+
ckpt_loader = CheckpointLoader(load_path=load_path[-1], load_dict={"net": net})
131+
ckpt_loader.attach(evaluator)
132+
133+
# sliding window inference for one image at every iteration
134+
loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
135+
state = evaluator.run(loader)
136+
print(state)
137+
138+
# Prepare MetricsReloaded input
139+
pth_ref = sorted(list(filter(os.path.isfile, glob(tempdir + os.sep + "lab*.nii.gz"))))
140+
pth_pred = sorted(list(filter(os.path.isfile, glob(tempdir + os.sep + "*_pred.nii.gz"))))
141+
142+
# Prepare input dictionary for MetricsReloaded package
143+
dict_file = get_metrics_reloaded_dict(pth_ref, pth_pred)
144+
145+
# Run MetricsReloaded evaluation process
146+
PE = ProcessEvaluation(
147+
dict_file,
148+
"SemS",
149+
localization="mask_iou",
150+
file=dict_file["file"],
151+
flag_map=True,
152+
assignment="greedy_matching",
153+
measures_overlap=[
154+
"numb_ref",
155+
"numb_pred",
156+
"numb_tp",
157+
"numb_fp",
158+
"numb_fn",
159+
"iou",
160+
"fbeta",
161+
],
162+
measures_boundary=[
163+
"assd",
164+
"boundary_iou",
165+
"hd",
166+
"hd_perc",
167+
"masd",
168+
"nsd",
169+
],
170+
case=True,
171+
thresh_ass=0.000001,
172+
)
173+
174+
# Save results as CSV
175+
PE.resseg.to_csv(cwd + os.sep + "results_metrics_reloaded.csv")
176+
177+
return
178+
179+
180+
if __name__ == "__main__":
181+
with tempfile.TemporaryDirectory() as tempdir:
182+
main(tempdir)
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2023 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
import sys
15+
import tempfile
16+
from glob import glob
17+
18+
import nibabel as nib
19+
import numpy as np
20+
import torch
21+
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
22+
from ignite.handlers import EarlyStopping, ModelCheckpoint
23+
24+
import monai
25+
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
26+
from monai.handlers import (
27+
MetricsReloadedBinaryHandler,
28+
StatsHandler,
29+
stopping_fn_from_metric,
30+
)
31+
from monai.transforms import (
32+
Activations,
33+
EnsureChannelFirst,
34+
AsDiscrete,
35+
Compose,
36+
RandSpatialCrop,
37+
Resize,
38+
ScaleIntensity,
39+
)
40+
41+
42+
def main(tempdir, img_size=96):
43+
monai.config.print_config()
44+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
45+
46+
# Set patch size
47+
patch_size = (int(img_size / 2.0),) * 3
48+
49+
# create a temporary directory and 40 random image, mask pairs
50+
print(f"generating synthetic data to {tempdir} (this may take a while)")
51+
for i in range(40):
52+
im, seg = create_test_image_3d(img_size, img_size, img_size, num_seg_classes=1)
53+
54+
n = nib.Nifti1Image(im, np.eye(4))
55+
nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
56+
57+
n = nib.Nifti1Image(seg, np.eye(4))
58+
nib.save(n, os.path.join(tempdir, f"lab{i:d}.nii.gz"))
59+
60+
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
61+
segs = sorted(glob(os.path.join(tempdir, "lab*.nii.gz")))
62+
63+
# define transforms for image and segmentation
64+
train_imtrans = Compose(
65+
[
66+
ScaleIntensity(),
67+
EnsureChannelFirst(),
68+
RandSpatialCrop(patch_size, random_size=False),
69+
]
70+
)
71+
train_segtrans = Compose([EnsureChannelFirst(), RandSpatialCrop(patch_size, random_size=False)])
72+
val_imtrans = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize(patch_size)])
73+
val_segtrans = Compose([EnsureChannelFirst(), Resize(patch_size)])
74+
75+
# define image dataset, data loader
76+
check_ds = ImageDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans)
77+
check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
78+
im, seg = monai.utils.misc.first(check_loader)
79+
print(im.shape, seg.shape)
80+
81+
# create a training data loader
82+
train_ds = ImageDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
83+
train_loader = DataLoader(
84+
train_ds,
85+
batch_size=5,
86+
shuffle=True,
87+
num_workers=8,
88+
pin_memory=torch.cuda.is_available(),
89+
)
90+
# create a validation data loader
91+
val_ds = ImageDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
92+
val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())
93+
94+
# Compute UNet levels and strides from image size
95+
min_size = 4 # minimum size allowed at coarsest resolution level
96+
num_levels = int(np.maximum(np.ceil(np.log2(np.min(img_size)) - np.log2(min_size)), 1))
97+
channels = [2 ** (i + 4) for i in range(num_levels)]
98+
99+
# create UNet, DiceLoss and Adam optimizer
100+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101+
net = monai.networks.nets.UNet(
102+
spatial_dims=3,
103+
in_channels=1,
104+
out_channels=1,
105+
channels=channels,
106+
strides=(2,) * (num_levels - 1),
107+
num_res_units=2,
108+
).to(device)
109+
loss = monai.losses.DiceLoss(sigmoid=True)
110+
lr = 1e-3
111+
opt = torch.optim.Adam(net.parameters(), lr)
112+
113+
# Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
114+
# user can add output_transform to return other values, like: y_pred, y, etc.
115+
trainer = create_supervised_trainer(net, opt, loss, device, False)
116+
117+
# adding checkpoint handler to save models (network params and optimizer stats) during training
118+
checkpoint_handler = ModelCheckpoint("./runs_array/", "net", n_saved=10, require_empty=False)
119+
trainer.add_event_handler(
120+
event_name=Events.EPOCH_COMPLETED,
121+
handler=checkpoint_handler,
122+
to_save={"net": net, "opt": opt},
123+
)
124+
125+
# StatsHandler prints loss at every iteration and print metrics at every epoch,
126+
# we don't set metrics for trainer here, so just print loss, user can also customize print functions
127+
# and can use output_transform to convert engine.state.output if it's not a loss value
128+
train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
129+
train_stats_handler.attach(trainer)
130+
131+
# Set parameters for validation
132+
validation_every_n_epochs = 1
133+
# Use validation metrics from MetricsReloaded
134+
metric_name = "Intersection_Over_Union"
135+
# add evaluation metric to the evaluator engine
136+
val_metrics = {metric_name: MetricsReloadedBinaryHandler("Intersection Over Union")}
137+
138+
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
139+
post_label = Compose([AsDiscrete(threshold=0.5)])
140+
141+
# Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
142+
# user can add output_transform to return other values
143+
evaluator = create_supervised_evaluator(
144+
net,
145+
val_metrics,
146+
device,
147+
True,
148+
output_transform=lambda x, y, y_pred: (
149+
[post_pred(i) for i in decollate_batch(y_pred)],
150+
[post_label(i) for i in decollate_batch(y)],
151+
),
152+
)
153+
154+
@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
155+
def run_validation(engine):
156+
evaluator.run(val_loader)
157+
158+
# add early stopping handler to evaluator
159+
early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)
160+
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)
161+
162+
# add stats event handler to print validation stats via evaluator
163+
val_stats_handler = StatsHandler(
164+
name="evaluator",
165+
output_transform=lambda x: None, # no need to print loss value, so disable per iteration output
166+
global_epoch_transform=lambda x: trainer.state.epoch,
167+
) # fetch global epoch number from trainer
168+
val_stats_handler.attach(evaluator)
169+
170+
train_epochs = 30
171+
state = trainer.run(train_loader, train_epochs)
172+
print(state)
173+
174+
175+
if __name__ == "__main__":
176+
with tempfile.TemporaryDirectory() as tempdir:
177+
main(tempdir)

0 commit comments

Comments
 (0)