Skip to content

Commit 8956d19

Browse files
authored
Merge branch 'main' into 782-acceleration-metatensor
2 parents 92091f7 + 77cb6eb commit 8956d19

32 files changed

+355
-1414
lines changed

3d_segmentation/brats_segmentation_3d.ipynb

Lines changed: 12 additions & 19 deletions
Large diffs are not rendered by default.

3d_segmentation/challenge_baseline/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ This directory contains a simple baseline method [using MONAI](https://monai.io)
2525

2626
The script is tested with:
2727

28-
- `Ubuntu 18.04` | `Python 3.6` | `CUDA 10.2`
28+
- `Ubuntu 20.04` | `Python 3.8` | `CUDA 11.7`
2929

3030
On a GPU with [automatic mixed precision support](https://developer.nvidia.com/automatic-mixed-precision):
3131

3d_segmentation/challenge_baseline/run_net.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
ScaleIntensityRanged,
3737
Spacingd,
3838
SpatialPadd,
39-
EnsureTyped,
4039
)
4140

4241

@@ -69,12 +68,12 @@ def get_xforms(mode="train", keys=("image", "label")):
6968
RandFlipd(keys, spatial_axis=2, prob=0.5),
7069
]
7170
)
72-
dtype = (np.float32, np.uint8)
71+
dtype = (torch.float32, torch.uint8)
7372
if mode == "val":
74-
dtype = (np.float32, np.uint8)
73+
dtype = (torch.float32, torch.uint8)
7574
if mode == "infer":
76-
dtype = (np.float32,)
77-
xforms.extend([CastToTyped(keys, dtype=dtype), EnsureTyped(keys)])
75+
dtype = (torch.float32,)
76+
xforms.extend([CastToTyped(keys, dtype=dtype)])
7877
return monai.transforms.Compose(xforms)
7978

8079

@@ -172,7 +171,7 @@ def train(data_folder=".", model_folder="runs"):
172171

173172
# create evaluator (to be used to measure model quality during training
174173
val_post_transform = monai.transforms.Compose(
175-
[EnsureTyped(keys=("pred", "label")), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=2)]
174+
[AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=2)]
176175
)
177176
val_handlers = [
178177
ProgressBar(),
@@ -246,7 +245,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
246245
saver = monai.data.NiftiSaver(output_dir=prediction_folder, mode="nearest")
247246
with torch.no_grad():
248247
for infer_data in infer_loader:
249-
logging.info(f"segmenting {infer_data['image_meta_dict']['filename_or_obj']}")
248+
logging.info(f"segmenting {infer_data['image'].meta['filename_or_obj']}")
250249
preds = inferer(infer_data[keys[0]].to(device), net)
251250
n = 1.0
252251
for _ in range(4):
@@ -262,7 +261,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
262261
n = n + 1.0
263262
preds = preds / n
264263
preds = (preds.argmax(dim=1, keepdims=True)).float()
265-
saver.save_batch(preds, infer_data["image_meta_dict"])
264+
saver.save_batch(preds, infer_data["image"].meta)
266265

267266
# copy the saved segmentations into the required folder structure for submission
268267
submission_dir = os.path.join(prediction_folder, "to_submit")

3d_segmentation/ignite/unet_evaluation_array.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
import numpy as np
2020
import torch
2121
from ignite.engine import Engine
22-
from torch.utils.data import DataLoader
2322

2423
from monai import config
25-
from monai.data import ImageDataset, create_test_image_3d, decollate_batch
24+
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
2625
from monai.handlers import CheckpointLoader, MeanDice, StatsHandler
2726
from monai.inferers import sliding_window_inference
2827
from monai.networks.nets import UNet
29-
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, SaveImage, ScaleIntensity, EnsureType
28+
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, SaveImage, ScaleIntensity
3029

3130

3231
def main(tempdir):
@@ -47,8 +46,8 @@ def main(tempdir):
4746
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
4847

4948
# define transforms for image and segmentation
50-
imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
51-
segtrans = Compose([AddChannel(), EnsureType()])
49+
imtrans = Compose([ScaleIntensity(), AddChannel()])
50+
segtrans = Compose([AddChannel()])
5251
ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
5352

5453
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -65,7 +64,7 @@ def main(tempdir):
6564
roi_size = (96, 96, 96)
6665
sw_batch_size = 4
6766

68-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
67+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
6968
save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg")
7069

7170
def _sliding_window_processor(engine, batch):
@@ -74,9 +73,8 @@ def _sliding_window_processor(engine, batch):
7473
val_images, val_labels = batch[0].to(device), batch[1].to(device)
7574
seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
7675
seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
77-
val_data = decollate_batch(batch[2])
78-
for seg_prob, data in zip(seg_probs, val_data):
79-
save_image(seg_prob, data)
76+
for seg_prob in seg_probs:
77+
save_image(seg_prob)
8078
return seg_probs, val_labels
8179

8280
evaluator = Engine(_sliding_window_processor)

3d_segmentation/ignite/unet_evaluation_dict.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from monai.handlers import CheckpointLoader, MeanDice, StatsHandler
2727
from monai.inferers import sliding_window_inference
2828
from monai.networks.nets import UNet
29-
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd, EnsureTyped, EnsureType
29+
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd
3030

3131

3232
def main(tempdir):
@@ -53,7 +53,6 @@ def main(tempdir):
5353
LoadImaged(keys=["img", "seg"]),
5454
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
5555
ScaleIntensityd(keys="img"),
56-
EnsureTyped(keys=["img", "seg"]),
5756
]
5857
)
5958
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
@@ -72,7 +71,7 @@ def main(tempdir):
7271
roi_size = (96, 96, 96)
7372
sw_batch_size = 4
7473

