Skip to content

Commit e113883

Browse files
committed
Implement equivalent patch to upstream python/cpython#14268
1 parent 4754ff5 commit e113883

File tree

3 files changed

+63
-101
lines changed

3 files changed

+63
-101
lines changed

src/connection.c

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,9 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
187187
}
188188
self->check_same_thread = check_same_thread;
189189

190-
Py_XSETREF(self->function_pinboard, PyDict_New());
191-
if (!self->function_pinboard) {
192-
return -1;
193-
}
190+
self->function_pinboard_trace_callback = NULL;
191+
self->function_pinboard_progress_handler = NULL;
192+
self->function_pinboard_authorizer_cb = NULL;
194193

195194
Py_XSETREF(self->collations, PyDict_New());
196195
if (!self->collations) {
@@ -250,13 +249,13 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)
250249

251250
/* Clean up if user has not called .close() explicitly. */
252251
if (self->db) {
253-
Py_BEGIN_ALLOW_THREADS
254252
SQLITE3_CLOSE(self->db);
255-
Py_END_ALLOW_THREADS
256253
}
257254

258255
Py_XDECREF(self->isolation_level);
259-
Py_XDECREF(self->function_pinboard);
256+
Py_XDECREF(self->function_pinboard_trace_callback);
257+
Py_XDECREF(self->function_pinboard_progress_handler);
258+
Py_XDECREF(self->function_pinboard_authorizer_cb);
260259
Py_XDECREF(self->row_factory);
261260
Py_XDECREF(self->text_factory);
262261
Py_XDECREF(self->collations);
@@ -343,9 +342,7 @@ PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args)
343342
pysqlite_do_all_statements(self, ACTION_FINALIZE, 1);
344343

345344
if (self->db) {
346-
Py_BEGIN_ALLOW_THREADS
347345
rc = SQLITE3_CLOSE(self->db);
348-
Py_END_ALLOW_THREADS
349346

350347
if (rc != SQLITE_OK) {
351348
_pysqlite_seterror(self->db, NULL);
@@ -906,6 +903,12 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
906903
Py_SETREF(self->cursors, new_list);
907904
}
908905

906+
static void _destructor(void* args)
907+
{
908+
Py_DECREF((PyObject *)args);
909+
}
910+
911+
909912
PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
910913
{
911914
static char *kwlist[] = {"name", "narg", "func", "deterministic", NULL};
@@ -941,17 +944,16 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec
941944
flags |= SQLITE_DETERMINISTIC;
942945
#endif
943946
}
944-
if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1) {
945-
return NULL;
946-
}
947-
rc = sqlite3_create_function(self->db,
948-
name,
949-
narg,
950-
flags,
951-
(void*)func,
952-
_pysqlite_func_callback,
953-
NULL,
954-
NULL);
947+
Py_INCREF(func);
948+
rc = sqlite3_create_function_v2(self->db,
949+
name,
950+
narg,
951+
flags,
952+
(void*)func,
953+
_pysqlite_func_callback,
954+
NULL,
955+
NULL,
956+
&_destructor);
955957

956958
if (rc != SQLITE_OK) {
957959
/* Workaround for SQLite bug: no error code or string is available here */
@@ -979,10 +981,17 @@ PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObje
979981
return NULL;
980982
}
981983

982-
if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1) {
983-
return NULL;
984-
}
985-
rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback);
984+
Py_INCREF(aggregate_class);
985+
rc = sqlite3_create_function_v2(self->db,
986+
name,
987+
n_arg,
988+
SQLITE_UTF8,
989+
(void*)aggregate_class,
990+
0,
991+
&_pysqlite_step_callback,
992+
&_pysqlite_final_callback,
993+
&_destructor);
994+
986995
if (rc != SQLITE_OK) {
987996
/* Workaround for SQLite bug: no error code or string is available here */
988997
PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate");
@@ -1011,10 +1020,7 @@ PyObject* pysqlite_connection_create_window_function(pysqlite_Connection* self,
10111020
return NULL;
10121021
}
10131022

1014-
if (PyDict_SetItem(self->function_pinboard, window_function_class, Py_None) == -1) {
1015-
return NULL;
1016-
}
1017-
1023+
Py_INCREF(window_function_class);
10181024
rc = sqlite3_create_window_function(
10191025
self->db,
10201026
name,
@@ -1025,7 +1031,7 @@ PyObject* pysqlite_connection_create_window_function(pysqlite_Connection* self,
10251031
&_pysqlite_final_callback,
10261032
&_pysqlite_value_callback,
10271033
&_pysqlite_inverse_callback,
1028-
NULL);
1034+
&_destructor);
10291035

10301036
if (rc != SQLITE_OK) {
10311037
/* Workaround for SQLite bug: no error code or string is available here */
@@ -1147,13 +1153,14 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P
11471153
return NULL;
11481154
}
11491155

1150-
if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1) {
1151-
return NULL;
1152-
}
11531156
rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb);
11541157
if (rc != SQLITE_OK) {
11551158
PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback");
1159+
Py_XSETREF(self->function_pinboard_authorizer_cb, NULL);
11561160
return NULL;
1161+
} else {
1162+
Py_INCREF(authorizer_cb);
1163+
Py_XSETREF(self->function_pinboard_authorizer_cb, authorizer_cb);
11571164
}
11581165
Py_RETURN_NONE;
11591166
}
@@ -1177,10 +1184,11 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s
11771184
if (progress_handler == Py_None) {
11781185
/* None clears the progress handler previously set */
11791186
sqlite3_progress_handler(self->db, 0, 0, (void*)0);
1187+
Py_XSETREF(self->function_pinboard_progress_handler, NULL);
11801188
} else {
1181-
if (PyDict_SetItem(self->function_pinboard, progress_handler, Py_None) == -1)
1182-
return NULL;
11831189
sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler);
1190+
Py_INCREF(progress_handler);
1191+
Py_XSETREF(self->function_pinboard_progress_handler, progress_handler);
11841192
}
11851193

