Skip to content

Commit 325b2c2

Browse files
[3.10] bpo-45500: Rewrite test_dbm (GH-29002) (GH-29069)
* Generate test classes at import time. It allows to filter them when run with unittest. E.g: "./python -m unittest test.test_dbm.TestCase_gnu -v". * Create a database class in a new directory which will be removed after test. It guarantees that all created files and directories be removed and will not conflict with other dbm tests. * Restore dbm._defaultmod after tests. Previously it was set to the last dbm module (dbm.dumb) which affected other tests. * Enable the whichdb test for dbm.dumb. * Move test_keys to the correct test class. It does not test whichdb(). * Remove some outdated code and comments.. (cherry picked from commit 975b94b) Co-authored-by: Serhiy Storchaka <[email protected]>
1 parent 092ec4b commit 325b2c2

File tree

1 file changed

+50
-64
lines changed

1 file changed

+50
-64
lines changed

Lib/test/test_dbm.py

Lines changed: 50 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
11
"""Test script for the dbm.open function based on testdumbdbm.py"""
22

33
import unittest
4-
import glob
4+
import dbm
5+
import os
56
from test.support import import_helper
67
from test.support import os_helper
78

8-
# Skip tests if dbm module doesn't exist.
9-
dbm = import_helper.import_module('dbm')
10-
119
try:
1210
from dbm import ndbm
1311
except ImportError:
1412
ndbm = None
1513

16-
_fname = os_helper.TESTFN
14+
dirname = os_helper.TESTFN
15+
_fname = os.path.join(dirname, os_helper.TESTFN)
1716

1817
#
19-
# Iterates over every database module supported by dbm currently available,
20-
# setting dbm to use each in turn, and yielding that module
18+
# Iterates over every database module supported by dbm currently available.
2119
#
2220
def dbm_iterator():
2321
for name in dbm._names:
@@ -31,11 +29,12 @@ def dbm_iterator():
3129
#
3230
# Clean up all scratch databases we might have created during testing
3331
#
34-
def delete_files():
35-
# we don't know the precise name the underlying database uses
36-
# so we use glob to locate all names
37-
for f in glob.glob(glob.escape(_fname) + "*"):
38-
os_helper.unlink(f)
32+
def cleaunup_test_dir():
33+
os_helper.rmtree(dirname)
34+
35+
def setup_test_dir():
36+
cleaunup_test_dir()
37+
os.mkdir(dirname)
3938

4039

4140
class AnyDBMTestCase:
@@ -134,80 +133,67 @@ def read_helper(self, f):
134133
for key in self._dict:
135134
self.assertEqual(self._dict[key], f[key.encode("ascii")])
136135

137-
def tearDown(self):
138-
delete_files()
136+
def test_keys(self):
137+
with dbm.open(_fname, 'c') as d:
138+
self.assertEqual(d.keys(), [])
139+
a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')]
140+
for k, v in a:
141+
d[k] = v
142+
self.assertEqual(sorted(d.keys()), sorted(k for (k, v) in a))
143+
for k, v in a:
144+
self.assertIn(k, d)
145+
self.assertEqual(d[k], v)
146+
self.assertNotIn(b'xxx', d)
147+
self.assertRaises(KeyError, lambda: d[b'xxx'])
139148

140149
def setUp(self):
150+
self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod)
141151
dbm._defaultmod = self.module
142-
delete_files()
152+
self.addCleanup(cleaunup_test_dir)
153+
setup_test_dir()
143154

144155

145156
class WhichDBTestCase(unittest.TestCase):
146157
def test_whichdb(self):
158+
self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod)
147159
for module in dbm_iterator():
148160
# Check whether whichdb correctly guesses module name
149161
# for databases opened with "module" module.
150-
# Try with empty files first
151162
name = module.__name__
152-
if name == 'dbm.dumb':
153-
continue # whichdb can't support dbm.dumb
154-
delete_files()
155-
f = module.open(_fname, 'c')
156-
f.close()
163+
setup_test_dir()
164+
dbm._defaultmod = module
165+
# Try with empty files first
166+
with module.open(_fname, 'c'): pass
157167
self.assertEqual(name, self.dbm.whichdb(_fname))
158168
# Now add a key
159-
f = module.open(_fname, 'w')
160-
f[b"1"] = b"1"
161-
# and test that we can find it
162-
self.assertIn(b"1", f)
163-
# and read it
164-
self.assertEqual(f[b"1"], b"1")
165-
f.close()
169+
with module.open(_fname, 'w') as f:
170+
f[b"1"] = b"1"
171+
# and test that we can find it
172+
self.assertIn(b"1", f)
173+
# and read it
174+
self.assertEqual(f[b"1"], b"1")
166175
self.assertEqual(name, self.dbm.whichdb(_fname))
167176

168177
@unittest.skipUnless(ndbm, reason='Test requires ndbm')
169178
def test_whichdb_ndbm(self):
170179
# Issue 17198: check that ndbm which is referenced in whichdb is defined
171-
db_file = '{}_ndbm.db'.format(_fname)
172-
with open(db_file, 'w'):
173-
self.addCleanup(os_helper.unlink, db_file)
174-
self.assertIsNone(self.dbm.whichdb(db_file[:-3]))
175-
176-
def tearDown(self):
177-
delete_files()
180+
with open(_fname + '.db', 'wb'): pass
181+
self.assertIsNone(self.dbm.whichdb(_fname))
178182

179183
def setUp(self):
180-
delete_files()
181-
self.filename = os_helper.TESTFN
182-
self.d = dbm.open(self.filename, 'c')
183-
self.d.close()
184+
self.addCleanup(cleaunup_test_dir)
185+
setup_test_dir()
184186
self.dbm = import_helper.import_fresh_module('dbm')
185187

186-
def test_keys(self):
187-
self.d = dbm.open(self.filename, 'c')
188-
self.assertEqual(self.d.keys(), [])
189-
a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')]
190-
for k, v in a:
191-
self.d[k] = v
192-
self.assertEqual(sorted(self.d.keys()), sorted(k for (k, v) in a))
193-
for k, v in a:
194-
self.assertIn(k, self.d)
195-
self.assertEqual(self.d[k], v)
196-
self.assertNotIn(b'xxx', self.d)
197-
self.assertRaises(KeyError, lambda: self.d[b'xxx'])
198-
self.d.close()
199-
200-
201-
def load_tests(loader, tests, pattern):
202-
classes = []
203-
for mod in dbm_iterator():
204-
classes.append(type("TestCase-" + mod.__name__,
205-
(AnyDBMTestCase, unittest.TestCase),
206-
{'module': mod}))
207-
suites = [unittest.makeSuite(c) for c in classes]
208-
209-
tests.addTests(suites)
210-
return tests
188+
189+
for mod in dbm_iterator():
190+
assert mod.__name__.startswith('dbm.')
191+
suffix = mod.__name__[4:]
192+
testname = f'TestCase_{suffix}'
193+
globals()[testname] = type(testname,
194+
(AnyDBMTestCase, unittest.TestCase),
195+
{'module': mod})
196+
211197

212198
if __name__ == "__main__":
213199
unittest.main()

0 commit comments

Comments
 (0)