Skip to content

Commit e4b99e1

Browse files
Implement SlidingWindowHoVerNetInferer (#5531)
Fixes #5521 ### Description This PR implement `SlidingWindowHoVerNetInferer` that with HoVerNet model, where the output size is different than the input but cannot be scaled and should be padded and then cropped. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. Signed-off-by: Behrooz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 624e832 commit e4b99e1

File tree

7 files changed

+533
-7
lines changed

7 files changed

+533
-7
lines changed

docs/source/apps.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ Applications
9292
.. autoclass:: ProbMapProducer
9393
:members:
9494

95+
.. automodule:: monai.apps.pathology.inferers
96+
.. autoclass:: SlidingWindowHoVerNetInferer
97+
:members:
98+
9599
.. automodule:: monai.apps.pathology.losses.hovernet_loss
96100
.. autoclass:: HoVerNetLoss
97101
:members:
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from .inferer import SlidingWindowHoVerNetInferer
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
13+
14+
import numpy as np
15+
import torch
16+
import torch.nn.functional as F
17+
18+
from monai.inferers import SlidingWindowInferer
19+
from monai.inferers.utils import sliding_window_inference
20+
from monai.utils import BlendMode, PytorchPadMode, look_up_option
21+
22+
__all__ = ["SlidingWindowHoVerNetInferer"]
23+
24+
25+
class SlidingWindowHoVerNetInferer(SlidingWindowInferer):
26+
"""
27+
Sliding window method for HoVerNet model inference,
28+
with `sw_batch_size` windows for every model.forward().
29+
Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.
30+
31+
Args:
32+
roi_size: the window size to execute SlidingWindow evaluation.
33+
If it has non-positive components, the corresponding `inputs` size will be used.
34+
if the components of the `roi_size` are non-positive values, the transform will use the
35+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
36+
to `(32, 64)` if the second spatial dimension size of img is `64`.
37+
sw_batch_size: the batch size to run window slices.
38+
overlap: Amount of overlap between scans.
39+
mode: {``"constant"``, ``"gaussian"``}
40+
How to blend output of overlapping windows. Defaults to ``"constant"``.
41+
42+
- ``"constant``": gives equal weight to all predictions.
43+
- ``"gaussian``": gives less weight to predictions on edges of windows.
44+
45+
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
46+
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
47+
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
48+
spatial dimensions.
49+
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
50+
Padding mode when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
51+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
52+
cval: fill value for 'constant' padding mode. Default: 0
53+
sw_device: device for the window data.
54+
By default the device (and accordingly the memory) of the `inputs` is used.
55+
Normally `sw_device` should be consistent with the device where `predictor` is defined.
56+
device: device for the stitched output prediction.
57+
By default the device (and accordingly the memory) of the `inputs` is used. If for example
58+
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
59+
`inputs` and `roi_size`. Output is on the `device`.
60+
progress: whether to print a tqdm progress bar.
61+
cache_roi_weight_map: whether to pre-compute the ROI weight map.
62+
cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)
63+
when input image volume is larger than this threshold (in pixels/voxels).
64+
Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu.
65+
extra_input_padding: the amount of padding for the input image, which is a tuple of even number of pads.
66+
Refer to to the `pad` argument of `torch.nn.functional.pad` for more details.
67+
68+
Note:
69+
``sw_batch_size`` denotes the max number of windows per network inference iteration,
70+
not the batch size of inputs.
71+
72+
"""
73+
74+
def __init__(
75+
self,
76+
roi_size: Union[Sequence[int], int],
77+
sw_batch_size: int = 1,
78+
overlap: float = 0.25,
79+
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
80+
sigma_scale: Union[Sequence[float], float] = 0.125,
81+
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
82+
cval: float = 0.0,
83+
sw_device: Optional[Union[torch.device, str]] = None,
84+
device: Optional[Union[torch.device, str]] = None,
85+
progress: bool = False,
86+
cache_roi_weight_map: bool = False,
87+
cpu_thresh: Optional[int] = None,
88+
extra_input_padding: Optional[Tuple[int]] = None,
89+
) -> None:
90+
super().__init__(
91+
roi_size=roi_size,
92+
sw_batch_size=sw_batch_size,
93+
overlap=overlap,
94+
mode=mode,
95+
sigma_scale=sigma_scale,
96+
padding_mode=padding_mode,
97+
cval=cval,
98+
sw_device=sw_device,
99+
device=device,
100+
progress=progress,
101+
cache_roi_weight_map=cache_roi_weight_map,
102+
cpu_thresh=cpu_thresh,
103+
)
104+
self.extra_input_padding = extra_input_padding
105+
106+
def process_output(self, seg_prob_tuple, window_data, importance_map_):
107+
window_shape = window_data.shape[2:]
108+
seg_shape = seg_prob_tuple[0].shape[2:]
109+
110+
window_pad_size = []
111+
window_pad_slices = []
112+
for window_s, output_s in zip(window_shape, seg_shape):
113+
pad_width = max(window_s - output_s, 0)
114+
pad_half_1 = pad_width // 2
115+
pad_half_2 = pad_width - pad_half_1
116+
window_pad_size.extend([pad_half_1, pad_half_2])
117+
window_pad_slices.append(slice(pad_half_1, window_s - pad_half_2))
118+
119+
# Make the padding area of the importance map zero
120+
importance_map = torch.zeros(window_shape, dtype=importance_map_.dtype, device=importance_map_.device)
121+
importance_map[window_pad_slices] = importance_map_[window_pad_slices]
122+
123+
seg_prob_tuple = tuple(
124+
F.pad(seg_prob, pad=tuple(window_pad_size), mode=self.padding_mode, value=self.cval)
125+
for seg_prob in seg_prob_tuple
126+
)
127+
128+
return seg_prob_tuple, importance_map
129+
130+
def __call__(
131+
self,
132+
inputs: torch.Tensor,
133+
network: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
134+
*args: Any,
135+
**kwargs: Any,
136+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
137+
"""
138+
139+
Args:
140+
inputs: model input data for inference.
141+
network: target model to execute inference.
142+
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
143+
args: optional args to be passed to ``network``.
144+
kwargs: optional keyword args to be passed to ``network``.
145+
146+
"""
147+
148+
device = self.device
149+
if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh:
150+
device = "cpu" # stitch in cpu memory if image is too large
151+
152+
if self.extra_input_padding:
153+
image_size_original = inputs.shape[2:]
154+
num_spatial_dims = len(image_size_original)
155+
inputs = F.pad(
156+
inputs,
157+
pad=tuple(self.extra_input_padding),
158+
mode=look_up_option(self.padding_mode, PytorchPadMode),
159+
value=self.cval,
160+
)
161+
162+
results = sliding_window_inference(
163+
inputs,
164+
self.roi_size,
165+
self.sw_batch_size,
166+
network,
167+
self.overlap,
168+
self.mode,
169+
self.sigma_scale,
170+
self.padding_mode,
171+
self.cval,
172+
self.sw_device,
173+
device,
174+
self.progress,
175+
self.roi_weight_map,
176+
self.process_output,
177+
*args,
178+
**kwargs,
179+
)
180+
181+
if self.extra_input_padding:
182+
extra_slicing: List[slice] = []
183+
num_padded_dims = len(self.extra_input_padding) // 2
184+
for sp in range(num_padded_dims):
185+
slice_dim = slice(
186+
self.extra_input_padding[sp * 2],
187+
image_size_original[num_spatial_dims - sp - 1] + self.extra_input_padding[sp * 2],
188+
)
189+
extra_slicing.insert(0, slice_dim)
190+
for _ in range(len(inputs.shape) - num_padded_dims):
191+
extra_slicing.insert(0, slice(None))
192+
193+
if isinstance(results, dict):
194+
for k, v in results.items():
195+
results[k] = v[extra_slicing]
196+
elif isinstance(results, (list, tuple)):
197+
results = type(results)([res[extra_slicing] for res in results])
198+
elif isinstance(results, (torch.Tensor, np.ndarray)):
199+
results = results[extra_slicing]
200+
else:
201+
raise ValueError(
202+
f"The output [{type(results)}] should be either dict, list, tuple, torch.Tensor, or numpy array."
203+
)
204+
205+
return results

monai/inferers/inferer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def __call__(
213213
device,
214214
self.progress,
215215
self.roi_weight_map,
216+
None,
216217
*args,
217218
**kwargs,
218219
)

monai/inferers/utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
import warnings
13-
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
13+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
1414

1515
import torch
1616
import torch.nn.functional as F
@@ -47,7 +47,8 @@ def sliding_window_inference(
4747
sw_device: Union[torch.device, str, None] = None,
4848
device: Union[torch.device, str, None] = None,
4949
progress: bool = False,
50-
roi_weight_map: Union[torch.Tensor, None] = None,
50+
roi_weight_map: Optional[torch.Tensor] = None,
51+
process_fn: Optional[Callable] = None,
5152
*args: Any,
5253
**kwargs: Any,
5354
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
@@ -108,6 +109,7 @@ def sliding_window_inference(
108109
progress: whether to print a `tqdm` progress bar.
109110
roi_weight_map: pre-computed (non-negative) weight map for each ROI.
110111
If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
112+
process_fn: process inference output and adjust the importance map per window
111113
args: optional args to be passed to ``predictor``.
112114
kwargs: optional keyword args to be passed to ``predictor``.
113115
@@ -149,19 +151,21 @@ def sliding_window_inference(
149151
# Create window-level importance map
150152
valid_patch_size = get_valid_patch_size(image_size, roi_size)
151153
if valid_patch_size == roi_size and (roi_weight_map is not None):
152-
importance_map = roi_weight_map
154+
importance_map_ = roi_weight_map
153155
else:
154156
try:
155-
importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
157+
importance_map_ = compute_importance_map(
158+
valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device
159+
)
156160
except BaseException as e:
157161
raise RuntimeError(
158162
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
159163
) from e
160-
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0]
164+
importance_map_ = convert_data_type(importance_map_, torch.Tensor, device, compute_dtype)[0]
161165

162166
# handle non-positive weights
163-
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
164-
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
167+
min_non_zero = max(importance_map_[importance_map_ != 0].min().item(), 1e-3)
168+
importance_map_ = torch.clamp(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype)
165169

166170
# Perform predictions
167171
dict_key, output_image_list, count_map_list = None, [], []
@@ -193,6 +197,11 @@ def sliding_window_inference(
193197
seg_prob_tuple = ensure_tuple(seg_prob_out)
194198
is_tensor_output = False
195199

200+
if process_fn:
201+
seg_prob_tuple, importance_map = process_fn(seg_prob_tuple, window_data, importance_map_)
202+
else:
203+
importance_map = importance_map_
204+
196205
# for each output in multi-output list
197206
for ss, seg_prob in enumerate(seg_prob_tuple):
198207
seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN

0 commit comments

Comments
 (0)