Skip to content

Commit db021e9

Browse files
committed
bpo37347: fix refcount problem in sqlite3
1 parent 0c48618 commit db021e9

File tree

4 files changed

+64
-98
lines changed

4 files changed

+64
-98
lines changed

Lib/sqlite3/test/regression.py

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import unittest
2626
import sqlite3 as sqlite
2727
import weakref
28+
import sys
29+
import functools
2830
from test import support
2931

3032
class RegressionTests(unittest.TestCase):
@@ -383,72 +385,30 @@ def CheckDelIsolation_levelSegfault(self):
383385
with self.assertRaises(AttributeError):
384386
del self.con.isolation_level
385387

388+
def CheckBpo37347(self):
389+
class Printer:
390+
def log(self, *args):
391+
return sqlite.SQLITE_OK
386392

387-
class UnhashableFunc:
388-
__hash__ = None
393+
for method in [self.con.set_trace_callback,
394+
functools.partial(self.con.set_progress_handler, n=1),
395+
self.con.set_authorizer]:
396+
printer_instance = Printer()
397+
method(printer_instance.log)
398+
log = printer_instance.log
399+
refcnt_before = sys.getrefcount(log)
400+
method(log)
401+
refcnt_after = sys.getrefcount(log)
402+
self.assertEqual(refcnt_after - refcnt_before, 1, "%r must increase reference-count of arg" % method)
403+
self.con.execute("select 1") # trigger seg fault
404+
method(None)
389405

390-
def __init__(self, return_value=None):
391-
self.calls = 0
392-
self.return_value = return_value
393-
394-
def __call__(self, *args, **kwargs):
395-
self.calls += 1
396-
return self.return_value
397-
398-
399-
class UnhashableCallbacksTestCase(unittest.TestCase):
400-
"""
401-
https://bugs.python.org/issue34052
402-
403-
Registering unhashable callbacks raises TypeError, callbacks are not
404-
registered in SQLite after such registration attempt.
405-
"""
406-
def setUp(self):
407-
self.con = sqlite.connect(':memory:')
408-
409-
def tearDown(self):
410-
self.con.close()
411-
412-
def test_progress_handler(self):
413-
f = UnhashableFunc(return_value=0)
414-
with self.assertRaisesRegex(TypeError, 'unhashable type'):
415-
self.con.set_progress_handler(f, 1)
416-
self.con.execute('SELECT 1')
417-
self.assertFalse(f.calls)
418-
419-
def test_func(self):
420-
func_name = 'func_name'
421-
f = UnhashableFunc()
422-
with self.assertRaisesRegex(TypeError, 'unhashable type'):
423-
self.con.create_function(func_name, 0, f)
424-
msg = 'no such function: %s' % func_name
425-
with self.assertRaisesRegex(sqlite.OperationalError, msg):
426-
self.con.execute('SELECT %s()' % func_name)
427-
self.assertFalse(f.calls)
428-
429-
def test_authorizer(self):
430-
f = UnhashableFunc(return_value=sqlite.SQLITE_DENY)
431-
with self.assertRaisesRegex(TypeError, 'unhashable type'):
432-
self.con.set_authorizer(f)
433-
self.con.execute('SELECT 1')
434-
self.assertFalse(f.calls)
435-
436-
def test_aggr(self):
437-
class UnhashableType(type):
438-
__hash__ = None
439-
aggr_name = 'aggr_name'
440-
with self.assertRaisesRegex(TypeError, 'unhashable type'):
441-
self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {}))
442-
msg = 'no such function: %s' % aggr_name
443-
with self.assertRaisesRegex(sqlite.OperationalError, msg):
444-
self.con.execute('SELECT %s()' % aggr_name)
445406

446407

447408
def suite():
448409
regression_suite = unittest.makeSuite(RegressionTests, "Check")
449410
return unittest.TestSuite((
450411
regression_suite,
451-
unittest.makeSuite(UnhashableCallbacksTestCase),
452412
))
453413

454414
def test():

Modules/_sqlite/connection.c

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,9 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
186186
}
187187
self->check_same_thread = check_same_thread;
188188

189-
Py_XSETREF(self->function_pinboard, PyDict_New());
190-
if (!self->function_pinboard) {
191-
return -1;
192-
}
189+
self->function_pinboard_trace_callback = NULL;
190+
self->function_pinboard_progress_handler = NULL;
191+
self->function_pinboard_authorizer_cb = NULL;
193192

194193
Py_XSETREF(self->collations, PyDict_New());
195194
if (!self->collations) {
@@ -249,19 +248,18 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)
249248

250249
/* Clean up if user has not called .close() explicitly. */
251250
if (self->db) {
252-
Py_BEGIN_ALLOW_THREADS
253251
SQLITE3_CLOSE(self->db);
254-
Py_END_ALLOW_THREADS
255252
}
256253

257254
Py_XDECREF(self->isolation_level);
258-
Py_XDECREF(self->function_pinboard);
255+
Py_XDECREF(self->function_pinboard_trace_callback);
256+
Py_XDECREF(self->function_pinboard_progress_handler);
257+
Py_XDECREF(self->function_pinboard_authorizer_cb);
259258
Py_XDECREF(self->row_factory);
260259
Py_XDECREF(self->text_factory);
261260
Py_XDECREF(self->collations);
262261
Py_XDECREF(self->statements);
263262
Py_XDECREF(self->cursors);
264-
265263
Py_TYPE(self)->tp_free((PyObject*)self);
266264
}
267265

@@ -342,9 +340,7 @@ PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args)
342340
pysqlite_do_all_statements(self, ACTION_FINALIZE, 1);
343341

