Skip to content

Update segmentation_3d tutorials for metatensor #776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 14, 2022
Merged
31 changes: 12 additions & 19 deletions 3d_segmentation/brats_segmentation_3d.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion 3d_segmentation/challenge_baseline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ This directory contains a simple baseline method [using MONAI](https://monai.io)

The script is tested with:

- `Ubuntu 18.04` | `Python 3.6` | `CUDA 10.2`
- `Ubuntu 20.04` | `Python 3.8` | `CUDA 11.7`

On a GPU with [automatic mixed precision support](https://developer.nvidia.com/automatic-mixed-precision):

Expand Down
15 changes: 7 additions & 8 deletions 3d_segmentation/challenge_baseline/run_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
ScaleIntensityRanged,
Spacingd,
SpatialPadd,
EnsureTyped,
)


Expand Down Expand Up @@ -69,12 +68,12 @@ def get_xforms(mode="train", keys=("image", "label")):
RandFlipd(keys, spatial_axis=2, prob=0.5),
]
)
dtype = (np.float32, np.uint8)
dtype = (torch.float32, torch.uint8)
if mode == "val":
dtype = (np.float32, np.uint8)
dtype = (torch.float32, torch.uint8)
if mode == "infer":
dtype = (np.float32,)
xforms.extend([CastToTyped(keys, dtype=dtype), EnsureTyped(keys)])
dtype = (torch.float32,)
xforms.extend([CastToTyped(keys, dtype=dtype)])
return monai.transforms.Compose(xforms)


Expand Down Expand Up @@ -172,7 +171,7 @@ def train(data_folder=".", model_folder="runs"):

# create evaluator (to be used to measure model quality during training
val_post_transform = monai.transforms.Compose(
[EnsureTyped(keys=("pred", "label")), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=2)]
[AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=2)]
)
val_handlers = [
ProgressBar(),
Expand Down Expand Up @@ -246,7 +245,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
saver = monai.data.NiftiSaver(output_dir=prediction_folder, mode="nearest")
with torch.no_grad():
for infer_data in infer_loader:
logging.info(f"segmenting {infer_data['image_meta_dict']['filename_or_obj']}")
logging.info(f"segmenting {infer_data['image'].meta['filename_or_obj']}")
preds = inferer(infer_data[keys[0]].to(device), net)
n = 1.0
for _ in range(4):
Expand All @@ -262,7 +261,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
n = n + 1.0
preds = preds / n
preds = (preds.argmax(dim=1, keepdims=True)).float()
saver.save_batch(preds, infer_data["image_meta_dict"])
saver.save_batch(preds, infer_data["image"].meta)

# copy the saved segmentations into the required folder structure for submission
submission_dir = os.path.join(prediction_folder, "to_submit")
Expand Down
16 changes: 7 additions & 9 deletions 3d_segmentation/ignite/unet_evaluation_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
import numpy as np
import torch
from ignite.engine import Engine
from torch.utils.data import DataLoader

from monai import config
from monai.data import ImageDataset, create_test_image_3d, decollate_batch
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
from monai.handlers import CheckpointLoader, MeanDice, StatsHandler
from monai.inferers import sliding_window_inference
from monai.networks.nets import UNet
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, SaveImage, ScaleIntensity, EnsureType
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, SaveImage, ScaleIntensity


def main(tempdir):
Expand All @@ -47,8 +46,8 @@ def main(tempdir):
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

# define transforms for image and segmentation
imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
segtrans = Compose([AddChannel(), EnsureType()])
imtrans = Compose([ScaleIntensity(), AddChannel()])
segtrans = Compose([AddChannel()])
ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -65,7 +64,7 @@ def main(tempdir):
roi_size = (96, 96, 96)
sw_batch_size = 4

post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg")

def _sliding_window_processor(engine, batch):
Expand All @@ -74,9 +73,8 @@ def _sliding_window_processor(engine, batch):
val_images, val_labels = batch[0].to(device), batch[1].to(device)
seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
val_data = decollate_batch(batch[2])
for seg_prob, data in zip(seg_probs, val_data):
save_image(seg_prob, data)
for seg_prob in seg_probs:
save_image(seg_prob)
return seg_probs, val_labels

evaluator = Engine(_sliding_window_processor)
Expand Down
10 changes: 4 additions & 6 deletions 3d_segmentation/ignite/unet_evaluation_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.handlers import CheckpointLoader, MeanDice, StatsHandler
from monai.inferers import sliding_window_inference
from monai.networks.nets import UNet
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd, EnsureTyped, EnsureType
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd


def main(tempdir):
Expand All @@ -53,7 +53,6 @@ def main(tempdir):
LoadImaged(keys=["img", "seg"]),
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
ScaleIntensityd(keys="img"),
EnsureTyped(keys=["img", "seg"]),
]
)
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
Expand All @@ -72,7 +71,7 @@ def main(tempdir):
roi_size = (96, 96, 96)
sw_batch_size = 4