75-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
74+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
7675
save_image = SaveImage(output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg")
7776

7877
def _sliding_window_processor(engine, batch):
@@ -81,9 +80,8 @@ def _sliding_window_processor(engine, batch):
8180
val_images, val_labels = batch["img"].to(device), batch["seg"].to(device)
8281
seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
8382
seg_probs = [post_trans(i) for i in decollate_batch(seg_probs)]
84-
val_data = decollate_batch(batch["img_meta_dict"])
85-
for seg_prob, data in zip(seg_probs, val_data):
86-
save_image(seg_prob, data)
83+
for seg_prob in seg_probs:
84+
save_image(seg_prob)
8785
return seg_probs, val_labels
8886

8987
evaluator = Engine(_sliding_window_processor)

3d_segmentation/ignite/unet_training_array.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
import torch
2121
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
2222
from ignite.handlers import EarlyStopping, ModelCheckpoint
23-
from torch.utils.data import DataLoader
2423

2524
import monai
26-
from monai.data import ImageDataset, create_test_image_3d, decollate_batch
25+
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
2726
from monai.handlers import (
2827
MeanDice,
2928
StatsHandler,
@@ -39,7 +38,6 @@
3938
RandSpatialCrop,
4039
Resize,
4140
ScaleIntensity,
42-
EnsureType,
4341
)
4442

4543

@@ -67,16 +65,15 @@ def main(tempdir):
6765
ScaleIntensity(),
6866
AddChannel(),
6967
RandSpatialCrop((96, 96, 96), random_size=False),
70-
EnsureType(),
7168
]
7269
)
7370
train_segtrans = Compose(
74-
[AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), EnsureType()]
71+
[AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False)]
7572
)
7673
val_imtrans = Compose(
77-
[ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()]
74+
[ScaleIntensity(), AddChannel(), Resize((96, 96, 96))]
7875
)
79-
val_segtrans = Compose([AddChannel(), Resize((96, 96, 96)), EnsureType()])
76+
val_segtrans = Compose([AddChannel(), Resize((96, 96, 96))])
8077

8178
# define image dataset, data loader
8279
check_ds = ImageDataset(
@@ -151,8 +148,8 @@ def main(tempdir):
151148
# add evaluation metric to the evaluator engine
152149
val_metrics = {metric_name: MeanDice()}
153150

154-
post_pred = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
155-
post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)])
151+
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
152+
post_label = Compose([AsDiscrete(threshold=0.5)])
156153

157154
# Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
158155
# user can add output_transform to return other values

3d_segmentation/ignite/unet_training_dict.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@
4545
RandCropByPosNegLabeld,
4646
RandRotate90d,
4747
ScaleIntensityd,
48-
EnsureTyped,
49-
EnsureType,
5048
)
5149

5250

@@ -85,15 +83,13 @@ def main(tempdir):
8583
num_samples=4,
8684
),
8785
RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
88-
EnsureTyped(keys=["img", "seg"]),
8986
]
9087
)
9188
val_transforms = Compose(
9289
[
9390
LoadImaged(keys=["img", "seg"]),
9491
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
9592
ScaleIntensityd(keys="img"),
96-
EnsureTyped(keys=["img", "seg"]),
9793
]
9894
)
9995

@@ -180,8 +176,8 @@ def prepare_batch(batch, device=None, non_blocking=False):
180176
# add evaluation metric to the evaluator engine
181177
val_metrics = {metric_name: MeanDice()}
182178

183-
post_pred = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
184-
post_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)])
179+
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
180+
post_label = Compose([AsDiscrete(threshold=0.5)])
185181

186182
# Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
187183
# user can add output_transform to return other values

3d_segmentation/spleen_segmentation_3d.ipynb

Lines changed: 9 additions & 18 deletions
Large diffs are not rendered by default.

3d_segmentation/spleen_segmentation_3d_lightning.ipynb

Lines changed: 12 additions & 15 deletions
Large diffs are not rendered by default.

3d_segmentation/spleen_segmentation_3d_visualization_basic.ipynb

Lines changed: 3 additions & 10 deletions
Large diffs are not rendered by default.

3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb

Lines changed: 5 additions & 9 deletions
Large diffs are not rendered by default.

3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb

Lines changed: 6 additions & 11 deletions
Large diffs are not rendered by default.

3d_segmentation/torch/unet_evaluation_array.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
import nibabel as nib
1919
import numpy as np
2020
import torch
21-
from torch.utils.data import DataLoader
2221

