diff --git a/Include/odictobject.h b/Include/odictobject.h --- a/Include/odictobject.h +++ b/Include/odictobject.h @@ -30,6 +30,8 @@ /* wrappers around PyDict* functions */ #define PyODict_GetItem(od, key) PyDict_GetItem((PyObject *)od, key) +#define PyODict_GetItemWithError(od, key) \ + PyDict_GetItemWithError((PyObject *)od, key) #define PyODict_Contains(od, key) PyDict_Contains((PyObject *)od, key) #define PyODict_Size(od) PyDict_Size((PyObject *)od) #define PyODict_GetItemString(od, key) \ diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -2037,6 +2037,24 @@ del od[colliding] self.assertEqual(list(od.items()), [(key, ...), ('after', ...)]) + def test_issue24347(self): + OrderedDict = self.module.OrderedDict + + class Key: + def __hash__(self): + return randrange(100000) + + od = OrderedDict() + for i in range(100): + key = Key() + od[key] = i + + # These should not crash. + with self.assertRaises(KeyError): + repr(od) + with self.assertRaises(KeyError): + od.copy() + class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): diff --git a/Objects/odictobject.c b/Objects/odictobject.c --- a/Objects/odictobject.c +++ b/Objects/odictobject.c @@ -511,7 +511,7 @@ (node->key) /* borrowed reference */ #define _odictnode_VALUE(node, od) \ - PyODict_GetItem((PyObject *)od, _odictnode_KEY(node)) + PyODict_GetItemWithError((PyObject *)od, _odictnode_KEY(node)) /* If needed we could also have _odictnode_HASH. */ #define _odictnode_PREV(node) (node->prev) #define _odictnode_NEXT(node) (node->next) @@ -1313,10 +1313,14 @@ if (PyODict_CheckExact(od)) { _odict_FOREACH(od, node) { - int res = PyODict_SetItem((PyObject *)od_copy, - _odictnode_KEY(node), - _odictnode_VALUE(node, od)); - if (res != 0) + PyObject *key = _odictnode_KEY(node); + PyObject *value = _odictnode_VALUE(node, od); + if (value == NULL) { + if (!PyErr_Occurred()) + PyErr_SetObject(PyExc_KeyError, key); + goto fail; + } + if (PyODict_SetItem((PyObject *)od_copy, key, value) != 0) goto fail; } } @@ -1538,7 +1542,6 @@ Py_ssize_t count = -1; PyObject *pieces = NULL, *result = NULL, *cls = NULL; PyObject *classname = NULL, *format = NULL, *args = NULL; - _ODictNode *node; i = Py_ReprEnter((PyObject *)self); if (i != 0) { @@ -1551,13 +1554,21 @@ } if (PyODict_CheckExact(self)) { + _ODictNode *node; pieces = PyList_New(PyODict_SIZE(self)); if (pieces == NULL) goto Done; _odict_FOREACH(self, node) { - PyObject *pair = PyTuple_Pack(2, _odictnode_KEY(node), - _odictnode_VALUE(node, self)); + PyObject *pair; + PyObject *key = _odictnode_KEY(node); + PyObject *value = _odictnode_VALUE(node, self); + if (value == NULL) { + if (!PyErr_Occurred()) + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + pair = PyTuple_Pack(2, key, value); if (pair == NULL) goto Done; @@ -1813,7 +1824,7 @@ odictiter_dealloc(odictiterobject *di) { _PyObject_GC_UNTRACK(di); - Py_DECREF(di->di_odict); + Py_XDECREF(di->di_odict); Py_XDECREF(di->di_current); if (di->kind & (_odict_ITER_KEYS | _odict_ITER_VALUES)) { Py_DECREF(di->di_result); @@ -1830,16 +1841,21 @@ return 0; } +/* In order to protect against modifications during iteration, we track + * the current key instead of the current node. */ static PyObject * odictiter_nextkey(odictiterobject *di) { - PyObject *key; + PyObject *key = NULL; _ODictNode *node; int reversed = di->kind & _odict_ITER_REVERSED; + if (di->di_odict == NULL) + return NULL; + if (di->di_current == NULL) + goto done; /* We're already done. */ + /* Get the key. */ - if (di->di_current == NULL) - return NULL; node = _odict_find_node(di->di_odict, di->di_current); if (node == NULL) { /* Must have been deleted. */ @@ -1860,6 +1876,10 @@ } return key; + +done: + Py_CLEAR(di->di_odict); + return key; } static PyObject * @@ -1882,8 +1902,10 @@ value = PyODict_GetItem((PyObject *)di->di_odict, key); /* borrowed */ if (value == NULL) { + if (!PyErr_Occurred()) + PyErr_SetObject(PyExc_KeyError, key); Py_DECREF(key); - return NULL; + goto done; } Py_INCREF(value); @@ -1899,7 +1921,7 @@ if (result == NULL) { Py_DECREF(key); Py_DECREF(value); - return NULL; + goto done; } } @@ -1911,10 +1933,20 @@ /* Handle the values case. */ else { value = PyODict_GetItem((PyObject *)di->di_odict, key); - Py_XINCREF(value); Py_DECREF(key); + if (value == NULL) { + if (!PyErr_Occurred()) + PyErr_SetObject(PyExc_KeyError, key); + goto done; + } + Py_INCREF(value); return value; } + +done: + Py_CLEAR(di->di_current); + Py_CLEAR(di->di_odict); + return NULL; } /* No need for tp_clear because odictiterobject is not mutable. */