diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index 417a53109c..4825155a3b 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -226,28 +226,6 @@ def __init__(self, name): with self.assertRaises(sqlite.ProgrammingError): cur = con.cursor() - def test_cursor_registration(self): - """ - Verifies that subclassed cursor classes are correctly registered with - the connection object, too. (fetch-across-rollback problem) - """ - class Connection(sqlite.Connection): - def cursor(self): - return Cursor(self) - - class Cursor(sqlite.Cursor): - def __init__(self, con): - sqlite.Cursor.__init__(self, con) - - con = Connection(":memory:") - cur = con.cursor() - cur.execute("create table foo(x)") - cur.executemany("insert into foo(x) values (?)", [(3,), (4,), (5,)]) - cur.execute("select x from foo") - con.rollback() - with self.assertRaises(sqlite.InterfaceError): - cur.fetchall() - def test_auto_commit(self): """ Verifies that creating a connection in autocommit mode works. @@ -375,6 +353,48 @@ def test_commit_cursor_reset(self): counter += 1 self.assertEqual(counter, 3, "should have returned exactly three rows") + def test_action_after_rollback_cursor_reset(self): + """ + Similar to `test_commit_cursor_reset`, `Connection::rollback()` resets + statements and cursors, which could cause statements to be reset again + when they shouldn't be. + + See issue33376 for more details. + """ + con = sqlite.connect(":memory:") + con.executescript(""" + create table t(c); + insert into t values(0); + insert into t values(1); + insert into t values(2); + """) + + self.assertEqual(con.isolation_level, "") + + curs = con.cursor() + curs.execute("BEGIN TRANSACTION") + curs.execute("select c from t") + con.rollback() + + # Reusing the same statement from the statement cache, which has been + # reset by the rollback above. + gen = con.execute("select c from t") + + # Would previously cause a spurious reset of the statement. + del curs + + counter = 0 + for i, row in enumerate(gen): + with self.subTest(i=i, row=row): + if counter == 0: + self.assertEqual(row[0], 0) + elif counter == 1: + self.assertEqual(row[0], 1) + elif counter == 2: + self.assertEqual(row[0], 2) + counter += 1 + self.assertEqual(counter, 3, "should have returned exactly three rows") + def test_bpo31770(self): """ The interpreter shouldn't crash in case Cursor.__init__() is called diff --git a/Lib/sqlite3/test/transactions.py b/Lib/sqlite3/test/transactions.py index 80284902a1..0229d62c8f 100644 --- a/Lib/sqlite3/test/transactions.py +++ b/Lib/sqlite3/test/transactions.py @@ -128,21 +128,6 @@ def test_locking(self): # NO self.con2.rollback() HERE!!! self.con1.commit() - def test_rollback_cursor_consistency(self): - """ - Checks if cursors on the connection are set into a "reset" state - when a rollback is done on the connection. - """ - con = sqlite.connect(":memory:") - cur = con.cursor() - cur.execute("create table test(x)") - cur.execute("insert into test(x) values (5)") - cur.execute("select 1 union select 2 union select 3") - - con.rollback() - with self.assertRaises(sqlite.InterfaceError): - cur.fetchall() - class SpecialCommandTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 28932726b7..95fe401fc0 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -30,9 +30,6 @@ #include "prepare_protocol.h" #include "util.h" -#define ACTION_FINALIZE 1 -#define ACTION_RESET 2 - #if SQLITE_VERSION_NUMBER >= 3014000 #define HAVE_TRACE_V2 #endif @@ -190,39 +187,18 @@ pysqlite_connection_init(pysqlite_Connection *self, PyObject *args, return 0; } -/* action in (ACTION_RESET, ACTION_FINALIZE) */ static void -pysqlite_do_all_statements(pysqlite_Connection *self, int action, - int reset_cursors) +finalize_all_statements(pysqlite_Connection *self) { - int i; - PyObject* weakref; - PyObject* statement; - pysqlite_Cursor* cursor; - - for (i = 0; i < PyList_Size(self->statements); i++) { - weakref = PyList_GetItem(self->statements, i); - statement = PyWeakref_GetObject(weakref); + for (int i = 0; i < PyList_Size(self->statements); i++) { + PyObject *weakref = PyList_GetItem(self->statements, i); + PyObject *statement = PyWeakref_GetObject(weakref); if (statement != Py_None) { Py_INCREF(statement); - if (action == ACTION_RESET) { - (void)pysqlite_statement_reset((pysqlite_Statement*)statement); - } else { - (void)pysqlite_statement_finalize((pysqlite_Statement*)statement); - } + (void)pysqlite_statement_finalize((pysqlite_Statement*)statement); Py_DECREF(statement); } } - - if (reset_cursors) { - for (i = 0; i < PyList_Size(self->cursors); i++) { - weakref = PyList_GetItem(self->cursors, i); - cursor = (pysqlite_Cursor*)PyWeakref_GetObject(weakref); - if ((PyObject*)cursor != Py_None) { - cursor->reset = 1; - } - } - } } static void @@ -336,7 +312,7 @@ pysqlite_connection_close_impl(pysqlite_Connection *self) return NULL; } - pysqlite_do_all_statements(self, ACTION_FINALIZE, 1); + finalize_all_statements(self); if (self->db) { rc = sqlite3_close_v2(self->db); @@ -475,8 +451,6 @@ pysqlite_connection_rollback_impl(pysqlite_Connection *self) } if (!sqlite3_get_autocommit(self->db)) { - pysqlite_do_all_statements(self, ACTION_RESET, 1); - Py_BEGIN_ALLOW_THREADS rc = sqlite3_prepare_v2(self->db, "ROLLBACK", -1, &statement, NULL); Py_END_ALLOW_THREADS diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c index b71f780a0b..236ffcb6b4 100644 --- a/Modules/_sqlite/cursor.c +++ b/Modules/_sqlite/cursor.c @@ -32,8 +32,6 @@ class _sqlite3.Cursor "pysqlite_Cursor *" "pysqlite_CursorType" [clinic start generated code]*/ /*[clinic end generated code: output=da39a3ee5e6b4b0d input=b2072d8db95411d5]*/ -static const char errmsg_fetch_across_rollback[] = "Cursor needed to be reset because of commit/rollback and can no longer be fetched from."; - /*[clinic input] _sqlite3.Cursor.__init__ as pysqlite_cursor_init @@ -61,7 +59,6 @@ pysqlite_cursor_init_impl(pysqlite_Cursor *self, self->arraysize = 1; self->closed = 0; - self->reset = 0; self->rowcount = -1L; @@ -248,11 +245,6 @@ _pysqlite_fetch_one_row(pysqlite_Cursor* self) const char* colname; PyObject* error_msg; - if (self->reset) { - PyErr_SetString(pysqlite_InterfaceError, errmsg_fetch_across_rollback); - return NULL; - } - Py_BEGIN_ALLOW_THREADS numcols = sqlite3_data_count(self->statement->st); Py_END_ALLOW_THREADS @@ -414,7 +406,6 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation } self->locked = 1; - self->reset = 0; Py_CLEAR(self->next_row); @@ -687,8 +678,6 @@ pysqlite_cursor_executescript(pysqlite_Cursor *self, PyObject *script_obj) return NULL; } - self->reset = 0; - if (PyUnicode_Check(script_obj)) { script_cstr = PyUnicode_AsUTF8(script_obj); if (!script_cstr) { @@ -764,11 +753,6 @@ pysqlite_cursor_iternext(pysqlite_Cursor *self) return NULL; } - if (self->reset) { - PyErr_SetString(pysqlite_InterfaceError, errmsg_fetch_across_rollback); - return NULL; - } - if (!self->next_row) { if (self->statement) { (void)pysqlite_statement_reset(self->statement); diff --git a/Modules/_sqlite/cursor.h b/Modules/_sqlite/cursor.h index b26b288674..c75387a893 100644 --- a/Modules/_sqlite/cursor.h +++ b/Modules/_sqlite/cursor.h @@ -42,7 +42,6 @@ typedef struct PyObject* row_factory; pysqlite_Statement* statement; int closed; - int reset; int locked; int initialized;