2322
from monai import config
24-
from monai.data import ImageDataset, create_test_image_3d, decollate_batch
23+
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
2524
from monai.inferers import sliding_window_inference
2625
from monai.metrics import DiceMetric
2726
from monai.networks.nets import UNet
28-
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, SaveImage, ScaleIntensity, EnsureType
27+
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, SaveImage, ScaleIntensity
2928

3029

3130
def main(tempdir):
@@ -46,13 +45,13 @@ def main(tempdir):
4645
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
4746

4847
# define transforms for image and segmentation
49-
imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
50-
segtrans = Compose([AddChannel(), EnsureType()])
48+
imtrans = Compose([ScaleIntensity(), AddChannel()])
49+
segtrans = Compose([AddChannel()])
5150
val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
5251
# sliding window inference for one image at every iteration
5352
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
5453
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
55-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
54+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
5655
saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")
5756
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5857
model = UNet(
@@ -75,11 +74,10 @@ def main(tempdir):
7574
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
7675
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
7776
val_labels = decollate_batch(val_labels)
78-
meta_data = decollate_batch(val_data[2])
7977
# compute metric for current iteration
8078
dice_metric(y_pred=val_outputs, y=val_labels)
81-
for val_output, data in zip(val_outputs, meta_data):
82-
saver(val_output, data)
79+
for val_output in val_outputs:
80+
saver(val_output)
8381
# aggregate the final mean dice result
8482
print("evaluation metric:", dice_metric.aggregate().item())
8583
# reset the status

3d_segmentation/torch/unet_evaluation_dict.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from monai.inferers import sliding_window_inference
2727
from monai.metrics import DiceMetric
2828
from monai.networks.nets import UNet
29-
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd, EnsureTyped, EnsureType
29+
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, SaveImage, ScaleIntensityd
3030

3131

3232
def main(tempdir):
@@ -53,14 +53,13 @@ def main(tempdir):
5353
LoadImaged(keys=["img", "seg"]),
5454
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
5555
ScaleIntensityd(keys="img"),
56-
EnsureTyped(keys=["img", "seg"]),
5756
]
5857
)
5958
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
6059
# sliding window inference need to input 1 image in every iteration
6160
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
6261
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
63-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
62+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
6463
saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")
6564
# try to use all the available GPUs
6665
devices = [torch.device("cuda" if torch.cuda.is_available() else "cpu")]
@@ -90,11 +89,10 @@ def main(tempdir):
9089
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
9190
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
9291
val_labels = decollate_batch(val_labels)
93-
meta_data = decollate_batch(val_data["img_meta_dict"])
9492
# compute metric for current iteration
9593
dice_metric(y_pred=val_outputs, y=val_labels)
96-
for val_output, data in zip(val_outputs, meta_data):
97-
saver(val_output, data)
94+
for val_output in val_outputs:
95+
saver(val_output)
9896
# aggregate the final mean dice result
9997
print("evaluation metric:", dice_metric.aggregate().item())
10098
# reset the status

3d_segmentation/torch/unet_inference_dict.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
Resized,
3535
SaveImaged,
3636
ScaleIntensityd,
37-
EnsureTyped,
3837
)
3938

4039

@@ -58,34 +57,25 @@ def main(tempdir):
5857
Orientationd(keys="img", axcodes="RAS"),
5958
Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True),
6059
ScaleIntensityd(keys="img"),
61-
EnsureTyped(keys="img"),
6260
])
6361
# define dataset and dataloader
6462
dataset = Dataset(data=files, transform=pre_transforms)
6563
dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
6664
# define post transforms
6765
post_transforms = Compose([
68-
EnsureTyped(keys="pred"),
6966
Activationsd(keys="pred", sigmoid=True),
7067
Invertd(
7168
keys="pred", # invert the `pred` data field, also support multiple fields
7269
transform=pre_transforms,
7370
orig_keys="img", # get the previously applied pre_transforms information on the `img` data field,
7471
# then invert `pred` based on this information. we can use same info
7572
# for multiple fields, also support different orig_keys for different fields
76-
meta_keys="pred_meta_dict", # key field to save inverted meta data, every item maps to `keys`
77-
orig_meta_keys="img_meta_dict", # get the meta data from `img_meta_dict` field when inverting,
78-
# for example, may need the `affine` to invert `Spacingd` transform,
79-
# multiple fields can use the same meta data to invert
80-
meta_key_postfix="meta_dict", # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
81-
# if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}",
82-
# otherwise, no need this arg during inverting
8373
nearest_interp=False, # don't change the interpolation mode to "nearest" when inverting transforms
8474
# to ensure a smooth output, then execute `AsDiscreted` transform
8575
to_tensor=True, # convert to PyTorch Tensor after inverting
8676
),
8777
AsDiscreted(keys="pred", threshold=0.5),
88-
SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
78+
SaveImaged(keys="pred", output_dir="./out", output_postfix="seg", resample=False),
8979
])
9080

9181
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

0 commit comments

Comments
 (0)