Skip to content

Commit dae09ff

Browse files
myronwyli
andauthored
SlidingWindowInferer: option to adaptively stitch in cpu memory for large images (#5297)
SlidingWindowInferer: option to adaptively stitch in cpu memory for large images. This adds an option to provide maximum input image volume (number of elements) to dynamically change stitching to cpu memory (to avoid gpu memory crashes). For example with `cpu_thresh=400*400*400`, all input images with large volume will be stitched on cpu. At the moment, a user must decide beforehand, to stitch ALL images on cpu or gpu (by specifying the 'device' parameter). But in many datasets, only a few large images require device==cpu, and running inference on cpu for ALL will be unnecessary slow. It's related to #4625 #4495 #3497 #4726 #4588 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
1 parent 02a6a6d commit dae09ff

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

monai/inferers/inferer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ class SlidingWindowInferer(Inferer):
122122
`inputs` and `roi_size`. Output is on the `device`.
123123
progress: whether to print a tqdm progress bar.
124124
cache_roi_weight_map: whether to precompute the ROI weight map.
125+
cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)
126+
when input image volume is larger than this threshold (in pixels/volxels).
127+
Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu.
125128
126129
Note:
127130
``sw_batch_size`` denotes the max number of windows per network inference iteration,
@@ -142,8 +145,9 @@ def __init__(
142145
device: Union[torch.device, str, None] = None,
143146
progress: bool = False,
144147
cache_roi_weight_map: bool = False,
148+
cpu_thresh: Optional[int] = None,
145149
) -> None:
146-
Inferer.__init__(self)
150+
super().__init__()
147151
self.roi_size = roi_size
148152
self.sw_batch_size = sw_batch_size
149153
self.overlap = overlap
@@ -154,6 +158,7 @@ def __init__(
154158
self.sw_device = sw_device
155159
self.device = device
156160
self.progress = progress
161+
self.cpu_thresh = cpu_thresh
157162

158163
# compute_importance_map takes long time when computing on cpu. We thus
159164
# compute it once if it's static and then save it for future usage
@@ -189,6 +194,11 @@ def __call__(
189194
kwargs: optional keyword args to be passed to ``network``.
190195
191196
"""
197+
198+
device = self.device
199+
if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh:
200+
device = "cpu" # stitch in cpu memory if image is too large
201+
192202
return sliding_window_inference(
193203
inputs,
194204
self.roi_size,
@@ -200,7 +210,7 @@ def __call__(
200210
self.padding_mode,
201211
self.cval,
202212
self.sw_device,
203-
self.device,
213+
device,
204214
self.progress,
205215
self.roi_weight_map,
206216
*args,

0 commit comments

Comments
 (0)