From 4659807f32a8b15975048998db7123e8579125ea Mon Sep 17 00:00:00 2001 From: "Erlend E. Aasland" Date: Sun, 24 May 2020 22:50:36 +0200 Subject: [PATCH] Add support for sqlite3 aggregate window functions --- Doc/library/sqlite3.rst | 8 +++ Modules/_sqlite/connection.c | 113 +++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/Doc/library/sqlite3.rst b/Doc/library/sqlite3.rst index ccb82278bd..61f0eb9223 100644 --- a/Doc/library/sqlite3.rst +++ b/Doc/library/sqlite3.rst @@ -375,6 +375,14 @@ Connection Objects The ``finalize`` method can return any of the types supported by SQLite: bytes, str, int, float and ``None``. + For SQLite 3.25.0 or higher, you can implement ``value`` and ``inverse``methods in order to + create aggregate window functions. See + `User-Defined Aggregate Window Functions ` + in the SQLite documentation. + + .. versionchanged:: 3.10 + Support aggregate window functions. + Example: .. literalinclude:: ../includes/sqlite3/mysumaggr.py diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 958be7d869..04203df667 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -43,6 +43,10 @@ #define HAVE_BACKUP_API #endif +#if SQLITE_VERSION_NUMBER >= 3025003 +#define HAVE_WINDOW_FUNCTIONS +#endif + _Py_IDENTIFIER(cursor); static const char * const begin_statements[] = { @@ -861,6 +865,111 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec Py_RETURN_NONE; } +#ifdef HAVE_WINDOW_FUNCTIONS +/* + * Regarding the 'inverse' aggregate callback: + * This method is only required window aggregate functions, not legacy aggregate function + * implementations. It is invoked to remove a row from the current window. The function arguments, + * if any, correspond to the row being removed. + */ +static void _pysqlite_inverse_callback(sqlite3_context* context, int argc, sqlite3_value** params) +{ + PyObject* method = NULL; + PyObject* method_args = NULL; + PyObject* method_result = NULL; + PyObject** aggregate_instance = NULL; + + PyGILState_STATE gilstate = PyGILState_Ensure(); + if (!(aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)))) { + _sqlite3_result_error(context, "user-defined aggregate's '__init__' method raised error", -1); + } else if (!(method = PyObject_GetAttrString(*aggregate_instance, "inverse"))) { + _sqlite3_result_error(context, "user-defined aggregate's 'inverse' method not defined", -1); + } else if (!(method_args = _pysqlite_build_py_params(context, argc, params))) { + _sqlite3_result_error(context, "unable to build arguments for user-defined aggregate 'inverse' method", -1); + } else if (!(method_result = PyObject_CallObject(method, method_args))) { + _sqlite3_result_error(context, "user-defined aggregate's 'inverse' method raised error", -1); + } + + if (PyErr_Occurred()) { + if (_pysqlite_enable_callback_tracebacks) { + PyErr_Print(); + } else { + PyErr_Clear(); + } + } + Py_XDECREF(method_result); + PyGILState_Release(gilstate); +} + +/* + * Regarding the 'value' aggregate callback: + * This method is only required by window aggregate functions, not legacy aggregate function + * implementations. It is invoked to return the current value of the aggregate. + */ +static void _pysqlite_value_callback(sqlite3_context* context) +{ + _Py_IDENTIFIER(value); + PyObject* method_result = NULL; + PyObject** aggregate_instance = NULL; + + PyGILState_STATE gilstate = PyGILState_Ensure(); + if (!(aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)))) { + _sqlite3_result_error(context, "user-defined aggregate's '__init__' method raised error", -1); + } else if (!(method_result = _PyObject_CallMethodIdNoArgs(*aggregate_instance, &PyId_value))) { + _sqlite3_result_error(context, "user-defined aggregate's 'value' method raised error", -1); + } else if (_pysqlite_set_result(context, method_result) != 0) { + _sqlite3_result_error(context, "unable to set result from user-defined aggregate's 'value' method", -1); + } + + if (PyErr_Occurred()) { + if (_pysqlite_enable_callback_tracebacks) { + PyErr_Print(); + } else { + PyErr_Clear(); + } + } + Py_XDECREF(method_result); + PyGILState_Release(gilstate); +} + +PyObject* pysqlite_connection_create_window_function(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) +{ + PyObject* aggregate_class; + int n_arg; + char* name; + static char *kwlist[] = { "name", "n_arg", "aggregate_class", NULL }; + int rc; + + if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { + return NULL; + } + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "siO:create_window_function", + kwlist, &name, &n_arg, &aggregate_class)) { + return NULL; + } + Py_INCREF(aggregate_class); + rc = sqlite3_create_window_function(self->db, + name, + n_arg, + SQLITE_UTF8, + (void*)aggregate_class, + &_pysqlite_step_callback, + &_pysqlite_final_callback, + &_pysqlite_value_callback, + &_pysqlite_inverse_callback, + &_destructor); // will decref func + if (rc != SQLITE_OK) { + Py_DECREF(aggregate_class); + + /* Workaround for SQLite bug: no error code or string is available here */ + PyErr_SetString(pysqlite_OperationalError, "Error creating window function"); + return NULL; + } + Py_RETURN_NONE; +} +#endif + PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) { PyObject* aggregate_class; @@ -1778,6 +1887,10 @@ static PyMethodDef connection_methods[] = { PyDoc_STR("Creates a new function. Non-standard.")}, {"create_aggregate", (PyCFunction)(void(*)(void))pysqlite_connection_create_aggregate, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("Creates a new aggregate. Non-standard.")}, +#ifdef HAVE_WINDOW_FUNCTIONS + {"create_window_function", (PyCFunction)(void(*)(void))pysqlite_connection_create_window_function, METH_VARARGS|METH_KEYWORDS, + PyDoc_STR("Creates a new window function. Non-standard.")}, +#endif {"set_authorizer", (PyCFunction)(void(*)(void))pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("Sets authorizer callback. Non-standard.")}, #ifdef HAVE_LOAD_EXTENSION -- 2.27.0.rc1