|
30 | 30 | from monai import config
|
31 | 31 | from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
|
32 | 32 | from monai.data.meta_obj import MetaObj
|
33 |
| -from monai.networks.layers.simplelayers import GaussianFilter |
34 | 33 | from monai.utils import (
|
35 | 34 | MAX_SEED,
|
36 | 35 | BlendMode,
|
@@ -1067,17 +1066,16 @@ def compute_importance_map(
|
1067 | 1066 | if mode == BlendMode.CONSTANT:
|
1068 | 1067 | importance_map = torch.ones(patch_size, device=device, dtype=torch.float)
|
1069 | 1068 | elif mode == BlendMode.GAUSSIAN:
|
1070 |
| - center_coords = [i // 2 for i in patch_size] |
| 1069 | + |
1071 | 1070 | sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size))
|
1072 | 1071 | sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)]
|
1073 | 1072 |
|
1074 |
| - importance_map = torch.zeros(patch_size, device=device) |
1075 |
| - importance_map[tuple(center_coords)] = 1 |
1076 |
| - pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(device=device, dtype=torch.float) |
1077 |
| - importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0)) |
1078 |
| - importance_map = importance_map.squeeze(0).squeeze(0) |
1079 |
| - importance_map = importance_map / torch.max(importance_map) |
1080 |
| - importance_map = importance_map.float() |
| 1073 | + for i in range(len(patch_size)): |
| 1074 | + x = torch.arange( |
| 1075 | + start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=torch.float, device=device |
| 1076 | + ) |
| 1077 | + x = torch.exp(x**2 / (-2 * sigmas[i] ** 2)) # 1D gaussian |
| 1078 | + importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x |
1081 | 1079 | else:
|
1082 | 1080 | raise ValueError(
|
1083 | 1081 | f"Unsupported mode: {mode}, available options are [{BlendMode.CONSTANT}, {BlendMode.CONSTANT}]."
|
|
0 commit comments