Skip to content

Commit bb81a23

Browse files
[ReadyForReview] Auto3DSeg DataAnalyzer OOM and other minor issue (#5278)
Fixes #5277 . ### Updated results In my local test env, I have the following results: - The change of GPU memory before/after DataAnalyzer is less than 5MB after the fix. Previously, there are lots of cached PyTorch tensors and CuPy variables that are not released for trainings that takes up to several GBs of GPU mem. - DataAnalyzer can also process larger images now because leaks are fix (3D image with a size 512x512x512 passed for 12GB RTX 3080Ti) ### Description Auto3DSeg DataAnalyzer occupied a large trunk of memory and was unable to release them during the training. The reasons behind are possibly due to: - Training are done by subprocess call, and PyTorch in the subprocess is unable to find the memory pool allocated by the main process - GPU memory leakage ( DataAnalyzer math operations uses torch functions and CuPy) plus test functions need improvements and AutoRunner needs to expose the API call to change device of DataAnalyzer ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Mingxin Zheng <[email protected]>
1 parent b7403ee commit bb81a23

File tree

8 files changed

+174
-112
lines changed

8 files changed

+174
-112
lines changed

.github/workflows/pythonapp-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- "PT17+CUDA102"
2424
- "PT18+CUDA102"
2525
- "PT18+CUDA112"
26-
- "PT112+CUDA117"
26+
- "PT112+CUDA118"
2727
- "PT110+CUDA102"
2828
- "PT112+CUDA102"
2929
include:

monai/apps/auto3dseg/data_analyzer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def __init__(
122122
output_path: str = "./data_stats.yaml",
123123
average: bool = True,
124124
do_ccp: bool = True,
125-
device: Union[str, torch.device] = "cuda",
126-
worker: int = 0,
125+
device: Union[str, torch.device] = "cpu",
126+
worker: int = 2,
127127
image_key: str = "image",
128128
label_key: Optional[str] = "label",
129129
):
@@ -137,13 +137,10 @@ def __init__(
137137
self.average = average
138138
self.do_ccp = do_ccp
139139
self.device = torch.device(device)
140-
self.worker = worker
140+
self.worker = 0 if (self.device.type == "cuda") else worker
141141
self.image_key = image_key
142142
self.label_key = label_key
143143

144-
if (self.device.type == "cuda") and (worker > 0):
145-
raise ValueError("CUDA does not support multiple subprocess. If device is GPU, please set worker to 0")
146-
147144
@staticmethod
148145
def _check_data_uniformity(keys: List[str], result: Dict):
149146
"""
@@ -232,8 +229,14 @@ def get_all_case_stats(self):
232229
result[DataStatsKeys.SUMMARY] = summarizer.summarize(result[DataStatsKeys.BY_CASE])
233230

234231
if not self._check_data_uniformity([ImageStatsKeys.SPACING], result):
235-
logger.warning("Data is not completely uniform. MONAI transforms may provide unexpected result")
232+
logger.warning("data spacing is not completely uniform. MONAI transforms may provide unexpected result")
236233

237234
ConfigParser.export_config_file(result, self.output_path, fmt="yaml", default_flow_style=None)
238235

236+
del d["image"], d["label"]
237+
if self.device.type == "cuda":
238+
# release unreferenced tensors to mitigate OOM
239+
# limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237
240+
torch.cuda.empty_cache()
241+
239242
return result

monai/auto3dseg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from .algo_gen import Algo, AlgoGen
1313
from .analyzer import (
14+
Analyzer,
1415
FgImageStats,
1516
FgImageStatsSumm,
1617
FilenameStats,

monai/auto3dseg/analyzer.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,10 @@ def __call__(self, data):
229229
"""
230230
d = dict(data)
231231
start = time.time()
232-
ndas = data[self.image_key]
233-
ndas = [ndas[i] for i in range(ndas.shape[0])]
232+
restore_grad_state = torch.is_grad_enabled()
233+
torch.set_grad_enabled(False)
234+
235+
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
234236
if "nda_croppeds" not in d:
235237
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
236238

@@ -250,8 +252,10 @@ def __call__(self, data):
250252
if not verify_report_format(report, self.get_report_format()):
251253
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
252254

253-
logger.debug(f"Get image stats spent {time.time()-start}")
254255
d[self.stats_name] = report
256+
257+
torch.set_grad_enabled(restore_grad_state)
258+
logger.debug(f"Get image stats spent {time.time()-start}")
255259
return d
256260

257261

@@ -307,9 +311,11 @@ def __call__(self, data) -> dict:
307311
"""
308312

309313
d = dict(data)
314+
start = time.time()
315+
restore_grad_state = torch.is_grad_enabled()
316+
torch.set_grad_enabled(False)
310317

311-
ndas = d[self.image_key] # (1,H,W,D) or (C,H,W,D)
312-
ndas = [ndas[i] for i in range(ndas.shape[0])]
318+
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
313319
ndas_label = d[self.label_key] # (H,W,D)
314320
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
315321

@@ -324,6 +330,9 @@ def __call__(self, data) -> dict:
324330
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
325331

326332
d[self.stats_name] = report
333+
334+
torch.set_grad_enabled(restore_grad_state)
335+
logger.debug(f"Get foreground image stats spent {time.time()-start}")
327336
return d
328337

329338

@@ -423,9 +432,12 @@ def __call__(self, data):
423432
functions. If the input has nan/inf, the stats results will be nan/inf.
424433
"""
425434
d = dict(data)
435+
start = time.time()
436+
using_cuda = True if d[self.image_key].device.type == "cuda" else False
437+
restore_grad_state = torch.is_grad_enabled()
438+
torch.set_grad_enabled(False)
426439

427-
ndas = d[self.image_key] # (1,H,W,D) or (C,H,W,D)
428-
ndas = [ndas[i] for i in range(ndas.shape[0])]
440+
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
429441
ndas_label = d[self.label_key] # (H,W,D)
430442
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
431443

@@ -435,7 +447,6 @@ def __call__(self, data):
435447

436448
unique_label = unique_label.astype(np.int8).tolist()
437449

438-
start = time.time()
439450
label_substats = [] # each element is one label
440451
pixel_sum = 0
441452
pixel_arr = []
@@ -444,13 +455,20 @@ def __call__(self, data):
444455
label_dict: Dict[str, Any] = {}
445456
mask_index = ndas_label == index
446457

458+
nda_masks = [nda[mask_index] for nda in ndas]
447459
label_dict[LabelStatsKeys.IMAGE_INTST] = [
448-
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda[mask_index]) for nda in ndas
460+
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
449461
]
462+
450463
pixel_count = sum(mask_index)
451464
pixel_arr.append(pixel_count)
452465
pixel_sum += pixel_count
453466
if self.do_ccp: # apply connected component
467+
if using_cuda:
468+
# The back end of get_label_ccp is CuPy
469+
# which is unable to automatically release CUDA GPU memory held by PyTorch
470+
del nda_masks
471+
torch.cuda.empty_cache()
454472
shape_list, ncomponents = get_label_ccp(mask_index)
455473
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
456474
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents
@@ -472,6 +490,8 @@ def __call__(self, data):
472490
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
473491

474492
d[self.stats_name] = report
493+
494+
torch.set_grad_enabled(restore_grad_state)
475495
logger.debug(f"Get label stats spent {time.time()-start}")
476496
return d
477497

monai/auto3dseg/seg_summarizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def add_analyzer(self, case_analyzer, summary_analyzer) -> None:
104104
105105
.. code-block:: python
106106
107-
from monai.auto3dseg.analyzer import Analyzer
107+
from monai.auto3dseg import Analyzer
108108
from monai.auto3dseg.utils import concat_val_to_np
109109
from monai.auto3dseg.analyzer_engine import SegSummarizer
110110

monai/auto3dseg/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def get_label_ccp(mask_index: MetaTensor, use_gpu: bool = True) -> Tuple[List[An
106106
shape_list.append(bbox_shape)
107107
ncomponents = len(vals)
108108

109+
del mask_cupy, labeled, vals, comp_idx, ncomp
110+
cp.get_default_memory_pool().free_all_blocks()
111+
109112
elif has_measure:
110113
labeled, ncomponents = measure_np.label(mask_index.data.cpu().numpy(), background=-1, return_num=True)
111114
for ncomp in range(1, ncomponents + 1):
@@ -174,7 +177,7 @@ def concat_val_to_np(
174177
elif ragged:
175178
return np.concatenate(np_list, **kwargs) # type: ignore
176179
else:
177-
return np.concatenate([np_list], **kwargs)
180+
return np.concatenate([np_list], **kwargs) # type: ignore
178181

179182

180183
def concat_multikeys_to_dict(

0 commit comments

Comments
 (0)