diff -r 9160557984c7 Lib/sqlite3/test/dbapi.py --- a/Lib/sqlite3/test/dbapi.py Sat Mar 26 23:37:03 2011 +0100 +++ b/Lib/sqlite3/test/dbapi.py Sun Mar 27 03:26:38 2011 +0200 @@ -361,11 +361,17 @@ try: self.cu.executemany("select ?", [(3,)]) self.fail("should have raised a ProgrammingError") - except sqlite.ProgrammingError: + except sqlite.ProgrammingError, e: + if str(e).lower().find("select") == -1: + self.fail("error message should refer to select statement") return except: self.fail("raised wrong exception.") + def CheckExecuteManySelectNoIsolation(self): + self.cx.isolation_level = None + self.CheckExecuteManySelect() + def CheckExecuteManyNotIterable(self): try: self.cu.executemany("insert into test(income) values (?)", 42) diff -r 9160557984c7 Lib/sqlite3/test/transactions.py --- a/Lib/sqlite3/test/transactions.py Sat Mar 26 23:37:03 2011 +0100 +++ b/Lib/sqlite3/test/transactions.py Sun Mar 27 03:26:38 2011 +0200 @@ -53,6 +53,29 @@ except OSError: pass + def CheckHasActiveTransaction(self): + """Test that in_transaction returns the actual transaction state.""" + self.assertFalse(self.con1.in_transaction) + self.cur1.execute("create table test(i)") + self.cur1.execute("insert into test(i) values (5)") + self.assertTrue(self.con1.in_transaction) + self.con1.commit() + self.assertFalse(self.con1.in_transaction) + + # Manage the transaction state manually and check if it is detected correctly. + self.con2.isolation_level = None + self.assertFalse(self.con2.in_transaction) + self.cur2.execute("begin") + self.assertTrue(self.con2.in_transaction) + self.con2.commit() + self.assertFalse(self.con2.in_transaction) + + self.cur2.execute("begin") + self.assertTrue(self.con2.in_transaction) + self.cur2.execute("commit") + self.assertFalse(self.con2.in_transaction) + + def CheckDMLdoesAutoCommitBefore(self): self.cur1.execute("create table test(i)") self.cur1.execute("insert into test(i) values (5)") @@ -168,6 +191,107 @@ except: self.fail("InterfaceError should have been raised") + def CheckDropTableRollback(self): + """ + Checks that drop table can be run inside a transaction and will + roll back correctly. + """ + self.con1.operation_needs_transaction_callback = lambda x: True + self.cur1.execute("create table test(x)") + self.cur1.execute("insert into test(x) values (5)") + self.con1.commit() + self.cur1.execute("drop table test") + self.con1.rollback() + # Table should still exist. + self.cur1.execute("select * from test") + + def CheckCreateTableRollback(self): + """Checks that create table runs inside a transaction and can be rolled back.""" + self.con1.operation_needs_transaction_callback = lambda x: True + self.cur1.execute("create table test(x)") + self.con1.rollback() + # Table test was rolled back so this should work + self.cur1.execute("create table test(x)") + + def CheckSavepoints(self): + """Trivial savepoint check.""" + self.con1.operation_needs_transaction_callback = lambda x: True + self.cur1.execute("create table test(x)") + self.con1.commit() + self.cur1.execute("insert into test(x) values (1)") + self.cur1.execute("savepoint foobar") + self.cur1.execute("insert into test(x) values (2)") + self.cur1.execute("rollback to savepoint foobar") + self.con1.commit() + self.cur2.execute("select x from test") + res = self.cur2.fetchall() + self.assertEqual(len(res), 1) + self.assertEqual(res[0][0], 1) + + def CheckCreateIndexRollback(self): + """Check that create index is transactional.""" + self.con1.operation_needs_transaction_callback = lambda x: True + self.cur1.execute("create table test(x integer)") + self.cur1.execute("insert into test(x) values (1)") + self.con1.commit() + self.cur1.execute("create index myidx on test(x)") + self.assertTrue(self.cur1.execute("pragma index_info(myidx)").fetchone()) + self.cur1.execute("insert into test(x) values (2)") + self.con1.rollback() + self.assertFalse(self.cur1.execute("pragma index_info(myidx)").fetchone()) + + def CheckColumnAddRollback(self): + """Check that adding a column is transactional.""" + self.con1.operation_needs_transaction_callback = lambda x: True + self.cur1.execute("create table test(x integer)") + self.cur1.execute("insert into test(x) values (42)") + self.con1.commit() + self.cur1.execute("alter table test add column y integer default 37") + self.assertEqual(len(self.cur1.execute("select * from test").fetchone()), 2) + self.con1.rollback() + self.assertEqual(len(self.cur1.execute("select * from test").fetchone()), 1) + try: + self.cur1.execute("insert into test(x,y) values (1,2)") + self.fail("Column y should have been rolled back.") + except sqlite.OperationalError: + pass + + def CheckTableRenameRollback(self): + """Check that renaming a table is transactional.""" + self.con1.operation_needs_transaction_callback = lambda x: True + self.cur1.execute("create table foo(x integer)") + self.con1.commit() + self.cur1.execute("alter table foo rename to bar") + self.cur1.execute("select * from bar") + try: + self.cur1.execute("select * from foo") + self.fail("Table foo should have been renamed to bar") + except sqlite.OperationalError: + pass + self.con1.rollback() + self.cur1.execute("select * from foo") + try: + self.cur1.execute("select * from bar") + self.fail("Renaming the table should have been rolled back.") + except sqlite.OperationalError: + pass + + def CheckDropIndexRollback(self): + """Check that dropping an index is transactional.""" + self.con1.operation_needs_transaction_callback = lambda x: True + self.cur1.execute("create table foo(x integer)") + self.cur1.execute("create index myidx on foo(x)") + self.con1.commit() + self.cur1.execute("drop index myidx") + self.con1.rollback() + try: + self.cur1.execute("create index myidx on foo(x)") + self.fail("Index myidx should exist here (dropping it was rolled back).") + except sqlite.OperationalError, e: + # OperationalError: index myidx already exists + pass + + class SpecialCommandTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") diff -r 9160557984c7 Modules/_sqlite/connection.c --- a/Modules/_sqlite/connection.c Sat Mar 26 23:37:03 2011 +0100 +++ b/Modules/_sqlite/connection.c Sun Mar 27 03:26:38 2011 +0200 @@ -91,6 +91,9 @@ Py_INCREF(Py_None); self->row_factory = Py_None; + Py_INCREF(Py_None); + self->operation_needs_transaction_callback = Py_None; + Py_INCREF(&PyUnicode_Type); self->text_factory = (PyObject*)&PyUnicode_Type; @@ -1177,6 +1180,15 @@ } } +static PyObject* pysqlite_connection_in_transaction(pysqlite_Connection* self, void* unused) +{ + if (!pysqlite_check_connection(self)) { + return NULL; + } else { + return PyBool_FromLong(!sqlite3_get_autocommit(self->db)); + } +} + static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level) { PyObject* res; @@ -1613,6 +1625,8 @@ static PyGetSetDef connection_getset[] = { {"isolation_level", (getter)pysqlite_connection_get_isolation_level, (setter)pysqlite_connection_set_isolation_level}, {"total_changes", (getter)pysqlite_connection_get_total_changes, (setter)0}, + {"in_transaction", (getter)pysqlite_connection_in_transaction, (setter)0, + PyDoc_STR("True if connection has an active transaction, False otherwise. Non-standard.")}, {NULL} }; @@ -1674,6 +1688,12 @@ {"NotSupportedError", T_OBJECT, offsetof(pysqlite_Connection, NotSupportedError), RO}, {"row_factory", T_OBJECT, offsetof(pysqlite_Connection, row_factory)}, {"text_factory", T_OBJECT, offsetof(pysqlite_Connection, text_factory)}, + {"operation_needs_transaction_callback", T_OBJECT, offsetof(pysqlite_Connection, operation_needs_transaction_callback), 0, + PyDoc_STR("If this is not None, every operation executed is first passed to this callback to\n" + "decide if it has to be run inside a transaction. It should be safe for this function\n" + "to always return True, making each operation start a transaction. Returning None leaves\n" + "the transaction state unchanged while False will commit a running transaction automatically.\n" + "Non-standard.")}, {NULL} }; diff -r 9160557984c7 Modules/_sqlite/connection.h --- a/Modules/_sqlite/connection.h Sat Mar 26 23:37:03 2011 +0100 +++ b/Modules/_sqlite/connection.h Sun Mar 27 03:26:38 2011 +0200 @@ -104,6 +104,11 @@ * destructor */ PyObject* apsw_connection; + /* Callback to decide if the passed operation requires a running transaction (True), + leaves the state untouched (None) or should trigger a commit (False). A sensible + callback is "lambda operation: True". */ + PyObject* operation_needs_transaction_callback; + /* Exception objects */ PyObject* Warning; PyObject* Error; diff -r 9160557984c7 Modules/_sqlite/cursor.c --- a/Modules/_sqlite/cursor.c Sat Mar 26 23:37:03 2011 +0100 +++ b/Modules/_sqlite/cursor.c Sun Mar 27 03:26:38 2011 +0200 @@ -76,6 +76,77 @@ } } +static int update_transaction_state_via_callback( + pysqlite_Connection* connection, PyObject* needs_transaction_callback, PyObject* operation) +{ + PyObject* result = NULL; + PyObject* want_transaction; + + want_transaction = PyObject_CallFunctionObjArgs(needs_transaction_callback, operation, NULL); + if (want_transaction) { + if (want_transaction == Py_None) { + /* Don't care - leave old transaction state */ + } else if (PyObject_IsTrue(want_transaction)) { + /* Operation requires an active transaction */ + if (!connection->inTransaction) { + result = _pysqlite_connection_begin(connection); + } + } else { + /* Commit before running this operation */ + if (connection->inTransaction) { + result = pysqlite_connection_commit(connection, NULL); + } + } + } + + Py_XDECREF(want_transaction); + Py_XDECREF(result); + return PyErr_Occurred() != NULL; +} + +static int update_transaction_state_for_operation(pysqlite_Cursor* self, PyObject* operation, int statement_type) +{ + PyObject *result; + PyObject *needs_transaction_callback; + + if (self->connection->begin_statement) { + needs_transaction_callback = self->connection->operation_needs_transaction_callback; + if (needs_transaction_callback != Py_None) { + return update_transaction_state_via_callback(self->connection, needs_transaction_callback, operation); + } + + switch (statement_type) { + case STATEMENT_SELECT: + /* Currently does not start a transaction. */ + break; + case STATEMENT_UPDATE: + case STATEMENT_DELETE: + case STATEMENT_INSERT: + case STATEMENT_REPLACE: + if (!self->connection->inTransaction) { + result = _pysqlite_connection_begin(self->connection); + if (!result) { + return 1; + } + Py_DECREF(result); + } + break; + case STATEMENT_OTHER: + /* it's a DDL statement or something similar + - we better COMMIT first so it works for all cases */ + if (self->connection->inTransaction) { + result = pysqlite_connection_commit(self->connection, NULL); + if (!result) { + return 1; + } + Py_DECREF(result); + } + break; + } + } + return 0; +} + static int pysqlite_cursor_init(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs) { pysqlite_Connection* connection; @@ -457,7 +528,6 @@ int i; int rc; PyObject* func_args; - PyObject* result; int numcols; PY_LONG_LONG lastrowid; int statement_type; @@ -596,39 +666,15 @@ pysqlite_statement_mark_dirty(self->statement); statement_type = detect_statement_type(operation_cstr); - if (self->connection->begin_statement) { - switch (statement_type) { - case STATEMENT_UPDATE: - case STATEMENT_DELETE: - case STATEMENT_INSERT: - case STATEMENT_REPLACE: - if (!self->connection->inTransaction) { - result = _pysqlite_connection_begin(self->connection); - if (!result) { - goto error; - } - Py_DECREF(result); - } - break; - case STATEMENT_OTHER: - /* it's a DDL statement or something similar - - we better COMMIT first so it works for all cases */ - if (self->connection->inTransaction) { - result = pysqlite_connection_commit(self->connection, NULL); - if (!result) { - goto error; - } - Py_DECREF(result); - } - break; - case STATEMENT_SELECT: - if (multiple) { - PyErr_SetString(pysqlite_ProgrammingError, - "You cannot execute SELECT statements in executemany()."); - goto error; - } - break; - } + + if (multiple && statement_type == STATEMENT_SELECT) { + PyErr_SetString(pysqlite_ProgrammingError, + "You cannot execute SELECT statements in executemany()."); + goto error; + } + + if (update_transaction_state_for_operation(self, operation, statement_type) != 0) { + goto error; } while (1) {