Skip to content

Commit 3181e3e

Browse files
authored
Remove unneccessary tempdir (#446)
* [DLMED] remove tempdir * [DLMED] update all new features * [DLMED] delete tempdir
1 parent 253d1aa commit 3181e3e

10 files changed

+116
-108
lines changed

tests/test_arraydataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import shutil
1415
import numpy as np
1516
import tempfile
1617
import nibabel as nib
@@ -68,6 +69,7 @@ def test_shape(self, img_transform, label_transform, indexes, expected_shape):
6869
self.assertTupleEqual(data2[indexes[0]].shape, expected_shape)
6970
self.assertTupleEqual(data2[indexes[1]].shape, expected_shape)
7071
np.testing.assert_allclose(data2[indexes[0]], data2[indexes[0]])
72+
shutil.rmtree(tempdir)
7173

7274

7375
if __name__ == "__main__":

tests/test_data_stats.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import shutil
1415
import logging
1516
import tempfile
1617
import numpy as np
@@ -97,25 +98,25 @@ def test_value(self, input_param, input_data, expected_print):
9798

9899
@parameterized.expand([TEST_CASE_6])
99100
def test_file(self, input_data, expected_print):
100-
with tempfile.TemporaryDirectory() as tempdir:
101-
filename = os.path.join(tempdir, "test_stats.log")
102-
handler = logging.FileHandler(filename, mode="w")
103-
input_param = {
104-
"prefix": "test data",
105-
"data_shape": True,
106-
"intensity_range": True,
107-
"data_value": True,
108-
"additional_info": lambda x: np.mean(x),
109-
"logger_handler": handler,
110-
}
111-
transform = DataStats(**input_param)
112-
_ = transform(input_data)
113-
handler.stream.close()
114-
transform._logger.removeHandler(handler)
115-
with open(filename, "r") as f:
116-
content = f.read()
117-
self.assertEqual(content, expected_print)
118-
os.remove(filename)
101+
tempdir = tempfile.mkdtemp()
102+
filename = os.path.join(tempdir, "test_stats.log")
103+
handler = logging.FileHandler(filename, mode="w")
104+
input_param = {
105+
"prefix": "test data",
106+
"data_shape": True,
107+
"intensity_range": True,
108+
"data_value": True,
109+
"additional_info": lambda x: np.mean(x),
110+
"logger_handler": handler,
111+
}
112+
transform = DataStats(**input_param)
113+
_ = transform(input_data)
114+
handler.stream.close()
115+
transform._logger.removeHandler(handler)
116+
with open(filename, "r") as f:
117+
content = f.read()
118+
self.assertEqual(content, expected_print)
119+
shutil.rmtree(tempdir)
119120

120121

121122
if __name__ == "__main__":

tests/test_data_statsd.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import shutil
1415
import logging
1516
import tempfile
1617
import numpy as np
@@ -110,26 +111,26 @@ def test_value(self, input_param, input_data, expected_print):
110111

111112
@parameterized.expand([TEST_CASE_7])
112113
def test_file(self, input_data, expected_print):
113-
with tempfile.TemporaryDirectory() as tempdir:
114-
filename = os.path.join(tempdir, "test_stats.log")
115-
handler = logging.FileHandler(filename, mode="w")
116-
input_param = {
117-
"keys": "img",
118-
"prefix": "test data",
119-
"data_shape": True,
120-
"intensity_range": True,
121-
"data_value": True,
122-
"additional_info": lambda x: np.mean(x),
123-
"logger_handler": handler,
124-
}
125-
transform = DataStatsd(**input_param)
126-
_ = transform(input_data)
127-
handler.stream.close()
128-
transform.printer._logger.removeHandler(handler)
129-
with open(filename, "r") as f:
130-
content = f.read()
131-
self.assertEqual(content, expected_print)
132-
os.remove(filename)
114+
tempdir = tempfile.mkdtemp()
115+
filename = os.path.join(tempdir, "test_stats.log")
116+
handler = logging.FileHandler(filename, mode="w")
117+
input_param = {
118+
"keys": "img",
119+
"prefix": "test data",
120+
"data_shape": True,
121+
"intensity_range": True,
122+
"data_value": True,
123+
"additional_info": lambda x: np.mean(x),
124+
"logger_handler": handler,
125+
}
126+
transform = DataStatsd(**input_param)
127+
_ = transform(input_data)
128+
handler.stream.close()
129+
transform.printer._logger.removeHandler(handler)
130+
with open(filename, "r") as f:
131+
content = f.read()
132+
self.assertEqual(content, expected_print)
133+
shutil.rmtree(tempdir)
133134

134135

135136
if __name__ == "__main__":

tests/test_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import shutil
1415
import numpy as np
1516
import tempfile
1617
import nibabel as nib
@@ -71,6 +72,7 @@ def test_shape(self, expected_shape):
7172
self.assertTupleEqual(data2_simple["image"].shape, expected_shape)
7273
self.assertTupleEqual(data2_simple["label"].shape, expected_shape)
7374
self.assertTupleEqual(data2_simple["extra"].shape, expected_shape)
75+
shutil.rmtree(tempdir)
7476

7577

7678
if __name__ == "__main__":

tests/test_handler_checkpoint_loader.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
import os
1312
import tempfile
1413
import shutil
1514
import torch
@@ -33,15 +32,14 @@ def test_one_save_one_load(self):
3332
data2["weight"] = torch.tensor([0.2])
3433
net2.load_state_dict(data2)
3534
engine = Engine(lambda e, b: None)
36-
with tempfile.TemporaryDirectory() as tempdir:
37-
save_dir = os.path.join(tempdir, "checkpoint")
38-
CheckpointSaver(save_dir=save_dir, save_dict={"net": net1}, save_final=True).attach(engine)
39-
engine.run([0] * 8, max_epochs=5)
40-
path = save_dir + "/net_final_iteration=40.pth"
41-
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
42-
engine.run([0] * 8, max_epochs=1)
43-
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
44-
shutil.rmtree(save_dir)
35+
tempdir = tempfile.mkdtemp()
36+
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
37+
engine.run([0] * 8, max_epochs=5)
38+
path = tempdir + "/net_final_iteration=40.pth"
39+
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
40+
engine.run([0] * 8, max_epochs=1)
41+
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
42+
shutil.rmtree(tempdir)
4543

4644
def test_two_save_one_load(self):
4745
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@@ -55,16 +53,15 @@ def test_two_save_one_load(self):
5553
data2["weight"] = torch.tensor([0.2])
5654
net2.load_state_dict(data2)
5755
engine = Engine(lambda e, b: None)
58-
with tempfile.TemporaryDirectory() as tempdir:
59-
save_dir = os.path.join(tempdir, "checkpoint")
60-
save_dict = {"net": net1, "opt": optimizer}
61-
CheckpointSaver(save_dir=save_dir, save_dict=save_dict, save_final=True).attach(engine)
62-
engine.run([0] * 8, max_epochs=5)
63-
path = save_dir + "/checkpoint_final_iteration=40.pth"
64-
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
65-
engine.run([0] * 8, max_epochs=1)
66-
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
67-
shutil.rmtree(save_dir)
56+
tempdir = tempfile.mkdtemp()
57+
save_dict = {"net": net1, "opt": optimizer}
58+
CheckpointSaver(save_dir=tempdir, save_dict=save_dict, save_final=True).attach(engine)
59+
engine.run([0] * 8, max_epochs=5)
60+
path = tempdir + "/checkpoint_final_iteration=40.pth"
61+
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
62+
engine.run([0] * 8, max_epochs=1)
63+
torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1)
64+
shutil.rmtree(tempdir)
6865

