11
11
12
12
from __future__ import annotations
13
13
14
+ import threading
14
15
from abc import ABC , abstractmethod
15
16
from collections .abc import Sequence
16
- from typing import Any
17
+ from contextlib import nullcontext
18
+ from typing import TYPE_CHECKING , Any
17
19
20
+ import numpy as np
18
21
import torch
19
22
20
- from monai .utils import ensure_tuple_size
23
+ from monai .utils import ensure_tuple_size , optional_import , require_pkg
21
24
22
- __all__ = ["Merger" , "AvgMerger" ]
25
+ if TYPE_CHECKING :
26
+ import zarr
27
+ else :
28
+ zarr , _ = optional_import ("zarr" )
29
+
30
+
31
+ __all__ = ["Merger" , "AvgMerger" , "ZarrAvgMerger" ]
23
32
24
33
25
34
class Merger (ABC ):
@@ -97,9 +106,9 @@ def __init__(
97
106
self ,
98
107
merged_shape : Sequence [int ],
99
108
cropped_shape : Sequence [int ] | None = None ,
100
- device : torch .device | str = "cpu" ,
101
109
value_dtype : torch .dtype = torch .float32 ,
102
110
count_dtype : torch .dtype = torch .uint8 ,
111
+ device : torch .device | str = "cpu" ,
103
112
) -> None :
104
113
super ().__init__ (merged_shape = merged_shape , cropped_shape = cropped_shape , device = device )
105
114
if not self .merged_shape :
@@ -152,12 +161,21 @@ def finalize(self) -> torch.Tensor:
152
161
153
162
return self .values
154
163
164
+ def get_output (self ) -> torch .Tensor :
165
+ """
166
+ Get the final merged output.
167
+
168
+ Returns:
169
+ torch.Tensor: merged output.
170
+ """
171
+ return self .finalize ()
172
+
155
173
def get_values (self ) -> torch .Tensor :
156
174
"""
157
175
Get the accumulated values during aggregation or final averaged values after it is finalized.
158
176
159
177
Returns:
160
- Merged (averaged) output tensor .
178
+ torch.tensor: aggregated values .
161
179
162
180
Notes:
163
181
- If called before calling `finalize()`, this method returns the accumulating values.
@@ -170,6 +188,195 @@ def get_counts(self) -> torch.Tensor:
170
188
Get the aggregator tensor for number of samples.
171
189
172
190
Returns:
173
- torch.Tensor: Number of accumulated samples at each location.
191
+ torch.Tensor: number of accumulated samples at each location.
174
192
"""
175
193
return self .counts
194
+
195
+
196
+ @require_pkg (pkg_name = "zarr" )
197
+ class ZarrAvgMerger (Merger ):
198
+ """Merge patches by taking average of the overlapping area and store the results in zarr array.
199
+
200
+ Zarr is a format for the storage of chunked, compressed, N-dimensional arrays.
201
+ Zarr data can be stored in any storage system that can be represented as a key-value store,
202
+ like POSIX file systems, cloud object storage, zip files, and relational and document databases.
203
+ See https://zarr.readthedocs.io/en/stable/ for more details.
204
+ It is particularly useful for storing N-dimensional arrays too large to fit into memory.
205
+ One specific use case of this class is to merge patches extracted from whole slide images (WSI),
206
+ where the merged results do not fit into memory and need to be stored on a file system.
207
+
208
+ Args:
209
+ merged_shape: the shape of the tensor required to merge the patches.
210
+ cropped_shape: the shape of the final merged output tensor.
211
+ If not provided, it will be the same as `merged_shape`.
212
+ dtype: the dtype for the final merged result. Default is `float32`.
213
+ value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`.
214
+ count_dtype: the dtype for sample counting tensor. Default is `uint8`.
215
+ store: the zarr store to save the final results. Default is "merged.zarr".
216
+ value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
217
+ count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
218
+ compressor: the compressor for final merged zarr array. Default is "default".
219
+ value_compressor: the compressor for value aggregating zarr array. Default is None.
220
+ count_compressor: the compressor for sample counting zarr array. Default is None.
221
+ chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
222
+ If True, chunk shape will be guessed from `shape` and `dtype`.
223
+ If False, it will be set to `shape`, i.e., single chunk for the whole array.
224
+ If an int, the chunk size in each dimension will be given by the value of `chunks`.
225
+ """
226
+
227
+ def __init__ (
228
+ self ,
229
+ merged_shape : Sequence [int ],
230
+ cropped_shape : Sequence [int ] | None = None ,
231
+ dtype : np .dtype | str = "float32" ,
232
+ value_dtype : np .dtype | str = "float32" ,
233
+ count_dtype : np .dtype | str = "uint8" ,
234
+ store : zarr .storage .Store | str = "merged.zarr" ,
235
+ value_store : zarr .storage .Store | str | None = None ,
236
+ count_store : zarr .storage .Store | str | None = None ,
237
+ compressor : str = "default" ,
238
+ value_compressor : str | None = None ,
239
+ count_compressor : str | None = None ,
240
+ chunks : Sequence [int ] | bool = True ,
241
+ thread_locking : bool = True ,
242
+ ) -> None :
243
+ super ().__init__ (merged_shape = merged_shape , cropped_shape = cropped_shape )
244
+ if not self .merged_shape :
245
+ raise ValueError (f"`merged_shape` must be provided for `ZarrAvgMerger`. { self .merged_shape } is give." )
246
+ self .output_dtype = dtype
247
+ self .value_dtype = value_dtype
248
+ self .count_dtype = count_dtype
249
+ self .store = store
250
+ self .value_store = zarr .storage .TempStore () if value_store is None else value_store
251
+ self .count_store = zarr .storage .TempStore () if count_store is None else count_store
252
+ self .chunks = chunks
253
+ self .compressor = compressor
254
+ self .value_compressor = value_compressor
255
+ self .count_compressor = count_compressor
256
+ self .output = zarr .empty (
257
+ shape = self .merged_shape ,
258
+ chunks = self .chunks ,
259
+ dtype = self .output_dtype ,
260
+ compressor = self .compressor ,
261
+ store = self .store ,
262
+ overwrite = True ,
263
+ )
264
+ self .values = zarr .zeros (
265
+ shape = self .merged_shape ,
266
+ chunks = self .chunks ,
267
+ dtype = self .value_dtype ,
268
+ compressor = self .value_compressor ,
269
+ store = self .value_store ,
270
+ overwrite = True ,
271
+ )
272
+ self .counts = zarr .zeros (
273
+ shape = self .merged_shape ,
274
+ chunks = self .chunks ,
275
+ dtype = self .count_dtype ,
276
+ compressor = self .count_compressor ,
277
+ store = self .count_store ,
278
+ overwrite = True ,
279
+ )
280
+ self .lock : threading .Lock | nullcontext
281
+ if thread_locking :
282
+ # use lock to protect the in-place addition during aggregation
283
+ self .lock = threading .Lock ()
284
+ else :
285
+ # use nullcontext to avoid the locking if not needed
286
+ self .lock = nullcontext ()
287
+
288
+ def aggregate (self , values : torch .Tensor , location : Sequence [int ]) -> None :
289
+ """
290
+ Aggregate values for merging.
291
+
292
+ Args:
293
+ values: a tensor of shape BCHW[D], representing the values of inference output.
294
+ location: a tuple/list giving the top left location of the patch in the original image.
295
+ """
296
+ if self .is_finalized :
297
+ raise ValueError ("`ZarrAvgMerger` is already finalized. Please instantiate a new object to aggregate." )
298
+ patch_size = values .shape [2 :]
299
+ map_slice = tuple (slice (loc , loc + size ) for loc , size in zip (location , patch_size ))
300
+ map_slice = ensure_tuple_size (map_slice , values .ndim , pad_val = slice (None ), pad_from_start = True )
301
+ with self .lock :
302
+ self .values [map_slice ] += values .numpy ()
303
+ self .counts [map_slice ] += 1
304
+
305
+ def finalize (self ) -> zarr .Array :
306
+ """
307
+ Finalize merging by dividing values by counts and return the merged tensor.
308
+
309
+ Notes:
310
+ To avoid creating a new tensor for the final results (to save memory space),
311
+ after this method is called, `get_values()` method will return the "final" averaged values,
312
+ and not the accumulating values. Also calling `finalize()` multiple times does not have any effect.
313
+
314
+ Returns:
315
+ zarr.Array: a zarr array of of merged patches
316
+ """
317
+ # guard against multiple calls to finalize
318
+ if not self .is_finalized :
319
+ # use chunks for division to fit into memory
320
+ for chunk in iterate_over_chunks (self .values .chunks , self .values .cdata_shape ):
321
+ self .output [chunk ] = self .values [chunk ] / self .counts [chunk ]
322
+ # finalize the shape
323
+ self .output .resize (self .cropped_shape )
324
+ # set finalize flag to protect performing in-place division again
325
+ self .is_finalized = True
326
+
327
+ return self .output
328
+
329
+ def get_output (self ) -> zarr .Array :
330
+ """
331
+ Get the final merged output.
332
+
333
+ Returns:
334
+ zarr.Array: Merged (averaged) output tensor.
335
+ """
336
+ return self .output
337
+
338
+ def get_values (self ) -> zarr .Array :
339
+ """
340
+ Get the accumulated values during aggregation
341
+
342
+ Returns:
343
+ zarr.Array: aggregated values.
344
+
345
+ """
346
+ return self .values
347
+
348
+ def get_counts (self ) -> zarr .Array :
349
+ """
350
+ Get the aggregator tensor for number of samples.
351
+
352
+ Returns:
353
+ zarr.Array: Number of accumulated samples at each location.
354
+ """
355
+ return self .counts
356
+
357
+
358
+ def iterate_over_chunks (chunks , cdata_shape , slice_tuple = ()):
359
+ """
360
+ Iterate over chunks of a given shape.
361
+
362
+ Args:
363
+ chunks: the chunk shape
364
+ cdata_shape: the shape of the data in chunks
365
+ slice_tuple: the slice tuple to be used for indexing
366
+
367
+ Raises:
368
+ ValueError: When the length of chunks and cdata_shape are not the same.
369
+
370
+ Yields:
371
+ slices of the data
372
+ """
373
+ if len (chunks ) != len (cdata_shape ):
374
+ raise ValueError ("chunks and cdata_shape must have the same length" )
375
+ if len (chunks ) == 1 :
376
+ for i in range (cdata_shape [0 ]):
377
+ yield slice_tuple + (slice (i * chunks [0 ], (i + 1 ) * chunks [0 ]),)
378
+ else :
379
+ for i in range (cdata_shape [0 ]):
380
+ yield from iterate_over_chunks (
381
+ chunks [1 :], cdata_shape [1 :], slice_tuple + (slice (i * chunks [0 ], (i + 1 ) * chunks [0 ]),)
382
+ )
0 commit comments