diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index 39c9bf5b61..69ffc354c1 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -92,6 +92,10 @@ def test_shared_cache_deprecated(self): sqlite.enable_shared_cache(enable) self.assertIn("dbapi.py", cm.filename) + def test_complete_statement(self): + self.assertFalse(sqlite.complete_statement("select t")) + self.assertTrue(sqlite.complete_statement("create table t(t);")) + class ConnectionTests(unittest.TestCase): @@ -191,6 +195,24 @@ def test_open_uri(self): with self.assertRaises(sqlite.OperationalError): cx.execute('insert into test(id) values(1)') + def test_interrupt_on_closed_db(self): + cx = sqlite.connect(":memory:") + cx.close() + with self.assertRaises(sqlite.ProgrammingError): + cx.interrupt() + + def test_interrupt(self): + self.assertIsNone(self.cx.interrupt()) + + def test_drop_unused_refs(self): + for n in range(500): + cu = self.cx.execute(f"select {n}") + self.assertEqual(cu.fetchone()[0], n) + + def test_database_keyword(self): + with sqlite.connect(database=":memory:") as cx: + self.assertEqual(type(cx), sqlite.Connection) + class CursorTests(unittest.TestCase): def setUp(self): @@ -522,6 +544,10 @@ def test_last_row_id_insert_o_r(self): ] self.assertEqual(results, expected) + def test_same_query_in_multiple_cursors(self): + cursors = [self.cx.execute("select 1") for _ in range(3)] + for cu in cursors: + self.assertEqual(cu.fetchall(), [(1,)]) class ThreadTests(unittest.TestCase): def setUp(self): @@ -680,6 +706,21 @@ def run(cur, errors): if len(errors) > 0: self.fail("\n".join(errors)) + def test_dont_check_same_thread(self): + def run(con, err): + try: + cur = con.execute("select 1") + except sqlite.Error: + err.append("multi-threading not allowed") + + con = sqlite.connect(":memory:", check_same_thread=False) + err = [] + t = threading.Thread(target=run, kwargs={"con": con, "err": err}) + t.start() + t.join() + self.assertEqual(len(err), 0, "\n".join(err)) + + class ConstructorTests(unittest.TestCase): def test_date(self): d = sqlite.Date(2004, 10, 28) diff --git a/Lib/sqlite3/test/factory.py b/Lib/sqlite3/test/factory.py index 8764284975..7faa9ac8c1 100644 --- a/Lib/sqlite3/test/factory.py +++ b/Lib/sqlite3/test/factory.py @@ -123,6 +123,8 @@ def test_sqlite_row_index(self): row[-3] with self.assertRaises(IndexError): row[2**1000] + with self.assertRaises(IndexError): + row[complex()] # index must be int or string def test_sqlite_row_index_unicode(self): self.con.row_factory = sqlite.Row diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 2370dd1693..0b4a6a87b4 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -356,9 +356,9 @@ def test_cursor_description_cte(self): class ObjectAdaptationTests(unittest.TestCase): + @staticmethod def cast(obj): return float(obj) - cast = staticmethod(cast) def setUp(self): self.con = sqlite.connect(":memory:") @@ -379,6 +379,43 @@ def test_caster_is_used(self): val = self.cur.fetchone()[0] self.assertEqual(type(val), float) + def test_missing_adapter(self): + with self.assertRaises(sqlite.ProgrammingError): + sqlite.adapt(1.) # No float adapter registered + + def test_missing_protocol(self): + with self.assertRaises(sqlite.ProgrammingError): + sqlite.adapt(1, None) + + def test_defect_proto(self): + class DefectProto(): + def __adapt__(self): + return None + with self.assertRaises(sqlite.ProgrammingError): + sqlite.adapt(1., DefectProto) + + def test_defect_self_adapt(self): + class DefectSelfAdapt(float): + def __conform__(self, _): + return None + with self.assertRaises(sqlite.ProgrammingError): + sqlite.adapt(DefectSelfAdapt(1.)) + + def test_custom_proto(self): + class CustomProto(): + def __adapt__(self): + return "adapted" + self.assertEqual(sqlite.adapt(1., CustomProto), "adapted") + + def test_adapt(self): + val = 42 + self.assertEqual(float(val), sqlite.adapt(val)) + + def test_adapt_alt(self): + alt = "other" + self.assertEqual(alt, sqlite.adapt(1., None, alt)) + + @unittest.skipUnless(zlib, "requires zlib") class BinaryConverterTests(unittest.TestCase): def convert(s): diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 749ea049c8..0ed4a83ec4 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -21,10 +21,35 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. +import contextlib +import functools +import io import unittest import unittest.mock import sqlite3 as sqlite +def with_tracebacks(strings): + """Convenience decorator for testing callback tracebacks.""" + strings.append('Traceback') + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + # First, run the test with traceback enabled. + sqlite.enable_callback_tracebacks(True) + buf = io.StringIO() + with contextlib.redirect_stderr(buf): + func(self, *args, **kwargs) + tb = buf.getvalue() + for s in strings: + self.assertIn(s, tb) + + # Then run the test with traceback disabled. + sqlite.enable_callback_tracebacks(False) + func(self, *args, **kwargs) + return wrapper + return decorator + def func_returntext(): return "foo" def func_returnunicode(): @@ -227,6 +252,7 @@ def test_func_return_long_long(self): val = cur.fetchone()[0] self.assertEqual(val, 1<<31) + @with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError']) def test_func_exception(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -364,6 +390,7 @@ def test_aggr_no_finalize(self): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") + @with_tracebacks(['__init__', '5/0', 'ZeroDivisionError']) def test_aggr_exception_in_init(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -371,6 +398,7 @@ def test_aggr_exception_in_init(self): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") + @with_tracebacks(['step', '5/0', 'ZeroDivisionError']) def test_aggr_exception_in_step(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -378,6 +406,7 @@ def test_aggr_exception_in_step(self): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") + @with_tracebacks(['finalize', '5/0', 'ZeroDivisionError']) def test_aggr_exception_in_finalize(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -479,6 +508,14 @@ def authorizer_cb(action, arg1, arg2, dbname, source): raise ValueError return sqlite.SQLITE_OK + @with_tracebacks(['authorizer_cb', 'ValueError']) + def test_table_access(self): + super().test_table_access() + + @with_tracebacks(['authorizer_cb', 'ValueError']) + def test_column_access(self): + super().test_table_access() + class AuthorizerIllegalTypeTests(AuthorizerTests): @staticmethod def authorizer_cb(action, arg1, arg2, dbname, source):