diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -2098,6 +2098,36 @@ ERRORS.codes[ERRORS.XML_ERROR_SYNTAX]) +class KeywordArgsTest(unittest.TestCase): + def test_issue14818(self): + x = ET.XML("foo") + self.assertEqual(x.find('a', None), + x.find(path='a', namespaces=None)) + self.assertEqual(x.findtext('a', None, None), + x.findtext(path='a', default=None, namespaces=None)) + self.assertEqual(x.findall('a', None), + x.findall(path='a', namespaces=None)) + self.assertEqual(list(x.iterfind('a', None)), + list(x.iterfind(path='a', namespaces=None))) + + self.assertEqual(ET.Element('a').attrib, {}) + elements = [ + ET.Element('a', dict(href="#", id="foo")), + ET.Element('a', attrib=dict(href="#", id="foo")), + ET.Element('a', dict(href="#"), id="foo"), + ET.Element('a', href="#", id="foo"), + ET.Element('a', dict(href="#", id="foo"), href="#", id="foo"), + ] + for e in elements: + self.assertEqual(e.tag, 'a') + self.assertEqual(e.attrib, dict(href="#", id="foo")) + + with self.assertRaisesRegex(TypeError, 'must be dict, not str'): + ET.Element('a', "I'm not a dict") + with self.assertRaisesRegex(TypeError, 'must be dict, not str'): + ET.Element('a', attrib="I'm not a dict") + + # -------------------------------------------------------------------- @@ -2153,7 +2183,8 @@ StringIOTest, ParseErrorTest, ElementTreeTest, - TreeBuilderTest] + TreeBuilderTest, + KeywordArgsTest] if module is pyET: # Run the tests specific to the Python implementation test_classes += [NoAcceleratorTest] diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -205,6 +205,8 @@ # constructor def __init__(self, tag, attrib={}, **extra): + if not isinstance(attrib, dict): + raise TypeError("must be dict, not %s" % attrib.__class__.__name__) attrib = attrib.copy() attrib.update(extra) self.tag = tag diff --git a/Modules/_elementtree.c b/Modules/_elementtree.c --- a/Modules/_elementtree.c +++ b/Modules/_elementtree.c @@ -358,12 +358,32 @@ if (!PyArg_ParseTuple(args, "O|O!:Element", &tag, &PyDict_Type, &attrib)) return -1; - if (attrib || kwds) { - attrib = (attrib) ? PyDict_Copy(attrib) : PyDict_New(); + if (attrib) { + /* attrib passed as positional arg */ + attrib = PyDict_Copy(attrib); + if (kwds) + PyDict_Update(attrib, kwds); + } else if (kwds) { + PyObject *attrib_str = PyUnicode_FromString("attrib"); + if (PyDict_Contains(kwds, attrib_str)) { + /* attrib passed as kw arg, ends up in kwds, extract it */ + attrib = PyDict_GetItem(kwds, attrib_str); + if (!PyDict_Check(attrib)) { + Py_DECREF(attrib_str); + PyErr_Format(PyExc_TypeError, "must be dict, not %.100s", + Py_TYPE(attrib)->tp_name); + return -1; + } + attrib = PyDict_Copy(attrib); + PyDict_DelItem(kwds, attrib_str); + } else { + /* attrib wasn't even in kwds */ + attrib = PyDict_New(); + } + Py_DECREF(attrib_str); if (!attrib) return -1; - if (kwds) - PyDict_Update(attrib, kwds); + PyDict_Update(attrib, kwds); } else { Py_INCREF(Py_None); attrib = Py_None; @@ -881,13 +901,15 @@ } static PyObject* -element_find(ElementObject* self, PyObject* args) +element_find(ElementObject* self, PyObject* args, PyObject *kwds) { int i; PyObject* tag; PyObject* namespaces = Py_None; - - if (!PyArg_ParseTuple(args, "O|O:find", &tag, &namespaces)) + static char *kwlist[] = {"path", "namespaces", 0}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:find", kwlist, + &tag, &namespaces)) return NULL; if (checkpath(tag) || namespaces != Py_None) { @@ -913,15 +935,17 @@ } static PyObject* -element_findtext(ElementObject* self, PyObject* args) +element_findtext(ElementObject* self, PyObject* args, PyObject *kwds) { int i; PyObject* tag; PyObject* default_value = Py_None; PyObject* namespaces = Py_None; _Py_IDENTIFIER(findtext); - - if (!PyArg_ParseTuple(args, "O|OO:findtext", &tag, &default_value, &namespaces)) + static char *kwlist[] = {"path", "default", "namespaces", 0}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OO:findtext", kwlist, + &tag, &default_value, &namespaces)) return NULL; if (checkpath(tag) || namespaces != Py_None) @@ -951,14 +975,16 @@ } static PyObject* -element_findall(ElementObject* self, PyObject* args) +element_findall(ElementObject* self, PyObject* args, PyObject *kwds) { int i; PyObject* out; PyObject* tag; PyObject* namespaces = Py_None; - - if (!PyArg_ParseTuple(args, "O|O:findall", &tag, &namespaces)) + static char *kwlist[] = {"path", "namespaces", 0}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:findall", kwlist, + &tag, &namespaces)) return NULL; if (checkpath(tag) || namespaces != Py_None) { @@ -990,13 +1016,15 @@ } static PyObject* -element_iterfind(ElementObject* self, PyObject* args) +element_iterfind(ElementObject* self, PyObject* args, PyObject *kwds) { PyObject* tag; PyObject* namespaces = Py_None; _Py_IDENTIFIER(iterfind); - - if (!PyArg_ParseTuple(args, "O|O:iterfind", &tag, &namespaces)) + static char *kwlist[] = {"path", "namespaces", 0}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:iterfind", kwlist, + &tag, &namespaces)) return NULL; return _PyObject_CallMethodId( @@ -1567,9 +1595,9 @@ {"get", (PyCFunction) element_get, METH_VARARGS}, {"set", (PyCFunction) element_set, METH_VARARGS}, - {"find", (PyCFunction) element_find, METH_VARARGS}, - {"findtext", (PyCFunction) element_findtext, METH_VARARGS}, - {"findall", (PyCFunction) element_findall, METH_VARARGS}, + {"find", (PyCFunction) element_find, METH_VARARGS | METH_KEYWORDS}, + {"findtext", (PyCFunction) element_findtext, METH_VARARGS | METH_KEYWORDS}, + {"findall", (PyCFunction) element_findall, METH_VARARGS | METH_KEYWORDS}, {"append", (PyCFunction) element_append, METH_VARARGS}, {"extend", (PyCFunction) element_extend, METH_VARARGS}, @@ -1578,7 +1606,7 @@ {"iter", (PyCFunction) element_iter, METH_VARARGS}, {"itertext", (PyCFunction) element_itertext, METH_VARARGS}, - {"iterfind", (PyCFunction) element_iterfind, METH_VARARGS}, + {"iterfind", (PyCFunction) element_iterfind, METH_VARARGS | METH_KEYWORDS}, {"getiterator", (PyCFunction) element_iter, METH_VARARGS}, {"getchildren", (PyCFunction) element_getchildren, METH_VARARGS},