6966
def test_save_single_device_load_multi_devices(self):
7067
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@@ -78,15 +75,14 @@ def test_save_single_device_load_multi_devices(self):
7875
net2.load_state_dict(data2)
7976
net2 = torch.nn.DataParallel(net2)
8077
engine = Engine(lambda e, b: None)
81-
with tempfile.TemporaryDirectory() as tempdir:
82-
save_dir = os.path.join(tempdir, "checkpoint")
83-
CheckpointSaver(save_dir=save_dir, save_dict={"net": net1}, save_final=True).attach(engine)
84-
engine.run([0] * 8, max_epochs=5)
85-
path = save_dir + "/net_final_iteration=40.pth"
86-
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
87-
engine.run([0] * 8, max_epochs=1)
88-
torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1)
89-
shutil.rmtree(save_dir)
78+
tempdir = tempfile.mkdtemp()
79+
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
80+
engine.run([0] * 8, max_epochs=5)
81+
path = tempdir + "/net_final_iteration=40.pth"
82+
CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine)
83+
engine.run([0] * 8, max_epochs=1)
84+
torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1)
85+
shutil.rmtree(tempdir)
9086

9187

9288
if __name__ == "__main__":

tests/test_handler_checkpoint_saver.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,26 +78,25 @@ def _train_func(engine, batch):
7878
if multi_devices:
7979
net = torch.nn.DataParallel(net)
8080
optimizer = optim.SGD(net.parameters(), lr=0.02)
81-
with tempfile.TemporaryDirectory() as tempdir:
82-
save_dir = os.path.join(tempdir, "checkpoint")
83-
handler = CheckpointSaver(
84-
save_dir,
85-
{"net": net, "opt": optimizer},
86-
"CheckpointSaver",
87-
"test",
88-
save_final,
89-
save_key_metric,
90-
key_metric_name,
91-
key_metric_n_saved,
92-
epoch_level,
93-
save_interval,
94-
n_saved,
95-
)
96-
handler.attach(engine)
97-
engine.run(data, max_epochs=5)
98-
for filename in filenames:
99-
self.assertTrue(os.path.exists(os.path.join(save_dir, filename)))
100-
shutil.rmtree(save_dir)
81+
tempdir = tempfile.mkdtemp()
82+
handler = CheckpointSaver(
83+
tempdir,
84+
{"net": net, "opt": optimizer},
85+
"CheckpointSaver",
86+
"test",
87+
save_final,
88+
save_key_metric,
89+
key_metric_name,
90+
key_metric_n_saved,
91+
epoch_level,
92+
save_interval,
93+
n_saved,
94+
)
95+
handler.attach(engine)
96+
engine.run(data, max_epochs=5)
97+
for filename in filenames:
98+
self.assertTrue(os.path.exists(os.path.join(tempdir, filename)))
99+
shutil.rmtree(tempdir)
101100