post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg")

def _sliding_window_processor(engine, batch):
Expand All @@ -81,9 +80,8 @@ def _sliding_window_processor(engine, batch):
val_images, val_labels = batch["img"].to(device), batch["seg"].to(device)
seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
val_data = decollate_batch(batch["img_meta_dict"])
for seg_prob, data in zip(seg_probs, val_data):
save_image(seg_prob, data)
for seg_prob in seg_probs:
save_image(seg_prob)
return seg_probs, val_labels

evaluator = Engine(_sliding_window_processor)
Expand Down
15 changes: 6 additions & 9 deletions 3d_segmentation/ignite/unet_training_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
import torch
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader

import monai
from monai.data import ImageDataset, create_test_image_3d, decollate_batch
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
from monai.handlers import (
MeanDice,
StatsHandler,
Expand All @@ -39,7 +38,6 @@
RandSpatialCrop,
Resize,
ScaleIntensity,
EnsureType,
)


Expand Down Expand Up @@ -67,16 +65,15 @@ def main(tempdir):
ScaleIntensity(),
AddChannel(),
RandSpatialCrop((96, 96, 96), random_size=False),
EnsureType(),
]
)
train_segtrans = Compose(
[AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), EnsureType()]
[AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False)]
)
val_imtrans = Compose(
[ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()]
[ScaleIntensity(), AddChannel(), Resize((96, 96, 96))]
)
val_segtrans = Compose([AddChannel(), Resize((96, 96, 96)), EnsureType()])
val_segtrans = Compose([AddChannel(), Resize((96, 96, 96))])

# define image dataset, data loader
check_ds = ImageDataset(
Expand Down Expand Up @@ -151,8 +148,8 @@ def main(tempdir):
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice()}

post_pred = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)])
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_label = Compose([AsDiscrete(threshold=0.5)])

# Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
Expand Down
8 changes: 2 additions & 6 deletions 3d_segmentation/ignite/unet_training_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@
RandCropByPosNegLabeld,
RandRotate90d,
ScaleIntensityd,
EnsureTyped,
EnsureType,
)


Expand Down Expand Up @@ -85,15 +83,13 @@ def main(tempdir):
num_samples=4,
),
RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
EnsureTyped(keys=["img", "seg"]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["img", "seg"]),
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
ScaleIntensityd(keys="img"),
EnsureTyped(keys=["img", "seg"]),
]
)

Expand Down Expand Up @@ -180,8 +176,8 @@ def prepare_batch(batch, device=None, non_blocking=False):
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice()}

post_pred = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)])
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_label = Compose([AsDiscrete(threshold=0.5)])

# Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
Expand Down
27 changes: 9 additions & 18 deletions 3d_segmentation/spleen_segmentation_3d.ipynb

Large diffs are not rendered by default.

27 changes: 12 additions & 15 deletions 3d_segmentation/spleen_segmentation_3d_lightning.ipynb

Large diffs are not rendered by default.

13 changes: 3 additions & 10 deletions 3d_segmentation/spleen_segmentation_3d_visualization_basic.ipynb

Large diffs are not rendered by default.

