Skip to content

Commit 17bf2ec

Browse files
authored
1015 Restructure folders and combine examples (#27)
* [DLMED] restructure folders Signed-off-by: Nic Ma <[email protected]> * [DLMED] add examples Signed-off-by: Nic Ma <[email protected]> * [DLMED] update patch in notebooks Signed-off-by: Nic Ma <[email protected]> * [DLMED] update images to figures Signed-off-by: Nic Ma <[email protected]>
1 parent c5f14b0 commit 17bf2ec

File tree

59 files changed

+5321
-86
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+5321
-86
lines changed

mednist_tutorial.ipynb renamed to 2d_classification/mednist_tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"* Train the model with a PyTorch program\n",
1616
"* Evaluate on test dataset\n",
1717
"\n",
18-
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/Tutorials/blob/master/mednist_tutorial.ipynb)"
18+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/2d_classification/mednist_tutorial.ipynb)"
1919
]
2020
},
2121
{
@@ -683,7 +683,7 @@
683683
"name": "python",
684684
"nbconvert_exporter": "python",
685685
"pygments_lexer": "ipython3",
686-
"version": "3.6.9"
686+
"version": "3.6.10"
687687
}
688688
},
689689
"nbformat": 4,
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
import sys
15+
import tempfile
16+
from glob import glob
17+
18+
import torch
19+
from PIL import Image
20+
from torch.utils.data import DataLoader
21+
22+
from monai import config
23+
from monai.data import ArrayDataset, PNGSaver, create_test_image_2d
24+
from monai.inferers import sliding_window_inference
25+
from monai.metrics import DiceMetric
26+
from monai.networks.nets import UNet
27+
from monai.transforms import AddChannel, Compose, LoadImage, ScaleIntensity, ToTensor
28+
29+
30+
def main(tempdir):
31+
config.print_config()
32+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
33+
34+
print(f"generating synthetic data to {tempdir} (this may take a while)")
35+
for i in range(5):
36+
im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
37+
Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
38+
Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))
39+
40+
images = sorted(glob(os.path.join(tempdir, "img*.png")))
41+
segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
42+
43+
# define transforms for image and segmentation
44+
imtrans = Compose([LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor()])
45+
segtrans = Compose([LoadImage(image_only=True), AddChannel(), ToTensor()])
46+
val_ds = ArrayDataset(images, imtrans, segs, segtrans)
47+
# sliding window inference for one image at every iteration
48+
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
49+
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
50+
51+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52+
model = UNet(
53+
dimensions=2,
54+
in_channels=1,
55+
out_channels=1,
56+
channels=(16, 32, 64, 128, 256),
57+
strides=(2, 2, 2, 2),
58+
num_res_units=2,
59+
).to(device)
60+
61+
model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth"))
62+
model.eval()
63+
with torch.no_grad():
64+
metric_sum = 0.0
65+
metric_count = 0
66+
saver = PNGSaver(output_dir="./output")
67+
for val_data in val_loader:
68+
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
69+
# define sliding window size and batch size for windows inference
70+
roi_size = (96, 96)
71+
sw_batch_size = 4
72+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
73+
value = dice_metric(y_pred=val_outputs, y=val_labels)
74+
metric_count += len(value)
75+
metric_sum += value.item() * len(value)
76+
val_outputs = val_outputs.sigmoid() >= 0.5
77+
saver.save_batch(val_outputs)
78+
metric = metric_sum / metric_count
79+
print("evaluation metric:", metric)
80+
81+
82+
if __name__ == "__main__":
83+
with tempfile.TemporaryDirectory() as tempdir:
84+
main(tempdir)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
import sys
15+
import tempfile
16+
from glob import glob
17+
18+
import torch
19+
from PIL import Image
20+
from torch.utils.data import DataLoader
21+
22+
import monai
23+
from monai.data import PNGSaver, create_test_image_2d, list_data_collate
24+
from monai.inferers import sliding_window_inference
25+
from monai.metrics import DiceMetric
26+
from monai.networks.nets import UNet
27+
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
28+
29+
30+
def main(tempdir):
31+
monai.config.print_config()
32+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
33+
34+
print(f"generating synthetic data to {tempdir} (this may take a while)")
35+
for i in range(5):
36+
im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
37+
Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
38+
Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))
39+
40+
images = sorted(glob(os.path.join(tempdir, "img*.png")))
41+
segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
42+
val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]
43+
44+
# define transforms for image and segmentation
45+
val_transforms = Compose(
46+
[
47+
LoadImaged(keys=["img", "seg"]),
48+
AddChanneld(keys=["img", "seg"]),
49+
ScaleIntensityd(keys="img"),
50+
ToTensord(keys=["img", "seg"]),
51+
]
52+
)
53+
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
54+
# sliding window inference need to input 1 image in every iteration
55+
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
56+
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
57+
58+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59+
model = UNet(
60+
dimensions=2,
61+
in_channels=1,
62+
out_channels=1,
63+
channels=(16, 32, 64, 128, 256),
64+
strides=(2, 2, 2, 2),
65+
num_res_units=2,
66+
).to(device)
67+
68+
model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth"))
69+
70+
model.eval()
71+
with torch.no_grad():
72+
metric_sum = 0.0
73+
metric_count = 0
74+
saver = PNGSaver(output_dir="./output")
75+
for val_data in val_loader:
76+
val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
77+
# define sliding window size and batch size for windows inference
78+
roi_size = (96, 96)
79+
sw_batch_size = 4
80+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
81+
value = dice_metric(y_pred=val_outputs, y=val_labels)
82+
metric_count += len(value)
83+
metric_sum += value.item() * len(value)
84+
val_outputs = val_outputs.sigmoid() >= 0.5
85+
saver.save_batch(val_outputs)
86+
metric = metric_sum / metric_count
87+
print("evaluation metric:", metric)
88+
89+
90+
if __name__ == "__main__":
91+
with tempfile.TemporaryDirectory() as tempdir:
92+
main(tempdir)
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
import sys
15+
import tempfile
16+
from glob import glob
17+
18+
import torch
19+
from PIL import Image
20+
from torch.utils.data import DataLoader
21+
from torch.utils.tensorboard import SummaryWriter
22+
23+
import monai
24+
from monai.data import ArrayDataset, create_test_image_2d
25+
from monai.inferers import sliding_window_inference
26+
from monai.metrics import DiceMetric
27+
from monai.transforms import AddChannel, Compose, LoadImage, RandRotate90, RandSpatialCrop, ScaleIntensity, ToTensor
28+
from monai.visualize import plot_2d_or_3d_image
29+
30+
31+
def main(tempdir):
32+
monai.config.print_config()
33+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
34+
35+
# create a temporary directory and 40 random image, mask pairs
36+
print(f"generating synthetic data to {tempdir} (this may take a while)")
37+
for i in range(40):
38+
im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
39+
Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
40+
Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))
41+
42+
images = sorted(glob(os.path.join(tempdir, "img*.png")))
43+
segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
44+
train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
45+
val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]
46+
47+
# define transforms for image and segmentation
48+
train_imtrans = Compose(
49+
[
50+
LoadImage(image_only=True),
51+
ScaleIntensity(),
52+
AddChannel(),
53+
RandSpatialCrop((96, 96), random_size=False),
54+
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
55+
ToTensor(),
56+
]
57+
)
58+
train_segtrans = Compose(
59+
[
60+
LoadImage(image_only=True),
61+
AddChannel(),
62+
RandSpatialCrop((96, 96), random_size=False),
63+
RandRotate90(prob=0.5, spatial_axes=(0, 1)),
64+
ToTensor(),
65+
]
66+
)
67+
val_imtrans = Compose([LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor()])
68+
val_segtrans = Compose([LoadImage(image_only=True), AddChannel(), ToTensor()])
69+
70+
# define array dataset, data loader
71+
check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
72+
check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
73+
im, seg = monai.utils.misc.first(check_loader)
74+
print(im.shape, seg.shape)
75+
76+
# create a training data loader
77+
train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans)
78+
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
79+
# create a validation data loader
80+
val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
81+
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
82+
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")
83+
84+
# create UNet, DiceLoss and Adam optimizer
85+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86+
model = monai.networks.nets.UNet(
87+
dimensions=2,
88+
in_channels=1,
89+
out_channels=1,
90+
channels=(16, 32, 64, 128, 256),
91+
strides=(2, 2, 2, 2),
92+
num_res_units=2,
93+
).to(device)
94+
loss_function = monai.losses.DiceLoss(sigmoid=True)
95+
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
96+
97+
# start a typical PyTorch training
98+
val_interval = 2
99+
best_metric = -1
100+
best_metric_epoch = -1
101+
epoch_loss_values = list()
102+
metric_values = list()
103+
writer = SummaryWriter()
104+
for epoch in range(10):
105+
print("-" * 10)
106+
print(f"epoch {epoch + 1}/{10}")
107+
model.train()
108+
epoch_loss = 0
109+
step = 0
110+
for batch_data in train_loader:
111+
step += 1
112+
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
113+
optimizer.zero_grad()
114+
outputs = model(inputs)
115+
loss = loss_function(outputs, labels)
116+
loss.backward()
117+
optimizer.step()
118+
epoch_loss += loss.item()
119+
epoch_len = len(train_ds) // train_loader.batch_size
120+
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
121+
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
122+
epoch_loss /= step
123+
epoch_loss_values.append(epoch_loss)
124+
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
125+
126+
if (epoch + 1) % val_interval == 0:
127+
model.eval()
128+
with torch.no_grad():
129+
metric_sum = 0.0
130+
metric_count = 0
131+
val_images = None
132+
val_labels = None
133+
val_outputs = None
134+
for val_data in val_loader:
135+
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
136+
roi_size = (96, 96)
137+
sw_batch_size = 4
138+
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
139+
value = dice_metric(y_pred=val_outputs, y=val_labels)
140+
metric_count += len(value)
141+
metric_sum += value.item() * len(value)
142+
metric = metric_sum / metric_count
143+
metric_values.append(metric)
144+
if metric > best_metric:
145+
best_metric = metric
146+
best_metric_epoch = epoch + 1
147+
torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
148+
print("saved new best metric model")
149+
print(
150+
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
151+
epoch + 1, metric, best_metric, best_metric_epoch
152+
)
153+
)
154+
writer.add_scalar("val_mean_dice", metric, epoch + 1)
155+
# plot the last model output as GIF image in TensorBoard with the corresponding image and label
156+
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
157+
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
158+
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
159+
160+
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
161+
writer.close()
162+
163+
164+
if __name__ == "__main__":
165+
with tempfile.TemporaryDirectory() as tempdir:
166+
main(tempdir)

0 commit comments

Comments
 (0)