Skip to content

Commit 10203f3

Browse files
Nic-Mawyli
andauthored
Unify all tests to use tempdir (#450)
* [DLMED] fix flake8 error * [DLMED] revert * [DLMED] update all tests to use tempdir * [DLMED] fix flake8 issue Co-authored-by: Wenqi Li <[email protected]>
1 parent f37a430 commit 10203f3

File tree

5 files changed

+60
-53
lines changed

5 files changed

+60
-53
lines changed

monai/handlers/tensorboard_handlers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class TensorBoardStatsHandler(object):
3636
def __init__(
3737
self,
3838
summary_writer=None,
39+
log_dir="./runs",
3940
epoch_event_writer=None,
4041
iteration_event_writer=None,
4142
output_transform=lambda x: x,
@@ -46,6 +47,7 @@ def __init__(
4647
Args:
4748
summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter,
4849
default to create a new writer.
50+
log_dir (str): if using default SummaryWriter, write logs to this directory, default is `./runs`.
4951
epoch_event_writer (Callable): customized callable TensorBoard writer for epoch level.
5052
must accept parameter "engine" and "summary_writer", use default event writer if None.
5153
iteration_event_writer (Callable): customized callable TensorBoard writer for iteration level.
@@ -59,7 +61,7 @@ def __init__(
5961
when plotting epoch vs metric curves.
6062
tag_name (string): when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``.
6163
"""
62-
self._writer = SummaryWriter() if summary_writer is None else summary_writer
64+
self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer
6365
self.epoch_event_writer = epoch_event_writer
6466
self.iteration_event_writer = iteration_event_writer
6567
self.output_transform = output_transform
@@ -179,6 +181,7 @@ class TensorBoardImageHandler(object):
179181
def __init__(
180182
self,
181183
summary_writer=None,
184+
log_dir="./runs",
182185
batch_transform=lambda x: x,
183186
output_transform=lambda x: x,
184187
global_iter_transform=lambda x: x,
@@ -190,6 +193,7 @@ def __init__(
190193
Args:
191194
summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter,
192195
default to create a new writer.
196+
log_dir (str): if using default SummaryWriter, write logs to this directory, default is `./runs`.
193197
batch_transform (Callable): a callable that is used to transform the
194198
``ignite.engine.batch`` into expected format to extract several label data.
195199
output_transform (Callable): a callable that is used to transform the
@@ -200,7 +204,7 @@ def __init__(
200204
max_channels (int): number of channels to plot.
201205
max_frames (int): number of frames for 2D-t plot.
202206
"""
203-
self._writer = SummaryWriter() if summary_writer is None else summary_writer
207+
self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer
204208
self.batch_transform = batch_transform
205209
self.output_transform = output_transform
206210
self.global_iter_transform = global_iter_transform

tests/test_handler_tb_image.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import os
1414
import shutil
1515
import unittest
16-
16+
import tempfile
1717
import numpy as np
1818
import torch
1919
from ignite.engine import Engine, Events
@@ -34,8 +34,8 @@
3434
class TestHandlerTBImage(unittest.TestCase):
3535
@parameterized.expand(TEST_CASES)
3636
def test_tb_image_shape(self, shape):
37-
default_dir = os.path.join(".", "runs")
38-
shutil.rmtree(default_dir, ignore_errors=True)
37+
tempdir = tempfile.mkdtemp()
38+
shutil.rmtree(tempdir, ignore_errors=True)
3939

4040
# set up engine
4141
def _train_func(engine, batch):
@@ -44,15 +44,15 @@ def _train_func(engine, batch):
4444
engine = Engine(_train_func)
4545

4646
# set up testing handler
47-
stats_handler = TensorBoardImageHandler()
47+
stats_handler = TensorBoardImageHandler(log_dir=tempdir)
4848
engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler)
4949

5050
data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape)))
5151
engine.run(data, epoch_length=10, max_epochs=1)
5252

53-
self.assertTrue(os.path.exists(default_dir))
54-
self.assertTrue(len(glob.glob(default_dir)) > 0)
55-
shutil.rmtree(default_dir)
53+
self.assertTrue(os.path.exists(tempdir))
54+
self.assertTrue(len(glob.glob(tempdir)) > 0)
55+
shutil.rmtree(tempdir)
5656

5757

5858
if __name__ == "__main__":

tests/test_handler_tb_stats.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
class TestHandlerTBStats(unittest.TestCase):
2525
def test_metrics_print(self):
26-
default_dir = os.path.join(".", "runs")
27-
shutil.rmtree(default_dir, ignore_errors=True)
26+
tempdir = tempfile.mkdtemp()
27+
shutil.rmtree(tempdir, ignore_errors=True)
2828

2929
# set up engine
3030
def _train_func(engine, batch):
@@ -39,41 +39,41 @@ def _update_metric(engine):
3939
engine.state.metrics["acc"] = current_metric + 0.1
4040

4141
# set up testing handler
42-
stats_handler = TensorBoardStatsHandler()
42+
stats_handler = TensorBoardStatsHandler(log_dir=tempdir)
4343
stats_handler.attach(engine)
4444
engine.run(range(3), max_epochs=2)
4545
# check logging output
4646

