diff --git a/Include/pystate.h b/Include/pystate.h --- a/Include/pystate.h +++ b/Include/pystate.h @@ -133,6 +133,12 @@ PyAPI_FUNC(int) PyState_AddModule(PyObject*, struct PyModuleDef*); PyAPI_FUNC(int) PyState_RemoveModule(struct PyModuleDef*); #endif +#if !defined(Py_LIMITED_API) || Py_LIMITED_API+0 >= 0x03040000 +/* New in 3.4 */ +PyAPI_FUNC(PyObject *) PyState_GetModuleAttr(struct PyModuleDef *, + const char *name, + PyObject *restrict_type); +#endif PyAPI_FUNC(PyObject*) PyState_FindModule(struct PyModuleDef*); #ifndef Py_LIMITED_API PyAPI_FUNC(void) _PyState_ClearModules(void); diff --git a/Modules/_csv.c b/Modules/_csv.c --- a/Modules/_csv.c +++ b/Modules/_csv.c @@ -14,39 +14,14 @@ #include "structmember.h" -typedef struct { - PyObject *error_obj; /* CSV exception */ - PyObject *dialects; /* Dialect registry */ - long field_limit; /* max parsed field size */ -} _csvstate; - -#define _csvstate(o) ((_csvstate *)PyModule_GetState(o)) - -static int -_csv_clear(PyObject *m) -{ - Py_CLEAR(_csvstate(m)->error_obj); - Py_CLEAR(_csvstate(m)->dialects); - return 0; -} - -static int -_csv_traverse(PyObject *m, visitproc visit, void *arg) -{ - Py_VISIT(_csvstate(m)->error_obj); - Py_VISIT(_csvstate(m)->dialects); - return 0; -} - -static void -_csv_free(void *m) -{ - _csv_clear((PyObject *)m); -} - static struct PyModuleDef _csvmodule; -#define _csvstate_global ((_csvstate *)PyModule_GetState(PyState_FindModule(&_csvmodule))) +#define GET_DIALECTS() \ + PyState_GetModuleAttr(&_csvmodule, "_dialects", (PyObject *) &PyDict_Type) +#define GET_ERROR_OBJ() \ + PyState_GetModuleAttr(&_csvmodule, "Error", NULL) +#define GET_FIELD_LIMIT_OBJ() \ + PyState_GetModuleAttr(&_csvmodule, "_field_limit", (PyObject *) &PyLong_Type) typedef enum { START_RECORD, START_FIELD, ESCAPED_CHAR, IN_FIELD, @@ -122,6 +97,42 @@ static PyTypeObject Writer_Type; + +static PyObject * +format_csv_error(const char *format, ...) +{ + va_list vargs; + PyObject *string, *error_obj; + +#ifdef HAVE_STDARG_PROTOTYPES + va_start(vargs, format); +#else + va_start(vargs); +#endif + error_obj = GET_ERROR_OBJ(); + if (error_obj != NULL) { + string = PyUnicode_FromFormatV(format, vargs); + PyErr_SetObject(error_obj, string); + Py_XDECREF(string); + Py_DECREF(error_obj); + } + va_end(vargs); + return NULL; +} + +static long +get_field_limit(void) +{ + PyObject *field_limit = GET_FIELD_LIMIT_OBJ(); + if (field_limit != NULL) { + long x = PyLong_AsLong(field_limit); + Py_DECREF(field_limit); + return x; + } + return -1; +} + + /* * DIALECT class */ @@ -130,11 +141,15 @@ get_dialect_from_registry(PyObject * name_obj) { PyObject *dialect_obj; + PyObject *dialects = GET_DIALECTS(); - dialect_obj = PyDict_GetItem(_csvstate_global->dialects, name_obj); + if (dialects == NULL) + return NULL; + dialect_obj = PyDict_GetItem(dialects, name_obj); + Py_DECREF(dialects); if (dialect_obj == NULL) { if (!PyErr_Occurred()) - PyErr_Format(_csvstate_global->error_obj, "unknown dialect"); + format_csv_error("unknown dialect"); } else Py_INCREF(dialect_obj); @@ -579,9 +594,12 @@ static int parse_add_char(ReaderObj *self, Py_UCS4 c) { - if (self->field_len >= _csvstate_global->field_limit) { - PyErr_Format(_csvstate_global->error_obj, "field larger than field limit (%ld)", - _csvstate_global->field_limit); + long field_limit = get_field_limit(); + if (field_limit == -1 && PyErr_Occurred()) + return -1; + if (self->field_len >= field_limit) { + format_csv_error("field larger than field limit (%ld)", + field_limit); return -1; } if (self->field_len == self->field_size && !parse_grow_buff(self)) @@ -749,9 +767,9 @@ } else { /* illegal */ - PyErr_Format(_csvstate_global->error_obj, "'%c' expected after '%c'", - dialect->delimiter, - dialect->quotechar); + format_csv_error("'%c' expected after '%c'", + dialect->delimiter, + dialect->quotechar); return -1; } break; @@ -762,7 +780,9 @@ else if (c == '\0') self->state = START_RECORD; else { - PyErr_Format(_csvstate_global->error_obj, "new-line character seen in unquoted field - do you need to open the file in universal-newline mode?"); + format_csv_error("new-line character seen in unquoted field - " + "do you need to open the file in " + "universal-newline mode?"); return -1; } break; @@ -802,21 +822,19 @@ /* End of input OR exception */ if (!PyErr_Occurred() && (self->field_len != 0 || self->state == IN_QUOTED_FIELD)) { - if (self->dialect->strict) - PyErr_SetString(_csvstate_global->error_obj, - "unexpected end of data"); + if (self->dialect->strict) { + format_csv_error("unexpected end of data"); + } else if (parse_save_field(self) >= 0) break; } return NULL; } if (!PyUnicode_Check(lineobj)) { - PyErr_Format(_csvstate_global->error_obj, - "iterator should return strings, " - "not %.200s " - "(did you open the file in text mode?)", - lineobj->ob_type->tp_name - ); + format_csv_error("iterator should return strings, " + "not %.200s " + "(did you open the file in text mode?)", + lineobj->ob_type->tp_name); Py_DECREF(lineobj); return NULL; } @@ -833,8 +851,7 @@ c = PyUnicode_READ(kind, data, pos); if (c == '\0') { Py_DECREF(lineobj); - PyErr_Format(_csvstate_global->error_obj, - "line contains NULL byte"); + format_csv_error("line contains NULL byte"); goto err; } if (parse_process_char(self, c) < 0) { @@ -1049,8 +1066,7 @@ } if (want_escape) { if (!dialect->escapechar) { - PyErr_Format(_csvstate_global->error_obj, - "need to escape, but no escapechar set"); + format_csv_error("need to escape, but no escapechar set"); return -1; } ADDCH(dialect->escapechar); @@ -1065,8 +1081,7 @@ */ if (i == 0 && quote_empty) { if (dialect->quoting == QUOTE_NONE) { - PyErr_Format(_csvstate_global->error_obj, - "single empty field record must be quoted"); + format_csv_error("single empty field record must be quoted"); return -1; } else @@ -1184,7 +1199,7 @@ PyObject *line, *result; if (!PySequence_Check(seq)) - return PyErr_Format(_csvstate_global->error_obj, "sequence expected"); + return format_csv_error("sequence expected"); len = PySequence_Length(seq); if (len < 0) @@ -1410,14 +1425,20 @@ static PyObject * csv_list_dialects(PyObject *module, PyObject *args) { - return PyDict_Keys(_csvstate_global->dialects); + PyObject *keys; + PyObject *dialects = GET_DIALECTS(); + if (dialects == NULL) + return NULL; + keys = PyDict_Keys(dialects); + Py_DECREF(dialects); + return keys; } static PyObject * csv_register_dialect(PyObject *module, PyObject *args, PyObject *kwargs) { PyObject *name_obj, *dialect_obj = NULL; - PyObject *dialect; + PyObject *dialects = NULL, *dialect = NULL, *res = NULL; if (!PyArg_UnpackTuple(args, "", 1, 2, &name_obj, &dialect_obj)) return NULL; @@ -1428,25 +1449,34 @@ } if (PyUnicode_READY(name_obj) == -1) return NULL; + dialects = GET_DIALECTS(); + if (dialects == NULL) + return NULL; dialect = _call_dialect(dialect_obj, kwargs); if (dialect == NULL) - return NULL; - if (PyDict_SetItem(_csvstate_global->dialects, name_obj, dialect) < 0) { - Py_DECREF(dialect); - return NULL; - } - Py_DECREF(dialect); - Py_INCREF(Py_None); - return Py_None; + goto error; + if (PyDict_SetItem(dialects, name_obj, dialect) < 0) + goto error; + res = Py_None; +error: + Py_XDECREF(dialects); + Py_XDECREF(dialect); + Py_XINCREF(res); + return res; } static PyObject * csv_unregister_dialect(PyObject *module, PyObject *name_obj) { - if (PyDict_DelItem(_csvstate_global->dialects, name_obj) < 0) - return PyErr_Format(_csvstate_global->error_obj, "unknown dialect"); - Py_INCREF(Py_None); - return Py_None; + PyObject *dialects = GET_DIALECTS(); + if (dialects == NULL) + return NULL; + if (PyDict_DelItem(dialects, name_obj) < 0) { + Py_DECREF(dialects); + return format_csv_error("unknown dialect"); + } + Py_DECREF(dialects); + Py_RETURN_NONE; } static PyObject * @@ -1458,24 +1488,31 @@ static PyObject * csv_field_size_limit(PyObject *module, PyObject *args) { - PyObject *new_limit = NULL; - long old_limit = _csvstate_global->field_limit; + PyObject *old_limit, *new_limit = NULL; if (!PyArg_UnpackTuple(args, "field_size_limit", 0, 1, &new_limit)) return NULL; + old_limit = GET_FIELD_LIMIT_OBJ(); + if (old_limit == NULL) + return NULL; if (new_limit != NULL) { + PyObject *module = PyState_FindModule(&_csvmodule); + if (module == NULL) { + Py_DECREF(old_limit); + return NULL; + } if (!PyLong_CheckExact(new_limit)) { PyErr_Format(PyExc_TypeError, "limit must be an integer"); + Py_DECREF(old_limit); return NULL; } - _csvstate_global->field_limit = PyLong_AsLong(new_limit); - if (_csvstate_global->field_limit == -1 && PyErr_Occurred()) { - _csvstate_global->field_limit = old_limit; + if (PyObject_SetAttrString(module, "_field_limit", new_limit)) { + Py_DECREF(old_limit); return NULL; } } - return PyLong_FromLong(old_limit); + return old_limit; } /* @@ -1614,18 +1651,18 @@ PyModuleDef_HEAD_INIT, "_csv", csv_module_doc, - sizeof(_csvstate), + 0, csv_methods, NULL, - _csv_traverse, - _csv_clear, - _csv_free + NULL, + NULL, + NULL, }; PyMODINIT_FUNC PyInit__csv(void) { - PyObject *module; + PyObject *module, *dialects, *error_obj, *field_limit; StyleDesc *style; if (PyType_Ready(&Dialect_Type) < 0) @@ -1648,15 +1685,19 @@ return NULL; /* Set the field limit */ - _csvstate(module)->field_limit = 128 * 1024; - /* Do I still need to add this var to the Module Dict? */ + field_limit = PyLong_FromLong(128 * 1024); + if (field_limit == NULL) + return NULL; + Py_INCREF(field_limit); + if (PyModule_AddObject(module, "_field_limit", field_limit)) + return NULL; /* Add _dialects dictionary */ - _csvstate(module)->dialects = PyDict_New(); - if (_csvstate(module)->dialects == NULL) + dialects = PyDict_New(); + if (dialects == NULL) return NULL; - Py_INCREF(_csvstate(module)->dialects); - if (PyModule_AddObject(module, "_dialects", _csvstate(module)->dialects)) + Py_INCREF(dialects); + if (PyModule_AddObject(module, "_dialects", dialects)) return NULL; /* Add quote styles into dictionary */ @@ -1672,10 +1713,10 @@ return NULL; /* Add the CSV exception object to the module. */ - _csvstate(module)->error_obj = PyErr_NewException("_csv.Error", NULL, NULL); - if (_csvstate(module)->error_obj == NULL) + error_obj = PyErr_NewException("_csv.Error", NULL, NULL); + if (error_obj == NULL) return NULL; - Py_INCREF(_csvstate(module)->error_obj); - PyModule_AddObject(module, "Error", _csvstate(module)->error_obj); + Py_INCREF(error_obj); + PyModule_AddObject(module, "Error", error_obj); return module; } diff --git a/Python/pystate.c b/Python/pystate.c --- a/Python/pystate.c +++ b/Python/pystate.c @@ -260,6 +260,39 @@ return res==Py_None ? NULL : res; } +PyObject * +PyState_GetModuleAttr(struct PyModuleDef *def, + const char *name, + PyObject *restrict_type) +{ + PyObject *module, *value; + module = PyState_FindModule(def); + if (module == NULL) { + PyErr_Format(PyExc_ImportError, + "Module %s not loaded", def->m_name); + return NULL; + } + Py_INCREF(module); + value = PyObject_GetAttrString(module, name); + Py_DECREF(module); + if (value == NULL) + return NULL; + if (restrict_type != NULL) { + int r = PyObject_IsInstance(value, restrict_type); + if (r == 0) + PyErr_Format( + PyExc_TypeError, + "Attribute '%s' of module '%s' should be " + "an instance of %R, not %R", + name, def->m_name, restrict_type, Py_TYPE(value)); + if (r <= 0) { + Py_DECREF(value); + return NULL; + } + } + return value; +} + int _PyState_AddModule(PyObject* module, struct PyModuleDef* def) {