diff -r 86ddd32068a1 Doc/library/sqlite3.rst --- a/Doc/library/sqlite3.rst Tue Jan 03 11:20:15 2017 +0200 +++ b/Doc/library/sqlite3.rst Fri Jan 06 00:12:34 2017 +0200 @@ -345,6 +345,10 @@ .. literalinclude:: ../includes/sqlite3/md5func.py + .. versionchanged:: 3.7 + ``TEXT`` data types params that are accpected by *func* are encoded using + :attr:`~Connection.text_factory`. + .. method:: create_aggregate(name, num_params, aggregate_class) @@ -362,6 +366,10 @@ .. literalinclude:: ../includes/sqlite3/mysumaggr.py + .. versionchanged:: 3.7 + ``TEXT`` data types params that are accpected by the aggregate class + ``step`` method are encoded using :attr:`~Connection.text_factory`. + .. method:: create_collation(name, callable) diff -r 86ddd32068a1 Lib/sqlite3/test/userfunctions.py --- a/Lib/sqlite3/test/userfunctions.py Tue Jan 03 11:20:15 2017 +0200 +++ b/Lib/sqlite3/test/userfunctions.py Fri Jan 06 00:12:34 2017 +0200 @@ -452,6 +452,53 @@ return sqlite.SQLITE_OK +class UserFunctionTestsTextFactory(unittest.TestCase): + def setUp(self): + self.con = sqlite.connect(":memory:") + self.con.create_function("test", 1, lambda x: x) + + class AggrSum: + def __init__(self): + self.val = 0.0 + + def step(self, val): + self.val += val + + def finalize(self): + return self.val + + self.con.create_aggregate("test_aggr", 1, AggrSum) + + def tearDown(self): + self.con.close() + + def test_function_use_text_factory(self): + self.con.text_factory = lambda x: 5 + cur = self.con.cursor() + cur.execute("select test(?)", ("abc",)) + val = cur.fetchone()[0] + self.assertEqual(val, 5) + + def test_change_text_factory(self): + self.con.text_factory = lambda x: 5 + cur = self.con.cursor() + cur.execute("select test(?)", ("abc",)) + val = cur.fetchone()[0] + self.assertEqual(val, 5) + self.con.text_factory = lambda x: 6 + cur.execute("select test(?)", ("abc",)) + val = cur.fetchone()[0] + self.assertEqual(val, 6) + + def test_aggrregate_use_text_factory(self): + self.con.text_factory = lambda x: 5 + cur = self.con.cursor() + cur.execute("select test_aggr(?)", ("abc",)) + val = cur.fetchone()[0] + self.assertEqual(val, 5) + + + def suite(): function_suite = unittest.makeSuite(FunctionTests, "Check") aggregate_suite = unittest.makeSuite(AggregateTests, "Check") @@ -463,6 +510,7 @@ unittest.makeSuite(AuthorizerRaiseExceptionTests), unittest.makeSuite(AuthorizerIllegalTypeTests), unittest.makeSuite(AuthorizerLargeIntegerTests), + unittest.makeSuite(UserFunctionTestsTextFactory), )) def test(): diff -r 86ddd32068a1 Modules/_sqlite/connection.c --- a/Modules/_sqlite/connection.c Tue Jan 03 11:20:15 2017 +0200 +++ b/Modules/_sqlite/connection.c Fri Jan 06 00:12:34 2017 +0200 @@ -530,7 +530,7 @@ return 0; } -PyObject* _pysqlite_build_py_params(sqlite3_context *context, int argc, sqlite3_value** argv) +PyObject* _pysqlite_build_py_params(sqlite3_context *context, int argc, sqlite3_value** argv, PyObject* text_factory) { PyObject* args; int i; @@ -555,7 +555,15 @@ break; case SQLITE_TEXT: val_str = (const char*)sqlite3_value_text(cur_value); - cur_py_value = PyUnicode_FromString(val_str); + if (text_factory == (PyObject*)&PyUnicode_Type) { + cur_py_value = PyUnicode_FromString(val_str); + } else if (text_factory == (PyObject*)&PyBytes_Type) { + cur_py_value = PyBytes_FromString(val_str); + } else if (text_factory == (PyObject*)&PyByteArray_Type) { + cur_py_value = PyByteArray_FromStringAndSize(val_str, strlen(val_str)); + } else { + cur_py_value = PyObject_CallFunction(text_factory, "s", val_str); + } /* TODO: have a way to show errors here */ if (!cur_py_value) { PyErr_Clear(); @@ -586,12 +594,19 @@ return args; } +struct _pysqlite_func_callback_user_data { + PyObject* py_func; + pysqlite_Connection* connection; +}; + void _pysqlite_func_callback(sqlite3_context* context, int argc, sqlite3_value** argv) { PyObject* args; PyObject* py_func; PyObject* py_retval = NULL; int ok; + struct _pysqlite_func_callback_user_data* user_data; + PyObject* text_factory; #ifdef WITH_THREAD PyGILState_STATE threadstate; @@ -599,9 +614,11 @@ threadstate = PyGILState_Ensure(); #endif - py_func = (PyObject*)sqlite3_user_data(context); + user_data = (struct _pysqlite_func_callback_user_data*)sqlite3_user_data(context); + py_func = user_data->py_func; + text_factory = user_data->connection->text_factory; - args = _pysqlite_build_py_params(context, argc, argv); + args = _pysqlite_build_py_params(context, argc, argv, text_factory); if (args) { py_retval = PyObject_CallObject(py_func, args); Py_DECREF(args); @@ -633,6 +650,8 @@ PyObject* aggregate_class; PyObject** aggregate_instance; PyObject* stepmethod = NULL; + struct _pysqlite_func_callback_user_data* user_data; + PyObject* text_factory; #ifdef WITH_THREAD PyGILState_STATE threadstate; @@ -640,7 +659,9 @@ threadstate = PyGILState_Ensure(); #endif - aggregate_class = (PyObject*)sqlite3_user_data(context); + user_data = (struct _pysqlite_func_callback_user_data*)sqlite3_user_data(context); + aggregate_class = user_data->py_func; + text_factory = user_data->connection->text_factory; aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); @@ -664,7 +685,7 @@ goto error; } - args = _pysqlite_build_py_params(context, argc, params); + args = _pysqlite_build_py_params(context, argc, params, text_factory); if (!args) { goto error; } @@ -825,6 +846,7 @@ char* name; int narg; int rc; + struct _pysqlite_func_callback_user_data* user_data; if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { return NULL; @@ -836,7 +858,16 @@ return NULL; } - rc = sqlite3_create_function(self->db, name, narg, SQLITE_UTF8, (void*)func, _pysqlite_func_callback, NULL, NULL); + user_data = PyMem_Malloc(sizeof(*user_data)); + if (NULL == user_data) { + PyErr_NoMemory(); + return NULL; + } + user_data->py_func = func; + user_data->connection = self; + + rc = sqlite3_create_function_v2(self->db, name, narg, SQLITE_UTF8, (void*)user_data, + _pysqlite_func_callback, NULL, NULL, PyMem_Free); if (rc != SQLITE_OK) { /* Workaround for SQLite bug: no error code or string is available here */ @@ -858,6 +889,7 @@ char* name; static char *kwlist[] = { "name", "n_arg", "aggregate_class", NULL }; int rc; + struct _pysqlite_func_callback_user_data* user_data; if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { return NULL; @@ -868,7 +900,16 @@ return NULL; } - rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback); + user_data = PyMem_Malloc(sizeof(*user_data)); + if (NULL == user_data) { + PyErr_NoMemory(); + return NULL; + } + user_data->py_func = aggregate_class; + user_data->connection = self; + + rc = sqlite3_create_function_v2(self->db, name, n_arg, SQLITE_UTF8, (void*)user_data, NULL, + &_pysqlite_step_callback, &_pysqlite_final_callback, PyMem_Free); if (rc != SQLITE_OK) { /* Workaround for SQLite bug: no error code or string is available here */ PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate");