Skip to content

Commit ae95bf9

Browse files
authored
Zarr Merger (#6633)
Fixes #6006 ### Description This PR implements `ZarrAvgMerger` which can be used for patch inference. Also a use case is demonstrated [here](Project-MONAI/tutorials#1433). ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Behrooz <[email protected]>
1 parent d4b9552 commit ae95bf9

File tree

10 files changed

+549
-9
lines changed

10 files changed

+549
-9
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ optuna
3737
opencv-python-headless
3838
onnx>=1.13.0
3939
onnxruntime; python_version <= '3.10'
40+
zarr

docs/source/inferers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ Mergers
7777
:members:
7878
:special-members: __call__
7979

80+
`ZarrAvgMerger`
81+
~~~~~~~~~~~~~~~
82+
.. autoclass:: ZarrAvgMerger
83+
:members:
84+
:special-members: __call__
8085

8186

8287
Sliding Window Inference Function

docs/source/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
254254
- The options are
255255

256256
```
257-
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime]
257+
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
258258
```
259259

260260
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
261-
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, respectively.
261+
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively.
262262

263263
- `pip install 'monai[all]'` installs all the optional dependencies.

monai/inferers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
SlidingWindowInferer,
2121
SlidingWindowInfererAdapt,
2222
)
23-
from .merger import AvgMerger, Merger
23+
from .merger import AvgMerger, Merger, ZarrAvgMerger
2424
from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter
2525
from .utils import sliding_window_inference

monai/inferers/merger.py

Lines changed: 213 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,24 @@
1111

1212
from __future__ import annotations
1313

14+
import threading
1415
from abc import ABC, abstractmethod
1516
from collections.abc import Sequence
16-
from typing import Any
17+
from contextlib import nullcontext
18+
from typing import TYPE_CHECKING, Any
1719

20+
import numpy as np
1821
import torch
1922

20-
from monai.utils import ensure_tuple_size
23+
from monai.utils import ensure_tuple_size, optional_import, require_pkg
2124

22-
__all__ = ["Merger", "AvgMerger"]
25+
if TYPE_CHECKING:
26+
import zarr
27+
else:
28+
zarr, _ = optional_import("zarr")
29+
30+
31+
__all__ = ["Merger", "AvgMerger", "ZarrAvgMerger"]
2332

2433

2534
class Merger(ABC):
@@ -97,9 +106,9 @@ def __init__(
97106
self,
98107
merged_shape: Sequence[int],
99108
cropped_shape: Sequence[int] | None = None,
100-
device: torch.device | str = "cpu",
101109
value_dtype: torch.dtype = torch.float32,
102110
count_dtype: torch.dtype = torch.uint8,
111+
device: torch.device | str = "cpu",
103112
) -> None:
104113
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape, device=device)
105114
if not self.merged_shape:
@@ -152,12 +161,21 @@ def finalize(self) -> torch.Tensor:
152161

153162
return self.values
154163

164+
def get_output(self) -> torch.Tensor:
165+
"""
166+
Get the final merged output.
167+
168+
Returns:
169+
torch.Tensor: merged output.
170+
"""
171+
return self.finalize()
172+
155173
def get_values(self) -> torch.Tensor:
156174
"""
157175
Get the accumulated values during aggregation or final averaged values after it is finalized.
158176
159177
Returns:
160-
Merged (averaged) output tensor.
178+
torch.tensor: aggregated values.
161179
162180
Notes:
163181
- If called before calling `finalize()`, this method returns the accumulating values.
@@ -170,6 +188,195 @@ def get_counts(self) -> torch.Tensor:
170188
Get the aggregator tensor for number of samples.
171189
172190
Returns:
173-
torch.Tensor: Number of accumulated samples at each location.
191+
torch.Tensor: number of accumulated samples at each location.
174192
"""
175193
return self.counts
194+
195+
196+
@require_pkg(pkg_name="zarr")
197+
class ZarrAvgMerger(Merger):
198+
"""Merge patches by taking average of the overlapping area and store the results in zarr array.
199+
200+
Zarr is a format for the storage of chunked, compressed, N-dimensional arrays.
201+
Zarr data can be stored in any storage system that can be represented as a key-value store,
202+
like POSIX file systems, cloud object storage, zip files, and relational and document databases.
203+
See https://zarr.readthedocs.io/en/stable/ for more details.
204+
It is particularly useful for storing N-dimensional arrays too large to fit into memory.
205+
One specific use case of this class is to merge patches extracted from whole slide images (WSI),
206+
where the merged results do not fit into memory and need to be stored on a file system.
207+
208+
Args:
209+
merged_shape: the shape of the tensor required to merge the patches.
210+
cropped_shape: the shape of the final merged output tensor.
211+
If not provided, it will be the same as `merged_shape`.
212+
dtype: the dtype for the final merged result. Default is `float32`.
213+
value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`.
214+
count_dtype: the dtype for sample counting tensor. Default is `uint8`.
215+
store: the zarr store to save the final results. Default is "merged.zarr".
216+
value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
217+
count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
218+
compressor: the compressor for final merged zarr array. Default is "default".
219+
value_compressor: the compressor for value aggregating zarr array. Default is None.
220+
count_compressor: the compressor for sample counting zarr array. Default is None.
221+
chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
222+
If True, chunk shape will be guessed from `shape` and `dtype`.
223+
If False, it will be set to `shape`, i.e., single chunk for the whole array.
224+
If an int, the chunk size in each dimension will be given by the value of `chunks`.
225+
"""
226+
227+
def __init__(
228+
self,
229+
merged_shape: Sequence[int],
230+
cropped_shape: Sequence[int] | None = None,
231+
dtype: np.dtype | str = "float32",
232+
value_dtype: np.dtype | str = "float32",
233+
count_dtype: np.dtype | str = "uint8",
234+
store: zarr.storage.Store | str = "merged.zarr",
235+
value_store: zarr.storage.Store | str | None = None,
236+
count_store: zarr.storage.Store | str | None = None,
237+
compressor: str = "default",
238+
value_compressor: str | None = None,
239+
count_compressor: str | None = None,
240+
chunks: Sequence[int] | bool = True,
241+
thread_locking: bool = True,
242+
) -> None:
243+
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape)
244+
if not self.merged_shape:
245+
raise ValueError(f"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.")
246+
self.output_dtype = dtype
247+
self.value_dtype = value_dtype
248+
self.count_dtype = count_dtype
249+
self.store = store
250+
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
251+
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
252+
self.chunks = chunks
253+
self.compressor = compressor
254+
self.value_compressor = value_compressor
255+
self.count_compressor = count_compressor
256+
self.output = zarr.empty(
257+
shape=self.merged_shape,
258+
chunks=self.chunks,
259+
dtype=self.output_dtype,
260+
compressor=self.compressor,
261+
store=self.store,
262+
overwrite=True,
263+
)
264+
self.values = zarr.zeros(
265+
shape=self.merged_shape,
266+
chunks=self.chunks,
267+
dtype=self.value_dtype,
268+
compressor=self.value_compressor,
269+
store=self.value_store,
270+
overwrite=True,
271+
)
272+
self.counts = zarr.zeros(
273+
shape=self.merged_shape,
274+
chunks=self.chunks,
275+
dtype=self.count_dtype,
276+
compressor=self.count_compressor,
277+
store=self.count_store,
278+
overwrite=True,
279+
)
280+
self.lock: threading.Lock | nullcontext
281+
if thread_locking:
282+
# use lock to protect the in-place addition during aggregation
283+
self.lock = threading.Lock()
284+
else:
285+
# use nullcontext to avoid the locking if not needed
286+
self.lock = nullcontext()
287+
288+
def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:
289+
"""
290+
Aggregate values for merging.
291+
292+
Args:
293+
values: a tensor of shape BCHW[D], representing the values of inference output.
294+
location: a tuple/list giving the top left location of the patch in the original image.
295+
"""
296+
if self.is_finalized:
297+
raise ValueError("`ZarrAvgMerger` is already finalized. Please instantiate a new object to aggregate.")
298+
patch_size = values.shape[2:]
299+
map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size))
300+
map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
301+
with self.lock:
302+
self.values[map_slice] += values.numpy()
303+
self.counts[map_slice] += 1
304+
305+
def finalize(self) -> zarr.Array:
306+
"""
307+
Finalize merging by dividing values by counts and return the merged tensor.
308+
309+
Notes:
310+
To avoid creating a new tensor for the final results (to save memory space),
311+
after this method is called, `get_values()` method will return the "final" averaged values,
312+
and not the accumulating values. Also calling `finalize()` multiple times does not have any effect.
313+
314+
Returns:
315+
zarr.Array: a zarr array of of merged patches
316+
"""
317+
# guard against multiple calls to finalize
318+
if not self.is_finalized:
319+
# use chunks for division to fit into memory
320+
for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
321+
self.output[chunk] = self.values[chunk] / self.counts[chunk]
322+
# finalize the shape
323+
self.output.resize(self.cropped_shape)
324+
# set finalize flag to protect performing in-place division again
325+
self.is_finalized = True
326+
327+
return self.output
328+
329+
def get_output(self) -> zarr.Array:
330+
"""
331+
Get the final merged output.
332+
333+
Returns:
334+
zarr.Array: Merged (averaged) output tensor.
335+
"""
336+
return self.output
337+
338+
def get_values(self) -> zarr.Array:
339+
"""
340+
Get the accumulated values during aggregation
341+
342+
Returns:
343+
zarr.Array: aggregated values.
344+
345+
"""
346+
return self.values
347+
348+
def get_counts(self) -> zarr.Array:
349+
"""
350+
Get the aggregator tensor for number of samples.
351+
352+
Returns:
353+
zarr.Array: Number of accumulated samples at each location.
354+
"""
355+
return self.counts
356+
357+
358+
def iterate_over_chunks(chunks, cdata_shape, slice_tuple=()):
359+
"""
360+
Iterate over chunks of a given shape.
361+
362+
Args:
363+
chunks: the chunk shape
364+
cdata_shape: the shape of the data in chunks
365+
slice_tuple: the slice tuple to be used for indexing
366+
367+
Raises:
368+
ValueError: When the length of chunks and cdata_shape are not the same.
369+
370+
Yields:
371+
slices of the data
372+
"""
373+
if len(chunks) != len(cdata_shape):
374+
raise ValueError("chunks and cdata_shape must have the same length")
375+
if len(chunks) == 1:
376+
for i in range(cdata_shape[0]):
377+
yield slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),)
378+
else:
379+
for i in range(cdata_shape[0]):
380+
yield from iterate_over_chunks(
381+
chunks[1:], cdata_shape[1:], slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),)
382+
)

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ onnx>=1.13.0
5252
onnxruntime; python_version <= '3.10'
5353
typeguard<3 # https://github.com/microsoft/nni/issues/5457
5454
filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
55+
zarr

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ all =
7979
optuna
8080
onnx>=1.13.0
8181
onnxruntime; python_version <= '3.10'
82+
zarr
8283
nibabel =
8384
nibabel
8485
ninja =
@@ -142,6 +143,8 @@ optuna =
142143
onnx =
143144
onnx>=1.13.0
144145
onnxruntime; python_version <= '3.10'
146+
zarr =
147+
zarr
145148
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
146149
# MetricsReloaded =
147150
# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def run_testsuit():
202202
"test_metrics_reloaded",
203203
"test_spatial_combine_transforms",
204204
"test_bundle_workflow",
205+
"test_zarr_avg_merger",
205206
]
206207
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
207208

tests/test_download_url_yandex.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030

3131
class TestDownloadUrlYandex(unittest.TestCase):
32+
@unittest.skip("data source unstable")
3233
def test_verify(self):
3334
with tempfile.TemporaryDirectory() as tempdir:
3435
download_url(url=YANDEX_MODEL_URL, filepath=os.path.join(tempdir, "model.pt"))

0 commit comments

Comments
 (0)