11861194
Py_RETURN_NONE;
@@ -1204,10 +1212,11 @@ static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* sel
12041212
if (trace_callback == Py_None) {
12051213
/* None clears the trace callback previously set */
12061214
sqlite3_trace(self->db, 0, (void*)0);
1215+
Py_XSETREF(self->function_pinboard_trace_callback, NULL);
12071216
} else {
1208-
if (PyDict_SetItem(self->function_pinboard, trace_callback, Py_None) == -1)
1209-
return NULL;
12101217
sqlite3_trace(self->db, _trace_callback, trace_callback);
1218+
Py_INCREF(trace_callback);
1219+
Py_XSETREF(self->function_pinboard_trace_callback, trace_callback);
12111220
}
12121221

12131222
Py_RETURN_NONE;

src/connection.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,10 @@ typedef struct
8585
*/
8686
PyObject* text_factory;
8787

88-
/* remember references to functions/classes used in
89-
* create_function/create/aggregate, use these as dictionary keys, so we
90-
* can keep the total system refcount constant by clearing that dictionary
91-
* in connection_dealloc */
92-
PyObject* function_pinboard;
88+
/* remember references to functions/classes used in trace/progress/auth cb */
89+
PyObject* function_pinboard_trace_callback;
90+
PyObject* function_pinboard_progress_handler;
91+
PyObject* function_pinboard_authorizer_cb;
9392

9493
/* a dictionary of registered collation name => collation callable mappings */
9594
PyObject* collations;

test/regression.py

Lines changed: 14 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
# 3. This notice may not be removed or altered from any source distribution.
2323

2424
import datetime
25+
import functools
2526
import unittest
26-
from sqlcipher3 import dbapi2 as sqlite
27+
from pysqlite3 import dbapi2 as sqlite
2728
import weakref
2829
#from test import support
2930

@@ -383,72 +384,25 @@ def CheckDelIsolation_levelSegfault(self):
383384
with self.assertRaises(AttributeError):
384385
del self.con.isolation_level
385386

387+
def CheckBpo37347(self):
388+
class Printer:
389+
def log(self, *args):
390+
return sqlite.SQLITE_OK
386391

387-
class UnhashableFunc:
388-
__hash__ = None
389-
390-
def __init__(self, return_value=None):
391-
self.calls = 0
392-
self.return_value = return_value
393-
394-
def __call__(self, *args, **kwargs):
395-
self.calls += 1
396-
return self.return_value
397-
398-
399-
class UnhashableCallbacksTestCase(unittest.TestCase):
400-
"""
401-
https://bugs.python.org/issue34052
402-
403-
Registering unhashable callbacks raises TypeError, callbacks are not
404-
registered in SQLite after such registration attempt.
405-
"""
406-
def setUp(self):
407-
self.con = sqlite.connect(':memory:')
408-
409-
def tearDown(self):
410-
self.con.close()
411-
412-
def test_progress_handler(self):
413-
f = UnhashableFunc(return_value=0)
414-
with self.assertRaisesRegex(TypeError, 'unhashable type'):
415-
self.con.set_progress_handler(f, 1)
416-
self.con.execute('SELECT 1')
417-
self.assertFalse(f.calls)
418-
419-
def test_func(self):
420-
func_name = 'func_name'
421-
f = UnhashableFunc()
422-
with self.assertRaisesRegex(TypeError, 'unhashable type'):
423-
self.con.create_function(func_name, 0, f)
424-
msg = 'no such function: %s' % func_name
425-
with self.assertRaisesRegex(sqlite.OperationalError, msg):
426-
self.con.execute('SELECT %s()' % func_name)
427-
self.assertFalse(f.calls)
428-
429-
def test_authorizer(self):
430-
f = UnhashableFunc(return_value=sqlite.SQLITE_DENY)
431-
with self.assertRaisesRegex(TypeError, 'unhashable type'):
432-
self.con.set_authorizer(f)
433-
self.con.execute('SELECT 1')
434-
self.assertFalse(f.calls)
435-
436-
def test_aggr(self):
437-
class UnhashableType(type):
438-
__hash__ = None
439-
aggr_name = 'aggr_name'
440-
with self.assertRaisesRegex(TypeError, 'unhashable type'):
441-
self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {}))
442-
msg = 'no such function: %s' % aggr_name
443-
with self.assertRaisesRegex(sqlite.OperationalError, msg):
444-
self.con.execute('SELECT %s()' % aggr_name)
392+
for method in [self.con.set_trace_callback,
393+
functools.partial(self.con.set_progress_handler, n=1),
394+
self.con.set_authorizer]:
395+
printer_instance = Printer()
396+
method(printer_instance.log)
397+
method(printer_instance.log) # Register twice, incref twice.
398+
self.con.execute('select 1') # Triggers segfault.
399+
method(None)
445400

446401

447402
def suite():
448403
regression_suite = unittest.makeSuite(RegressionTests, "Check")
449404
return unittest.TestSuite((
450405
regression_suite,
451-
unittest.makeSuite(UnhashableCallbacksTestCase),
452406
))
453407

454408
def test():

0 commit comments

Comments
 (0)