14 changes: 5 additions & 9 deletions 3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb

Large diffs are not rendered by default.

17 changes: 6 additions & 11 deletions 3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb

Large diffs are not rendered by default.

16 changes: 7 additions & 9 deletions 3d_segmentation/torch/unet_evaluation_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import DataLoader

from monai import config
from monai.data import ImageDataset, create_test_image_3d, decollate_batch
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, SaveImage, ScaleIntensity, EnsureType
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, SaveImage, ScaleIntensity


def main(tempdir):
Expand All @@ -46,13 +45,13 @@ def main(tempdir):
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

# define transforms for image and segmentation
imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
segtrans = Compose([AddChannel(), EnsureType()])
imtrans = Compose([ScaleIntensity(), AddChannel()])
segtrans = Compose([AddChannel()])
val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
# sliding window inference for one image at every iteration
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
Expand All @@ -75,11 +74,10 @@ def main(tempdir):
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
val_labels = decollate_batch(val_labels)
meta_data = decollate_batch(val_data[2])
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
for val_output, data in zip(val_outputs, meta_data):
saver(val_output, data)
for val_output in val_outputs:
saver(val_output)
# aggregate the final mean dice result
print("evaluation metric:", dice_metric.aggregate().item())
# reset the status
Expand Down
10 changes: 4 additions & 6 deletions 3d_segmentation/torch/unet_evaluation_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd, EnsureTyped, EnsureType
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd


def main(tempdir):
Expand All @@ -53,14 +53,13 @@ def main(tempdir):
LoadImaged(keys=["img", "seg"]),
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
ScaleIntensityd(keys="img"),
EnsureTyped(keys=["img", "seg"]),
]
)
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
# sliding window inference need to input 1 image in every iteration
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")
# try to use all the available GPUs
devices = [torch.device("cuda" if torch.cuda.is_available() else "cpu")]
Expand Down Expand Up @@ -90,11 +89,10 @@ def main(tempdir):
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
val_labels = decollate_batch(val_labels)
meta_data = decollate_batch(val_data["img_meta_dict"])
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
for val_output, data in zip(val_outputs, meta_data):
saver(val_output, data)
for val_output in val_outputs:
saver(val_output)
# aggregate the final mean dice result
print("evaluation metric:", dice_metric.aggregate().item())
# reset the status
Expand Down
12 changes: 1 addition & 11 deletions 3d_segmentation/torch/unet_inference_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
Resized,
SaveImaged,
ScaleIntensityd,
EnsureTyped,
)


Expand All @@ -58,34 +57,25 @@ def main(tempdir):
Orientationd(keys="img", axcodes="RAS"),
Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True),
ScaleIntensityd(keys="img"),
EnsureTyped(keys="img"),
])
# define dataset and dataloader
dataset = Dataset(data=files, transform=pre_transforms)
dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
# define post transforms
post_transforms = Compose([
EnsureTyped(keys="pred"),
Activationsd(keys="pred", sigmoid=True),
Invertd(
keys="pred", # invert the `pred` data field, also support multiple fields
transform=pre_transforms,
orig_keys="img", # get the previously applied pre_transforms information on the `img` data field,
# then invert `pred` based on this information. we can use same info
# for multiple fields, also support different orig_keys for different fields
meta_keys="pred_meta_dict", # key field to save inverted meta data, every item maps to `keys`
orig_meta_keys="img_meta_dict", # get the meta data from `img_meta_dict` field when inverting,
# for example, may need the `affine` to invert `Spacingd` transform,
# multiple fields can use the same meta data to invert
meta_key_postfix="meta_dict", # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
# if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}",
# otherwise, no need this arg during inverting
nearest_interp=False, # don't change the interpolation mode to "nearest" when inverting transforms
# to ensure a smooth output, then execute `AsDiscreted` transform
to_tensor=True, # convert to PyTorch Tensor after inverting
),
AsDiscreted(keys="pred", threshold=0.5),
SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
SaveImaged(keys="pred", output_dir="./out", output_postfix="seg", resample=False),
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
Loading