Skip to content

Commit 73378dd

Browse files
authored
updates to Gaussian map for sliding window inference (#5302)
This updates the calculation of Gaussian map (weights) during sliding window inference with "gaussian" Current version had multiple small issues - it computed Gaussian weight map (image) via nn.Conv1d sequence with an empty image (with a single 1 in the middle). We don't need to run any convolutions, it's much simpler to directly calculate the Gaussian map (it's also faster and takes less memory) - For patch_sizes of even size (e.g. 128x128) it centered Gaussian on patch_size//2 which is 0.5 pixel off-center (I'm not sure why we did it. - Finally the Gaussian 1d convolutions were done approximately (with 'erf' internal approximation and truncated to sigma=4). I'm not sure why we need any approximations here at all, it's trivial to compute the Gaussian weight map directly ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <[email protected]>
1 parent bb81a23 commit 73378dd

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

monai/data/utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from monai import config
3131
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
3232
from monai.data.meta_obj import MetaObj
33-
from monai.networks.layers.simplelayers import GaussianFilter
3433
from monai.utils import (
3534
MAX_SEED,
3635
BlendMode,
@@ -1067,17 +1066,16 @@ def compute_importance_map(
10671066
if mode == BlendMode.CONSTANT:
10681067
importance_map = torch.ones(patch_size, device=device, dtype=torch.float)
10691068
elif mode == BlendMode.GAUSSIAN:
1070-
center_coords = [i // 2 for i in patch_size]
1069+
10711070
sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size))
10721071
sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)]
10731072

1074-
importance_map = torch.zeros(patch_size, device=device)
1075-
importance_map[tuple(center_coords)] = 1
1076-
pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(device=device, dtype=torch.float)
1077-
importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0))
1078-
importance_map = importance_map.squeeze(0).squeeze(0)
1079-
importance_map = importance_map / torch.max(importance_map)
1080-
importance_map = importance_map.float()
1073+
for i in range(len(patch_size)):
1074+
x = torch.arange(
1075+
start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=torch.float, device=device
1076+
)
1077+
x = torch.exp(x**2 / (-2 * sigmas[i] ** 2)) # 1D gaussian
1078+
importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x
10811079
else:
10821080
raise ValueError(
10831081
f"Unsupported mode: {mode}, available options are [{BlendMode.CONSTANT}, {BlendMode.CONSTANT}]."

0 commit comments

Comments
 (0)