diff --git a/Lib/sqlite3/test/backup.py b/Lib/sqlite3/test/backup.py new file mode 100644 --- /dev/null +++ b/Lib/sqlite3/test/backup.py @@ -0,0 +1,104 @@ +import sqlite3 as sqlite +from tempfile import NamedTemporaryFile +import unittest + +class BackupTests(unittest.TestCase): + def setUp(self): + cx = self.cx = sqlite.connect(":memory:") + cx.execute('CREATE TABLE foo (key INTEGER)') + cx.executemany('INSERT INTO foo (key) VALUES (?)', [(3,), (4,)]) + cx.commit() + + def tearDown(self): + self.cx.close() + + def testBackup(self, bckfn): + cx = sqlite.connect(bckfn) + result = cx.execute("SELECT key FROM foo ORDER BY key").fetchall() + self.assertEqual(result[0][0], 3) + self.assertEqual(result[1][0], 4) + + def CheckSimple(self): + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name) + self.testBackup(bckfn.name) + + def CheckProgress(self): + journal = [] + + def progress(remaining, total): + journal.append(remaining) + + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, 1, progress) + self.testBackup(bckfn.name) + + self.assertEqual(len(journal), 2) + self.assertEqual(journal[0], 1) + self.assertEqual(journal[1], 0) + + def CheckProgressAllPagesAtOnce_0(self): + journal = [] + + def progress(remaining, total): + journal.append(remaining) + + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, 0, progress) + self.testBackup(bckfn.name) + + self.assertEqual(len(journal), 1) + self.assertEqual(journal[0], 0) + + def CheckProgressAllPagesAtOnce_1(self): + journal = [] + + def progress(remaining, total): + journal.append(remaining) + + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, -1, progress) + self.testBackup(bckfn.name) + + self.assertEqual(len(journal), 1) + self.assertEqual(journal[0], 0) + + def CheckNonCallableProgress(self): + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + with self.assertRaises(TypeError) as cm: + self.cx.backup(bckfn.name, 1, 'bar') + self.assertEqual(str(cm.exception), 'progress argument must be a callable') + + def CheckModifyingProgress(self): + journal = [] + + def progress(remaining, total): + if not journal: + self.cx.execute('INSERT INTO foo (key) VALUES (?)', (remaining+1000,)) + self.cx.commit() + journal.append(remaining) + + with NamedTemporaryFile(suffix='.sqlite') as bckfn: + self.cx.backup(bckfn.name, 1, progress) + self.testBackup(bckfn.name) + + cx = sqlite.connect(bckfn.name) + result = cx.execute("SELECT key FROM foo" + " WHERE key >= 1000" + " ORDER BY key").fetchall() + self.assertEqual(result[0][0], 1001) + + self.assertEqual(len(journal), 3) + self.assertEqual(journal[0], 1) + self.assertEqual(journal[1], 1) + self.assertEqual(journal[2], 0) + +def suite(): + return unittest.TestSuite(unittest.makeSuite(BackupTests, "Check")) + +def test(): + runner = unittest.TextTestRunner() + runner.run(suite()) + +if __name__ == "__main__": + test() diff --git a/Lib/test/test_sqlite.py b/Lib/test/test_sqlite.py --- a/Lib/test/test_sqlite.py +++ b/Lib/test/test_sqlite.py @@ -7,7 +7,7 @@ import sqlite3 from sqlite3.test import (dbapi, types, userfunctions, factory, transactions, hooks, regression, - dump) + dump, backup) def load_tests(*args): if test.support.verbose: @@ -18,7 +18,7 @@ def load_tests(*args): userfunctions.suite(), factory.suite(), transactions.suite(), hooks.suite(), regression.suite(), - dump.suite()]) + dump.suite(), backup.suite()]) if __name__ == "__main__": unittest.main() diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1496,6 +1496,71 @@ finally: } static PyObject * +pysqlite_connection_backup(pysqlite_Connection* self, PyObject* args) +{ + char* filename; + int pages = -1; + PyObject* progress = Py_None; + int rc; + sqlite3 *bckconn; + sqlite3_backup *bckhandle; + + if (!PyArg_ParseTuple(args, "s|iO:backup(filename, pages, progress)", + &filename, &pages, &progress)) { + return NULL; + } + + if (progress != Py_None && !PyCallable_Check(progress)) { + PyErr_SetString(PyExc_TypeError, "progress argument must be a callable"); + return NULL; + } + + if (pages == 0) { + pages = -1; + } + + rc = sqlite3_open(filename, &bckconn); + if (rc == SQLITE_OK) { + bckhandle = sqlite3_backup_init(bckconn, "main", self->db, "main"); + if (bckhandle) { + do { + rc = sqlite3_backup_step(bckhandle, pages); + + if (progress != Py_None) { + if (!PyObject_CallFunction(progress, "ii", + sqlite3_backup_remaining(bckhandle), + sqlite3_backup_pagecount(bckhandle))) { + /* User's callback raised an error: interrupt the loop and + propagate it. */ + rc = -1; + } + } + + /* Sleep for 250ms if there are still further pages to copy */ + if (rc == SQLITE_OK || rc == SQLITE_BUSY || rc == SQLITE_LOCKED) { + sqlite3_sleep(250); + } + } while (rc == SQLITE_OK || rc == SQLITE_BUSY || rc == SQLITE_LOCKED); + + sqlite3_backup_finish(bckhandle); + } + + if (rc != -1) { + rc = _pysqlite_seterror(bckconn, NULL); + } + } + + sqlite3_close(bckconn); + + if (rc != 0) { + /* TODO: should the (probably incomplete/invalid) backup be removed here? */ + return NULL; + } else { + Py_RETURN_NONE; + } +} + +static PyObject * pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args) { PyObject* callable; @@ -1664,6 +1729,8 @@ static PyMethodDef connection_methods[] = { PyDoc_STR("Abort any pending database operation. Non-standard.")}, {"iterdump", (PyCFunction)pysqlite_connection_iterdump, METH_NOARGS, PyDoc_STR("Returns iterator to the dump of the database in an SQL text format. Non-standard.")}, + {"backup", (PyCFunction)pysqlite_connection_backup, METH_VARARGS, + PyDoc_STR("Execute a backup of the database. Non-standard.")}, {"__enter__", (PyCFunction)pysqlite_connection_enter, METH_NOARGS, PyDoc_STR("For context manager. Non-standard.")}, {"__exit__", (PyCFunction)pysqlite_connection_exit, METH_VARARGS,