102101

103102
if __name__ == "__main__":

tests/test_load_nifti.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import shutil
1415
import numpy as np
1516
import tempfile
1617
import nibabel as nib
@@ -38,18 +39,20 @@ class TestLoadNifti(unittest.TestCase):
3839
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
3940
def test_shape(self, input_param, filenames, expected_shape):
4041
test_image = np.random.randint(0, 2, size=[128, 128, 128])
41-
with tempfile.TemporaryDirectory() as tempdir:
42-
for i, name in enumerate(filenames):
43-
filenames[i] = os.path.join(tempdir, name)
44-
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
45-
result = LoadNifti(**input_param)(filenames)
42+
tempdir = tempfile.mkdtemp()
43+
for i, name in enumerate(filenames):
44+
filenames[i] = os.path.join(tempdir, name)
45+
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
46+
result = LoadNifti(**input_param)(filenames)
47+
4648
if isinstance(result, tuple):
4749
result, header = result
4850
self.assertTrue("affine" in header)
4951
np.testing.assert_allclose(header["affine"], np.eye(4))
5052
if input_param["as_closest_canonical"]:
5153
np.testing.asesrt_allclose(header["original_affine"], np.eye(4))
5254
self.assertTupleEqual(result.shape, expected_shape)
55+
shutil.rmtree(tempdir)
5356

5457

5558
if __name__ == "__main__":

tests/test_load_niftid.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import shutil
1415
import numpy as np
1516
import tempfile
1617
import nibabel as nib
@@ -27,13 +28,14 @@ class TestLoadNiftid(unittest.TestCase):
2728
def test_shape(self, input_param, expected_shape):
2829
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
2930
test_data = dict()
30-
with tempfile.TemporaryDirectory() as tempdir:
31-
for key in KEYS:
32-
nib.save(test_image, os.path.join(tempdir, key + ".nii.gz"))
33-
test_data.update({key: os.path.join(tempdir, key + ".nii.gz")})
34-
result = LoadNiftid(**input_param)(test_data)
31+
tempdir = tempfile.mkdtemp()
32+
for key in KEYS:
33+
nib.save(test_image, os.path.join(tempdir, key + ".nii.gz"))
34+
test_data.update({key: os.path.join(tempdir, key + ".nii.gz")})
35+
result = LoadNiftid(**input_param)(test_data)
3536
for key in KEYS:
3637
self.assertTupleEqual(result[key].shape, expected_shape)
38+
shutil.rmtree(tempdir)
3739

3840

3941
if __name__ == "__main__":

tests/test_load_png.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import shutil
1415
import numpy as np
1516
import tempfile
1617
from PIL import Image
@@ -29,13 +30,13 @@ class TestLoadPNG(unittest.TestCase):
2930
def test_shape(self, data_shape, filenames, expected_shape, meta_shape):
3031
test_image = np.random.randint(0, 256, size=data_shape)
3132
tempdir = tempfile.mkdtemp()
32-
with tempfile.TemporaryDirectory() as tempdir:
33-
for i, name in enumerate(filenames):
34-
filenames[i] = os.path.join(tempdir, name)
35-
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
36-
result = LoadPNG()(filenames)
33+
for i, name in enumerate(filenames):
34+
filenames[i] = os.path.join(tempdir, name)
35+
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
36+
result = LoadPNG()(filenames)
3737
self.assertTupleEqual(result[1]["spatial_shape"], meta_shape)
3838
self.assertTupleEqual(result[0].shape, expected_shape)
39+
shutil.rmtree(tempdir)
3940

4041

4142
if __name__ == "__main__":

tests/test_load_pngd.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import unittest
1313
import os
14+
import shutil
1415
import numpy as np
1516
import tempfile
1617
from PIL import Image
@@ -28,13 +29,13 @@ def test_shape(self, input_param, expected_shape):
2829
test_image = np.random.randint(0, 256, size=[128, 128, 3])
2930
tempdir = tempfile.mkdtemp()
3031
test_data = dict()
31-
with tempfile.TemporaryDirectory() as tempdir:
32-
for key in KEYS:
33-
Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png"))
34-
test_data.update({key: os.path.join(tempdir, key + ".png")})
35-
result = LoadPNGd(**input_param)(test_data)
32+
for key in KEYS:
33+
Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png"))
34+
test_data.update({key: os.path.join(tempdir, key + ".png")})
35+
result = LoadPNGd(**input_param)(test_data)
3636
for key in KEYS:
3737
self.assertTupleEqual(result[key].shape, expected_shape)
38+
shutil.rmtree(tempdir)
3839

3940

4041
if __name__ == "__main__":

0 commit comments

Comments
 (0)