Skip to content

Commit 52e6278

Browse files
authored
883 Fix exception issue in StatsHandler and CheckpointSaver (#890)
* [DLMED] update handlers * [DLMED] add more doc-strings * [DLMED] add more doc-string * [DLMED] update according to comments * [DLMED] update according to Eric's comments
1 parent 04e1c34 commit 52e6278

File tree

4 files changed

+64
-26
lines changed

4 files changed

+64
-26
lines changed

monai/handlers/checkpoint_saver.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ class CheckpointSaver:
3737
name: identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
3838
file_prefix: prefix for the filenames to which objects will be saved.
3939
save_final: whether to save checkpoint or session at final iteration or exception.
40+
If checkpoints are to be saved when an exception is raised, put this handler before
41+
`StatsHandler` in the handler list, because the logic with Ignite can only trigger
42+
the first attached handler for `EXCEPTION_RAISED` event.
4043
save_key_metric: whether to save checkpoint or session when the value of key_metric is
4144
higher than all the previous values during training.keep 4 decimal places of metric,
4245
checkpoint name is: {file_prefix}_key_metric=0.XXXX.pth.
4346
key_metric_name: the name of key_metric in ignite metrics dictionary.
44-
if None, use `engine.state.key_metric` instead.
47+
If None, use `engine.state.key_metric` instead.
4548
key_metric_n_saved: save top N checkpoints or sessions, sorted by the value of key
4649
metric in descending order.
4750
epoch_level: save checkpoint during training for every N epochs or every N iterations.
@@ -168,7 +171,8 @@ def completed(self, engine: Engine) -> None:
168171

169172
def exception_raised(self, engine: Engine, e: Exception) -> None:
170173
"""Callback for train or validation/evaluation exception raised Event.
171-
Save current data as final checkpoint if configure save_final is True.
174+
Save current data as final checkpoint if configure save_final is True. This callback may be skipped
175+
because the logic with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.
172176
173177
Args:
174178
engine: Ignite Engine, it can be a trainer, validator or evaluator.
@@ -179,6 +183,7 @@ def exception_raised(self, engine: Engine, e: Exception) -> None:
179183
assert self.logger is not None
180184
assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute."
181185
self.logger.info(f"Exception_raised, saved exception checkpoint: {self._final_checkpoint.last_checkpoint}")
186+
raise e
182187

183188
def metrics_completed(self, engine: Engine) -> None:
184189
"""Callback to compare metrics and save models in train or validation when epoch completed.

monai/handlers/stats_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
tag_name: scalar_value to logger. Defaults to ``'Loss'``.
7171
key_var_format: a formatting string to control the output string format of key: value.
7272
logger_handler: add additional handler to handle the stats data: save to file, etc.
73-
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
73+
Add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
7474
"""
7575

7676
self.epoch_print_logger = epoch_print_logger
@@ -133,16 +133,16 @@ def iteration_completed(self, engine: Engine) -> None:
133133
def exception_raised(self, engine: Engine, e: Exception) -> None:
134134
"""
135135
Handler for train or validation/evaluation exception raised Event.
136-
Print the exception information and traceback.
136+
Print the exception information and traceback. This callback may be skipped because the logic
137+
with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.
137138
138139
Args:
139140
engine: Ignite Engine, it can be a trainer, validator or evaluator.
140141
e: the exception caught in Ignite during engine.run().
141142
142143
"""
143144
self.logger.exception(f"Exception: {e}")
144-
# import traceback
145-
# traceback.print_exc()
145+
raise e
146146

147147
def _default_epoch_print(self, engine: Engine) -> None:
148148
"""

tests/test_handler_checkpoint_saver.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import logging
1313
import os
14-
import shutil
1514
import sys
1615
import tempfile
1716
import unittest
@@ -80,25 +79,43 @@ def _train_func(engine, batch):
8079
if multi_devices:
8180
net = torch.nn.DataParallel(net)
8281
optimizer = optim.SGD(net.parameters(), lr=0.02)
83-
tempdir = tempfile.mkdtemp()
84-
handler = CheckpointSaver(
85-
tempdir,
86-
{"net": net, "opt": optimizer},
87-
"CheckpointSaver",
88-
"test",
89-
save_final,
90-
save_key_metric,
91-
key_metric_name,
92-
key_metric_n_saved,
93-
epoch_level,
94-
save_interval,
95-
n_saved,
96-
)
97-
handler.attach(engine)
98-
engine.run(data, max_epochs=5)
99-
for filename in filenames:
100-
self.assertTrue(os.path.exists(os.path.join(tempdir, filename)))
101-
shutil.rmtree(tempdir)
82+
with tempfile.TemporaryDirectory() as tempdir:
83+
handler = CheckpointSaver(
84+
tempdir,
85+
{"net": net, "opt": optimizer},
86+
"CheckpointSaver",
87+
"test",
88+
save_final,
89+
save_key_metric,
90+
key_metric_name,
91+
key_metric_n_saved,
92+
epoch_level,
93+
save_interval,
94+
n_saved,
95+
)
96+
handler.attach(engine)
97+
engine.run(data, max_epochs=5)
98+
for filename in filenames:
99+
self.assertTrue(os.path.exists(os.path.join(tempdir, filename)))
100+
101+
def test_exception(self):
102+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
103+
net = torch.nn.PReLU()
104+
105+
# set up engine
106+
def _train_func(engine, batch):
107+
raise RuntimeError("test exception.")
108+
109+
engine = Engine(_train_func)
110+
111+
# set up testing handler
112+
with tempfile.TemporaryDirectory() as tempdir:
113+
stats_handler = CheckpointSaver(tempdir, {"net": net}, save_final=True)
114+
stats_handler.attach(engine)
115+
116+
with self.assertRaises(RuntimeError):
117+
engine.run(range(3), max_epochs=2)
118+
self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pth")))
102119

103120

104121
if __name__ == "__main__":

tests/test_handler_stats.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,22 @@ def _train_func(engine, batch):
143143
self.assertTrue(has_key_word.match(line))
144144
shutil.rmtree(tempdir)
145145

146+
def test_exception(self):
147+
logging.basicConfig(level=logging.INFO)
148+
149+
# set up engine
150+
def _train_func(engine, batch):
151+
raise RuntimeError("test exception.")
152+
153+
engine = Engine(_train_func)
154+
155+
# set up testing handler
156+
stats_handler = StatsHandler()
157+
stats_handler.attach(engine)
158+
159+
with self.assertRaises(RuntimeError):
160+
engine.run(range(3), max_epochs=2)
161+
146162

147163
if __name__ == "__main__":
148164
unittest.main()

0 commit comments

Comments
 (0)