344342
if (self->db) {
345-
Py_BEGIN_ALLOW_THREADS
346343
rc = SQLITE3_CLOSE(self->db);
347-
Py_END_ALLOW_THREADS
348344

349345
if (rc != SQLITE_OK) {
350346
_pysqlite_seterror(self->db, NULL);
@@ -808,6 +804,11 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
808804
Py_SETREF(self->cursors, new_list);
809805
}
810806

807+
static void _destructor(void* args)
808+
{
809+
Py_DECREF((PyObject*)args);
810+
}
811+
811812
PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
812813
{
813814
static char *kwlist[] = {"name", "narg", "func", "deterministic", NULL};
@@ -843,17 +844,16 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec
843844
flags |= SQLITE_DETERMINISTIC;
844845
#endif
845846
}
846-
if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1) {
847-
return NULL;
848-
}
849-
rc = sqlite3_create_function(self->db,
850-
name,
851-
narg,
852-
flags,
853-
(void*)func,
854-
_pysqlite_func_callback,
855-
NULL,
856-
NULL);
847+
Py_INCREF(func);
848+
rc = sqlite3_create_function_v2(self->db,
849+
name,
850+
narg,
851+
flags,
852+
(void*)func,
853+
_pysqlite_func_callback,
854+
NULL,
855+
NULL,
856+
&_destructor); // will decref func
857857

858858
if (rc != SQLITE_OK) {
859859
/* Workaround for SQLite bug: no error code or string is available here */
@@ -880,11 +880,16 @@ PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObje
880880
kwlist, &name, &n_arg, &aggregate_class)) {
881881
return NULL;
882882
}
883-
884-
if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1) {
885-
return NULL;
886-
}
887-
rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback);
883+
Py_INCREF(aggregate_class);
884+
rc = sqlite3_create_function_v2(self->db,
885+
name,
886+
n_arg,
887+
SQLITE_UTF8,
888+
(void*)aggregate_class,
889+
0,
890+
&_pysqlite_step_callback,
891+
&_pysqlite_final_callback,
892+
&_destructor); // will decref func
888893
if (rc != SQLITE_OK) {
889894
/* Workaround for SQLite bug: no error code or string is available here */
890895
PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate");
@@ -1003,13 +1008,14 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P
10031008
return NULL;
10041009
}
10051010

1006-
if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1) {
1007-
return NULL;
1008-
}
10091011
rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb);
10101012
if (rc != SQLITE_OK) {
10111013
PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback");
1014+
Py_XSETREF(self->function_pinboard_authorizer_cb, NULL);
10121015
return NULL;
1016+
} else {
1017+
Py_INCREF(authorizer_cb);
1018+
Py_XSETREF(self->function_pinboard_authorizer_cb, authorizer_cb);
10131019
}
10141020
Py_RETURN_NONE;
10151021
}
@@ -1033,12 +1039,12 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s
10331039
if (progress_handler == Py_None) {
10341040
/* None clears the progress handler previously set */
10351041
sqlite3_progress_handler(self->db, 0, 0, (void*)0);
1042+
Py_XSETREF(self->function_pinboard_progress_handler, NULL);
10361043
} else {
1037-
if (PyDict_SetItem(self->function_pinboard, progress_handler, Py_None) == -1)
1038-
return NULL;
10391044
sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler);
1045+
Py_INCREF(progress_handler);
1046+
Py_XSETREF(self->function_pinboard_progress_handler, progress_handler);
10401047
}
1041-
10421048
Py_RETURN_NONE;
10431049
}
10441050

@@ -1060,10 +1066,11 @@ static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* sel
10601066
if (trace_callback == Py_None) {
10611067
/* None clears the trace callback previously set */
10621068
sqlite3_trace(self->db, 0, (void*)0);
1069+
Py_XSETREF(self->function_pinboard_trace_callback, NULL);
10631070
} else {
1064-
if (PyDict_SetItem(self->function_pinboard, trace_callback, Py_None) == -1)
1065-
return NULL;
10661071
sqlite3_trace(self->db, _trace_callback, trace_callback);
1072+
Py_INCREF(trace_callback);
1073+
Py_XSETREF(self->function_pinboard_trace_callback, trace_callback);
10671074
}
10681075

10691076
Py_RETURN_NONE;

Modules/_sqlite/connection.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,10 @@ typedef struct
8585
*/
8686
PyObject* text_factory;
8787

88-
/* remember references to functions/classes used in
89-
* create_function/create/aggregate, use these as dictionary keys, so we
90-
* can keep the total system refcount constant by clearing that dictionary
91-
* in connection_dealloc */
92-
PyObject* function_pinboard;
88+
/* remember references to object used in trace_callback/progress_handler/authorizer_cb */
89+
PyObject* function_pinboard_trace_callback;
90+
PyObject* function_pinboard_progress_handler;
91+
PyObject* function_pinboard_authorizer_cb;
9392

9493
/* a dictionary of registered collation name => collation callable mappings */
9594
PyObject* collations;

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1348,7 +1348,7 @@ def detect_sqlite(self):
13481348
]
13491349
if CROSS_COMPILING:
13501350
sqlite_inc_paths = []
1351-
MIN_SQLITE_VERSION_NUMBER = (3, 3, 9)
1351+
MIN_SQLITE_VERSION_NUMBER = (3, 7, 2)
13521352
MIN_SQLITE_VERSION = ".".join([str(x)
13531353
for x in MIN_SQLITE_VERSION_NUMBER])
13541354

0 commit comments

Comments
 (0)