Skip to content

Commit 9731823

Browse files
authored
Adding a median filter for 3D images (#5307)
Co-authored-by: @ebrahimebrahim. Closes #5264. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder.
1 parent 73378dd commit 9731823

File tree

11 files changed

+419
-12
lines changed

11 files changed

+419
-12
lines changed

docs/source/networks.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,16 @@ Layers
334334
.. autoclass:: GaussianFilter
335335
:members:
336336

337+
`MedianFilter`
338+
~~~~~~~~~~~~~~
339+
.. autoclass:: MedianFilter
340+
:members:
341+
342+
`median_filter`
343+
~~~~~~~~~~~~~~~
344+
.. autoclass:: median_filter
345+
:members:
346+
337347
`BilateralFilter`
338348
~~~~~~~~~~~~~~~~~
339349
.. autoclass:: BilateralFilter

docs/source/transforms.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,14 @@ Intensity
336336
:members:
337337
:special-members: __call__
338338

339+
`MedianSmooth`
340+
""""""""""""""
341+
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/MedianSmooth.png
342+
:alt: example of MedianSmooth
343+
.. autoclass:: MedianSmooth
344+
:members:
345+
:special-members: __call__
346+
339347
`GaussianSmooth`
340348
""""""""""""""""
341349
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSmooth.png
@@ -1415,6 +1423,14 @@ Intensity (Dict)
14151423
:members:
14161424
:special-members: __call__
14171425

1426+
`MedianSmoothd`
1427+
"""""""""""""""
1428+
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/MedianSmoothd.png
1429+
:alt: example of MedianSmoothd
1430+
.. autoclass:: MedianSmoothd
1431+
:members:
1432+
:special-members: __call__
1433+
14181434
`GaussianSmoothd`
14191435
"""""""""""""""""
14201436
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSmoothd.png

monai/networks/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
Flatten,
2121
GaussianFilter,
2222
HilbertTransform,
23+
MedianFilter,
2324
Reshape,
2425
SavitzkyGolayFilter,
2526
SkipConnection,
2627
apply_filter,
28+
median_filter,
2729
separable_filtering,
2830
)
2931
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push

monai/networks/layers/simplelayers.py

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import math
1313
from copy import deepcopy
14-
from typing import List, Sequence, Union
14+
from typing import List, Optional, Sequence, Union
1515

1616
import torch
1717
import torch.nn.functional as F
@@ -20,8 +20,16 @@
2020

2121
from monai.networks.layers.convutils import gaussian_1d
2222
from monai.networks.layers.factories import Conv
23-
from monai.utils import ChannelMatching, SkipMode, look_up_option, optional_import, pytorch_after
24-
from monai.utils.misc import issequenceiterable
23+
from monai.utils import (
24+
ChannelMatching,
25+
SkipMode,
26+
convert_to_tensor,
27+
ensure_tuple_rep,
28+
issequenceiterable,
29+
look_up_option,
30+
optional_import,
31+
pytorch_after,
32+
)
2533

2634
_C, _ = optional_import("monai._C")
2735
fft, _ = optional_import("torch.fft")
@@ -32,10 +40,12 @@
3240
"GaussianFilter",
3341
"HilbertTransform",
3442
"LLTM",
43+
"MedianFilter",
3544
"Reshape",
3645
"SavitzkyGolayFilter",
3746
"SkipConnection",
3847
"apply_filter",
48+
"median_filter",
3949
"separable_filtering",
4050
]
4151

@@ -168,7 +178,6 @@ def _separable_filtering_conv(
168178
paddings: List[int],
169179
num_channels: int,
170180
) -> torch.Tensor:
171-
172181
if d < 0:
173182
return input_
174183

@@ -434,6 +443,121 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
434443
return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype)
435444

436445

446+
def get_binary_kernel(window_size: Sequence[int], dtype=torch.float, device=None) -> torch.Tensor:
447+
"""
448+
Create a binary kernel to extract the patches.
449+
The window size HxWxD will create a (H*W*D)xHxWxD kernel.
450+
"""
451+
win_size = convert_to_tensor(window_size, int, wrap_sequence=True)
452+
prod = torch.prod(win_size)
453+
s = [prod, 1, *win_size]
454+
return torch.diag(torch.ones(prod, dtype=dtype, device=device)).view(s) # type: ignore
455+
456+
457+
def median_filter(
458+
in_tensor: torch.Tensor,
459+
kernel_size: Sequence[int] = (3, 3, 3),
460+
spatial_dims: int = 3,
461+
kernel: Optional[torch.Tensor] = None,
462+
**kwargs,
463+
) -> torch.Tensor:
464+
"""
465+
Apply median filter to an image.
466+
467+
Args:
468+
in_tensor: input tensor; median filtering will be applied to the last `spatial_dims` dimensions.
469+
kernel_size: the convolution kernel size.
470+
spatial_dims: number of spatial dimensions to apply median filtering.
471+
kernel: an optional customized kernel.
472+
kwargs: additional parameters to the `conv`.
473+
474+
Returns:
475+
the filtered input tensor, shape remains the same as ``in_tensor``
476+
477+
Example::
478+
479+
>>> from monai.networks.layers import median_filter
480+
>>> import torch
481+
>>> x = torch.rand(4, 5, 7, 6)
482+
>>> output = median_filter(x, (3, 3, 3))
483+
>>> output.shape
484+
torch.Size([4, 5, 7, 6])
485+
486+
"""
487+
if not isinstance(in_tensor, torch.Tensor):
488+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(in_tensor)}")
489+
490+
original_shape = in_tensor.shape
491+
oshape, sshape = original_shape[: len(original_shape) - spatial_dims], original_shape[-spatial_dims:]
492+
oprod = torch.prod(convert_to_tensor(oshape, int, wrap_sequence=True))
493+
# prepare kernel
494+
if kernel is None:
495+
kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)
496+
kernel = get_binary_kernel(kernel_size, in_tensor.dtype, in_tensor.device)
497+
else:
498+
kernel = kernel.to(in_tensor)
499+
# map the local window to single vector
500+
conv = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] # type: ignore
501+
if "padding" not in kwargs:
502+
if pytorch_after(1, 10):
503+
kwargs["padding"] = "same"
504+
else:
505+
# even-sized kernels are not supported
506+
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
507+
elif kwargs["padding"] == "same" and not pytorch_after(1, 10):
508+
# even-sized kernels are not supported
509+
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
510+
features: torch.Tensor = conv(in_tensor.reshape(oprod, 1, *sshape), kernel, stride=1, **kwargs) # type: ignore
511+
features = features.view(oprod, -1, *sshape) # type: ignore
512+
513+
# compute the median along the feature axis
514+
median: torch.Tensor = torch.median(features, dim=1)[0]
515+
median = median.reshape(original_shape)
516+
517+
return median
518+
519+
520+
class MedianFilter(nn.Module):
521+
"""
522+
Apply median filter to an image.
523+
524+
Args:
525+
radius: the blurring kernel radius (radius of 1 corresponds to 3x3x3 kernel when spatial_dims=3).
526+
527+
Returns:
528+
filtered input tensor.
529+
530+
Example::
531+
532+
>>> from monai.networks.layers import MedianFilter
533+
>>> import torch
534+
>>> in_tensor = torch.rand(4, 5, 7, 6)
535+
>>> blur = MedianFilter([1, 1, 1]) # 3x3x3 kernel
536+
>>> output = blur(in_tensor)
537+
>>> output.shape
538+
torch.Size([4, 5, 7, 6])
539+
540+
"""
541+
542+
def __init__(self, radius: Union[Sequence[int], int], spatial_dims: int = 3, device="cpu") -> None:
543+
super().__init__()
544+
self.spatial_dims = spatial_dims
545+
self.radius: Sequence[int] = ensure_tuple_rep(radius, spatial_dims)
546+
self.window: Sequence[int] = [1 + 2 * deepcopy(r) for r in self.radius]
547+
self.kernel = get_binary_kernel(self.window, device=device)
548+
549+
def forward(self, in_tensor: torch.Tensor, number_of_passes=1) -> torch.Tensor:
550+
"""
551+
Args:
552+
in_tensor: input tensor, median filtering will be applied to the last `spatial_dims` dimensions.
553+
number_of_passes: median filtering will be repeated this many times
554+
"""
555+
x = in_tensor
556+
for _ in range(number_of_passes):
557+
x = median_filter(x, kernel=self.kernel, spatial_dims=self.spatial_dims)
558+
return x
559+
560+
437561
class GaussianFilter(nn.Module):
438562
def __init__(
439563
self,

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
IntensityRemap,
100100
KSpaceSpikeNoise,
101101
MaskIntensity,
102+
MedianSmooth,
102103
NormalizeIntensity,
103104
RandAdjustContrast,
104105
RandBiasField,
@@ -152,6 +153,9 @@
152153
MaskIntensityd,
153154
MaskIntensityD,
154155
MaskIntensityDict,
156+
MedianSmoothd,
157+
MedianSmoothD,
158+
MedianSmoothDict,
155159
NormalizeIntensityd,
156160
NormalizeIntensityD,
157161
NormalizeIntensityDict,

monai/transforms/intensity/array.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
2727
from monai.data.meta_obj import get_track_meta
2828
from monai.data.utils import get_random_patch, get_valid_patch_size
29-
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
29+
from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
3030
from monai.transforms.transform import RandomizableTransform, Transform
3131
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
3232
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
@@ -56,6 +56,7 @@
5656
"MaskIntensity",
5757
"DetectEnvelope",
5858
"SavitzkyGolaySmooth",
59+
"MedianSmooth",
5960
"GaussianSmooth",
6061
"RandGaussianSmooth",
6162
"GaussianSharpen",
@@ -1136,6 +1137,35 @@ def __call__(self, img: NdarrayOrTensor):
11361137
return out
11371138

11381139

1140+
class MedianSmooth(Transform):
1141+
"""
1142+
Apply median filter to the input data based on specified `radius` parameter.
1143+
A default value `radius=1` is provided for reference.
1144+
1145+
See also: :py:func:`monai.networks.layers.median_filter`
1146+
1147+
Args:
1148+
radius: if a list of values, must match the count of spatial dimensions of input data,
1149+
and apply every value in the list to 1 spatial dimension. if only 1 value provided,
1150+
use it for all spatial dimensions.
1151+
"""
1152+
1153+
backend = [TransformBackends.TORCH]
1154+
1155+
def __init__(self, radius: Union[Sequence[int], int] = 1) -> None:
1156+
self.radius = radius
1157+
1158+
def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
1159+
img = convert_to_tensor(img, track_meta=get_track_meta())
1160+
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
1161+
spatial_dims = img_t.ndim - 1
1162+
r = ensure_tuple_rep(self.radius, spatial_dims)
1163+
median_filter_instance = MedianFilter(r, spatial_dims=spatial_dims)
1164+
out_t: torch.Tensor = median_filter_instance(img_t)
1165+
out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype)
1166+
return out
1167+
1168+
11391169
class GaussianSmooth(Transform):
11401170
"""
11411171
Apply Gaussian smooth to the input data based on specified `sigma` parameter.

monai/transforms/intensity/dictionary.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HistogramNormalize,
3333
KSpaceSpikeNoise,
3434
MaskIntensity,
35+
MedianSmooth,
3536
NormalizeIntensity,
3637
RandAdjustContrast,
3738
RandBiasField,
@@ -78,6 +79,7 @@
7879
"ScaleIntensityRangePercentilesd",
7980
"MaskIntensityd",
8081
"SavitzkyGolaySmoothd",
82+
"MedianSmoothd",
8183
"GaussianSmoothd",
8284
"RandGaussianSmoothd",
8385
"GaussianSharpend",
@@ -124,6 +126,8 @@
124126
"MaskIntensityDict",
125127
"SavitzkyGolaySmoothD",
126128
"SavitzkyGolaySmoothDict",
129+
"MedianSmoothD",
130+
"MedianSmoothDict",
127131
"GaussianSmoothD",
128132
"GaussianSmoothDict",
129133
"RandGaussianSmoothD",
@@ -988,6 +992,35 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
988992
return d
989993

990994

995+
class MedianSmoothd(MapTransform):
996+
"""
997+
Dictionary-based wrapper of :py:class:`monai.transforms.MedianSmooth`.
998+
999+
Args:
1000+
keys: keys of the corresponding items to be transformed.
1001+
See also: :py:class:`monai.transforms.compose.MapTransform`
1002+
radius: if a list of values, must match the count of spatial dimensions of input data,
1003+
and apply every value in the list to 1 spatial dimension. if only 1 value provided,
1004+
use it for all spatial dimensions.
1005+
allow_missing_keys: don't raise exception if key is missing.
1006+
1007+
"""
1008+
1009+
backend = MedianSmooth.backend
1010+
1011+
def __init__(
1012+
self, keys: KeysCollection, radius: Union[Sequence[int], int], allow_missing_keys: bool = False
1013+
) -> None:
1014+
super().__init__(keys, allow_missing_keys)
1015+
self.converter = MedianSmooth(radius)
1016+
1017+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
1018+
d = dict(data)
1019+
for key in self.key_iterator(d):
1020+
d[key] = self.converter(d[key])
1021+
return d
1022+
1023+
9911024
class GaussianSmoothd(MapTransform):
9921025
"""
9931026
Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSmooth`.
@@ -1780,6 +1813,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
17801813
ScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd
17811814
MaskIntensityD = MaskIntensityDict = MaskIntensityd
17821815
SavitzkyGolaySmoothD = SavitzkyGolaySmoothDict = SavitzkyGolaySmoothd
1816+
MedianSmoothD = MedianSmoothDict = MedianSmoothd
17831817
GaussianSmoothD = GaussianSmoothDict = GaussianSmoothd
17841818
RandGaussianSmoothD = RandGaussianSmoothDict = RandGaussianSmoothd
17851819
GaussianSharpenD = GaussianSharpenDict = GaussianSharpend

0 commit comments

Comments
 (0)