diff --git a/Lib/sqlite3/test/factory.py b/Lib/sqlite3/test/factory.py index 8764284975..a80f181c3a 100644 --- a/Lib/sqlite3/test/factory.py +++ b/Lib/sqlite3/test/factory.py @@ -49,6 +49,12 @@ def tearDown(self): def test_is_instance(self): self.assertIsInstance(self.con, MyConnection) + def test_invalid_connection_factory(self): + class DefectFactory(sqlite.Connection): + def __init__(self, *args, **kwargs): + return None + self.con = sqlite.connect(":memory:", factory=DefectFactory) + class CursorFactoryTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 9bf2a35ab0..f15e25b6f6 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -184,17 +184,14 @@ pysqlite_connection_init(pysqlite_Connection *self, PyObject *args, /* action in (ACTION_RESET, ACTION_FINALIZE) */ static void -pysqlite_do_all_statements(pysqlite_Connection *self, int action, - int reset_cursors) +pysqlite_do_all_statements(pysqlite_Connection *self, int action) { - int i; - PyObject* weakref; - PyObject* statement; - pysqlite_Cursor* cursor; - - for (i = 0; i < PyList_Size(self->statements); i++) { - weakref = PyList_GetItem(self->statements, i); - statement = PyWeakref_GetObject(weakref); + if (self->statements == NULL) { + return; + } + for (int i = 0; i < PyList_Size(self->statements); i++) { + PyObject *weakref = PyList_GetItem(self->statements, i); + PyObject *statement = PyWeakref_GetObject(weakref); if (statement != Py_None) { Py_INCREF(statement); if (action == ACTION_RESET) { @@ -205,14 +202,19 @@ pysqlite_do_all_statements(pysqlite_Connection *self, int action, Py_DECREF(statement); } } +} - if (reset_cursors) { - for (i = 0; i < PyList_Size(self->cursors); i++) { - weakref = PyList_GetItem(self->cursors, i); - cursor = (pysqlite_Cursor*)PyWeakref_GetObject(weakref); - if ((PyObject*)cursor != Py_None) { - cursor->reset = 1; - } +static void +reset_cursors(pysqlite_Connection *self) +{ + if (self->cursors == NULL) { + return; + } + for (int i = 0; i < PyList_Size(self->cursors); i++) { + PyObject *weakref = PyList_GetItem(self->cursors, i); + PyObject *cursor = PyWeakref_GetObject(weakref); + if (cursor != Py_None) { + ((pysqlite_Cursor *)cursor)->reset = 1; } } } @@ -328,7 +330,8 @@ pysqlite_connection_close_impl(pysqlite_Connection *self) return NULL; } - pysqlite_do_all_statements(self, ACTION_FINALIZE, 1); + pysqlite_do_all_statements(self, ACTION_FINALIZE); + reset_cursors(self); if (self->db) { rc = sqlite3_close_v2(self->db); @@ -467,7 +470,8 @@ pysqlite_connection_rollback_impl(pysqlite_Connection *self) } if (!sqlite3_get_autocommit(self->db)) { - pysqlite_do_all_statements(self, ACTION_RESET, 1); + pysqlite_do_all_statements(self, ACTION_RESET); + reset_cursors(self); Py_BEGIN_ALLOW_THREADS rc = sqlite3_prepare_v2(self->db, "ROLLBACK", -1, &statement, NULL);