1
1
"""Test script for the dbm.open function based on testdumbdbm.py"""
2
2
3
3
import unittest
4
- import glob
4
+ import dbm
5
5
import os
6
6
from test .support import import_helper
7
7
from test .support import os_helper
8
8
9
- # Skip tests if dbm module doesn't exist.
10
- dbm = import_helper .import_module ('dbm' )
11
-
12
9
try :
13
10
from dbm import ndbm
14
11
except ImportError :
15
12
ndbm = None
16
13
17
- _fname = os_helper .TESTFN
14
+ dirname = os_helper .TESTFN
15
+ _fname = os .path .join (dirname , os_helper .TESTFN )
18
16
19
17
#
20
- # Iterates over every database module supported by dbm currently available,
21
- # setting dbm to use each in turn, and yielding that module
18
+ # Iterates over every database module supported by dbm currently available.
22
19
#
23
20
def dbm_iterator ():
24
21
for name in dbm ._names :
@@ -32,11 +29,12 @@ def dbm_iterator():
32
29
#
33
30
# Clean up all scratch databases we might have created during testing
34
31
#
35
- def delete_files ():
36
- # we don't know the precise name the underlying database uses
37
- # so we use glob to locate all names
38
- for f in glob .glob (glob .escape (_fname ) + "*" ):
39
- os_helper .unlink (f )
32
+ def cleaunup_test_dir ():
33
+ os_helper .rmtree (dirname )
34
+
35
+ def setup_test_dir ():
36
+ cleaunup_test_dir ()
37
+ os .mkdir (dirname )
40
38
41
39
42
40
class AnyDBMTestCase :
@@ -144,86 +142,76 @@ def read_helper(self, f):
144
142
for key in self ._dict :
145
143
self .assertEqual (self ._dict [key ], f [key .encode ("ascii" )])
146
144
147
- def tearDown (self ):
148
- delete_files ()
145
+ def test_keys (self ):
146
+ with dbm .open (_fname , 'c' ) as d :
147
+ self .assertEqual (d .keys (), [])
148
+ a = [(b'a' , b'b' ), (b'12345678910' , b'019237410982340912840198242' )]
149
+ for k , v in a :
150
+ d [k ] = v
151
+ self .assertEqual (sorted (d .keys ()), sorted (k for (k , v ) in a ))
152
+ for k , v in a :
153
+ self .assertIn (k , d )
154
+ self .assertEqual (d [k ], v )
155
+ self .assertNotIn (b'xxx' , d )
156
+ self .assertRaises (KeyError , lambda : d [b'xxx' ])
149
157
150
158
def setUp (self ):
159
+ self .addCleanup (setattr , dbm , '_defaultmod' , dbm ._defaultmod )
151
160
dbm ._defaultmod = self .module
152
- delete_files ()
161
+ self .addCleanup (cleaunup_test_dir )
162
+ setup_test_dir ()
153
163
154
164
155
165
class WhichDBTestCase (unittest .TestCase ):
156
166
def test_whichdb (self ):
167
+ self .addCleanup (setattr , dbm , '_defaultmod' , dbm ._defaultmod )
157
168
_bytes_fname = os .fsencode (_fname )
158
- for path in [_fname , os_helper .FakePath (_fname ),
159
- _bytes_fname , os_helper .FakePath (_bytes_fname )]:
160
- for module in dbm_iterator ():
161
- # Check whether whichdb correctly guesses module name
162
- # for databases opened with "module" module.
163
- # Try with empty files first
164
- name = module .__name__
165
- if name == 'dbm.dumb' :
166
- continue # whichdb can't support dbm.dumb
167
- delete_files ()
168
- f = module .open (path , 'c' )
169
- f .close ()
169
+ fnames = [_fname , os_helper .FakePath (_fname ),
170
+ _bytes_fname , os_helper .FakePath (_bytes_fname )]
171
+ for module in dbm_iterator ():
172
+ # Check whether whichdb correctly guesses module name
173
+ # for databases opened with "module" module.
174
+ name = module .__name__
175
+ setup_test_dir ()
176
+ dbm ._defaultmod = module
177
+ # Try with empty files first
178
+ with module .open (_fname , 'c' ): pass
179
+ for path in fnames :
170
180
self .assertEqual (name , self .dbm .whichdb (path ))
171
- # Now add a key
172
- f = module .open (path , 'w' )
181
+ # Now add a key
182
+ with module .open (_fname , 'w' ) as f :
173
183
f [b"1" ] = b"1"
174
184
# and test that we can find it
175
185
self .assertIn (b"1" , f )
176
186
# and read it
177
187
self .assertEqual (f [b"1" ], b"1" )
178
- f . close ()
188
+ for path in fnames :
179
189
self .assertEqual (name , self .dbm .whichdb (path ))
180
190
181
191
@unittest .skipUnless (ndbm , reason = 'Test requires ndbm' )
182
192
def test_whichdb_ndbm (self ):
183
193
# Issue 17198: check that ndbm which is referenced in whichdb is defined
184
- db_file = '{}_ndbm.db' .format (_fname )
185
- with open (db_file , 'w' ):
186
- self .addCleanup (os_helper .unlink , db_file )
187
- db_file_bytes = os .fsencode (db_file )
188
- self .assertIsNone (self .dbm .whichdb (db_file [:- 3 ]))
189
- self .assertIsNone (self .dbm .whichdb (os_helper .FakePath (db_file [:- 3 ])))
190
- self .assertIsNone (self .dbm .whichdb (db_file_bytes [:- 3 ]))
191
- self .assertIsNone (self .dbm .whichdb (os_helper .FakePath (db_file_bytes [:- 3 ])))
192
-
193
- def tearDown (self ):
194
- delete_files ()
194
+ with open (_fname + '.db' , 'wb' ): pass
195
+ _bytes_fname = os .fsencode (_fname )
196
+ fnames = [_fname , os_helper .FakePath (_fname ),
197
+ _bytes_fname , os_helper .FakePath (_bytes_fname )]
198
+ for path in fnames :
199
+ self .assertIsNone (self .dbm .whichdb (path ))
195
200
196
201
def setUp (self ):
197
- delete_files ()
198
- self .filename = os_helper .TESTFN
199
- self .d = dbm .open (self .filename , 'c' )
200
- self .d .close ()
202
+ self .addCleanup (cleaunup_test_dir )
203
+ setup_test_dir ()
201
204
self .dbm = import_helper .import_fresh_module ('dbm' )
202
205
203
- def test_keys (self ):
204
- self .d = dbm .open (self .filename , 'c' )
205
- self .assertEqual (self .d .keys (), [])
206
- a = [(b'a' , b'b' ), (b'12345678910' , b'019237410982340912840198242' )]
207
- for k , v in a :
208
- self .d [k ] = v
209
- self .assertEqual (sorted (self .d .keys ()), sorted (k for (k , v ) in a ))
210
- for k , v in a :
211
- self .assertIn (k , self .d )
212
- self .assertEqual (self .d [k ], v )
213
- self .assertNotIn (b'xxx' , self .d )
214
- self .assertRaises (KeyError , lambda : self .d [b'xxx' ])
215
- self .d .close ()
216
-
217
-
218
- def load_tests (loader , tests , pattern ):
219
- classes = []
220
- for mod in dbm_iterator ():
221
- classes .append (type ("TestCase-" + mod .__name__ ,
222
- (AnyDBMTestCase , unittest .TestCase ),
223
- {'module' : mod }))
224
- for c in classes :
225
- tests .addTest (loader .loadTestsFromTestCase (c ))
226
- return tests
206
+
207
+ for mod in dbm_iterator ():
208
+ assert mod .__name__ .startswith ('dbm.' )
209
+ suffix = mod .__name__ [4 :]
210
+ testname = f'TestCase_{ suffix } '
211
+ globals ()[testname ] = type (testname ,
212
+ (AnyDBMTestCase , unittest .TestCase ),
213
+ {'module' : mod })
214
+
227
215
228
216
if __name__ == "__main__" :
229
217
unittest .main ()
0 commit comments