47-
self.assertTrue(os.path.exists(default_dir))
48-
shutil.rmtree(default_dir)
47+
self.assertTrue(os.path.exists(tempdir))
48+
shutil.rmtree(tempdir)
4949

5050
def test_metrics_writer(self):
51-
default_dir = os.path.join(".", "runs")
52-
shutil.rmtree(default_dir, ignore_errors=True)
53-
with tempfile.TemporaryDirectory() as temp_dir:
54-
55-
# set up engine
56-
def _train_func(engine, batch):
57-
return batch + 1.0
58-
59-
engine = Engine(_train_func)
60-
61-
# set up dummy metric
62-
@engine.on(Events.EPOCH_COMPLETED)
63-
def _update_metric(engine):
64-
current_metric = engine.state.metrics.get("acc", 0.1)
65-
engine.state.metrics["acc"] = current_metric + 0.1
66-
67-
# set up testing handler
68-
writer = SummaryWriter(log_dir=temp_dir)
69-
stats_handler = TensorBoardStatsHandler(
70-
writer, output_transform=lambda x: {"loss": x * 2.0}, global_epoch_transform=lambda x: x * 3.0
71-
)
72-
stats_handler.attach(engine)
73-
engine.run(range(3), max_epochs=2)
74-
# check logging output
75-
self.assertTrue(len(glob.glob(temp_dir)) > 0)
76-
self.assertTrue(not os.path.exists(default_dir))
51+
tempdir = tempfile.mkdtemp()
52+
shutil.rmtree(tempdir, ignore_errors=True)
53+
54+
# set up engine
55+
def _train_func(engine, batch):
56+
return batch + 1.0
57+
58+
engine = Engine(_train_func)
59+
60+
# set up dummy metric
61+
@engine.on(Events.EPOCH_COMPLETED)
62+
def _update_metric(engine):
63+
current_metric = engine.state.metrics.get("acc", 0.1)
64+
engine.state.metrics["acc"] = current_metric + 0.1
65+
66+
# set up testing handler
67+
writer = SummaryWriter(log_dir=tempdir)
68+
stats_handler = TensorBoardStatsHandler(
69+
writer, output_transform=lambda x: {"loss": x * 2.0}, global_epoch_transform=lambda x: x * 3.0
70+
)
71+
stats_handler.attach(engine)
72+
engine.run(range(3), max_epochs=2)
73+
# check logging output
74+
self.assertTrue(os.path.exists(tempdir))
75+
self.assertTrue(len(glob.glob(tempdir)) > 0)
76+
shutil.rmtree(tempdir)
7777

7878

7979
if __name__ == "__main__":

tests/test_integration_sliding_window.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212
import os
13+
import shutil
1314
import tempfile
1415
import unittest
1516

@@ -76,13 +77,14 @@ def tearDown(self):
7677
os.remove(self.seg_name)
7778

7879
def test_training(self):
79-
with tempfile.TemporaryDirectory() as temp_dir:
80-
output_file = run_test(
81-
batch_size=2, img_name=self.img_name, seg_name=self.seg_name, output_dir=temp_dir, device=self.device
82-
)
83-
output_image = nib.load(output_file).get_fdata()
84-
np.testing.assert_allclose(np.sum(output_image), 34070)
85-
np.testing.assert_allclose(output_image.shape, (28, 25, 63, 1))
80+
tempdir = tempfile.mkdtemp()
81+
output_file = run_test(
82+
batch_size=2, img_name=self.img_name, seg_name=self.seg_name, output_dir=tempdir, device=self.device
83+
)
84+
output_image = nib.load(output_file).get_fdata()
85+
np.testing.assert_allclose(np.sum(output_image), 34070)
86+
np.testing.assert_allclose(output_image.shape, (28, 25, 63, 1))
87+
shutil.rmtree(tempdir)
8688

8789

8890
if __name__ == "__main__":

tests/test_plot_2d_or_3d_image.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import glob
1313
import os
14+
import tempfile
1415
import shutil
1516
import unittest
1617
from torch.utils.tensorboard import SummaryWriter
@@ -32,14 +33,14 @@
3233
class TestPlot2dOr3dImage(unittest.TestCase):
3334
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
3435
def test_tb_image_shape(self, shape):
35-
default_dir = os.path.join(".", "runs")
36-
shutil.rmtree(default_dir, ignore_errors=True)
36+
tempdir = tempfile.mkdtemp()
37+
shutil.rmtree(tempdir, ignore_errors=True)
3738

38-
plot_2d_or_3d_image(torch.zeros(shape), 0, SummaryWriter())
39+
plot_2d_or_3d_image(torch.zeros(shape), 0, SummaryWriter(log_dir=tempdir))
3940

40-
self.assertTrue(os.path.exists(default_dir))
41-
self.assertTrue(len(glob.glob(default_dir)) > 0)
42-
shutil.rmtree(default_dir, ignore_errors=True)
41+
self.assertTrue(os.path.exists(tempdir))
42+
self.assertTrue(len(glob.glob(tempdir)) > 0)
43+
shutil.rmtree(tempdir, ignore_errors=True)
4344

4445

4546
if __name__ == "__main__":

0 commit comments

Comments
 (0)