From e4365935397d5f85d98b0c17278740e53992ba74 Mon Sep 17 00:00:00 2001 From: Tim Mitchell Date: Mon, 12 Sep 2016 16:25:22 +1200 Subject: [PATCH] Issue #27945: Fixed segfaults in dict.fromkeys when iterable modifies itself. * insertdict and dict_equal now keep a reference to key as dk_lookup may call arbitrary python code and free key. * _PyDict_FromKeys now checks the iterable does not change whilst iterating when input is pure dict or set. --- Lib/test/test_dict.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ Objects/dictobject.c | 27 +++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index fb954c8..bfbb981 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -1023,6 +1023,55 @@ class DictTest(unittest.TestCase): d = {X(): 0, 1: 1} self.assertRaises(RuntimeError, d.update, other) + def test_equal_operator_modifying_operand(self): + # test fix for seg fault reported in issue 27945 part 3. + class X(): + def __del__(self): + dict_b.clear() + + def __eq__(self, other): + dict_a.clear() + return True + + def __hash__(self): + return 13 + + dict_a = {X(): 0} + dict_b = {X(): X()} + self.assertTrue(dict_a == dict_b) + + def test_fromkeys_operator_modifying_dict_operand(self): + # test fix for seg fault reported in issue 27945 part 4a. + class X(int): + def __hash__(self): + return 13 + + def __eq__(self, other): + if len(d) > 1: + d.clear() + return False + + d = {} # this is required to exist so that d can be constructed! + d = {X(1): 1, X(2): 2} + with self.assertRaises(RuntimeError): + dict.fromkeys(d) + + def test_fromkeys_operator_modifying_set_operand(self): + # test fix for seg fault reported in issue 27945 part 4b. + class X(int): + def __hash__(self): + return 13 + + def __eq__(self, other): + if len(d) > 1: + d.clear() + return False + + d = {} # this is required to exist so that d can be constructed! + d = {X(1): 1, X(2): 2} + with self.assertRaises(RuntimeError): + dict.fromkeys(d) + def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, dict) support.check_free_after_iterating(self, lambda d: iter(d.keys()), dict) diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 4bcc3db..bf334f4 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -1038,9 +1038,10 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) if (insertion_resize(mp) < 0) return -1; } - + Py_INCREF(key); ix = mp->ma_keys->dk_lookup(mp, key, hash, &value_addr, &hashpos); if (ix == DKIX_ERROR) { + Py_DECREF(key); return -1; } @@ -1056,6 +1057,7 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) (ix == DKIX_EMPTY && mp->ma_used != mp->ma_keys->dk_nentries))) { if (insertion_resize(mp) < 0) { Py_DECREF(value); + Py_DECREF(key); return -1; } find_empty_slot(mp, key, hash, &value_addr, &hashpos); @@ -1068,6 +1070,7 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) /* Need to resize. */ if (insertion_resize(mp) < 0) { Py_DECREF(value); + Py_DECREF(key); return -1; } find_empty_slot(mp, key, hash, &value_addr, &hashpos); @@ -1090,6 +1093,7 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) mp->ma_keys->dk_usable--; mp->ma_keys->dk_nentries++; assert(mp->ma_keys->dk_usable >= 0); + Py_DECREF(key); return 0; } @@ -1101,6 +1105,7 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) mp->ma_version_tag = DICT_NEXT_VERSION(); Py_DECREF(old_value); /* which **CAN** re-enter (see issue #22653) */ + Py_DECREF(key); return 0; } @@ -1110,6 +1115,7 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) *value_addr = value; mp->ma_used++; mp->ma_version_tag = DICT_NEXT_VERSION(); + Py_DECREF(key); return 0; } @@ -1771,7 +1777,7 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value) if (PyDict_CheckExact(iterable)) { PyDictObject *mp = (PyDictObject *)d; PyObject *oldvalue; - Py_ssize_t pos = 0; + Py_ssize_t pos = 0, it_size = ((PyDictObject *)iterable)->ma_used; PyObject *key; Py_hash_t hash; @@ -1785,12 +1791,18 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value) Py_DECREF(d); return NULL; } + if (((PyDictObject *)iterable)->ma_used != it_size) { + Py_DECREF(d); + PyErr_SetString(PyExc_RuntimeError, + "dictionary changed size during iteration"); + return NULL; + } } return d; } if (PyAnySet_CheckExact(iterable)) { PyDictObject *mp = (PyDictObject *)d; - Py_ssize_t pos = 0; + Py_ssize_t pos = 0, it_size = ((PyDictObject *)iterable)->ma_used; PyObject *key; Py_hash_t hash; @@ -1804,6 +1816,12 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value) Py_DECREF(d); return NULL; } + if (((PyDictObject *)iterable)->ma_used != it_size) { + Py_DECREF(d); + PyErr_SetString(PyExc_RuntimeError, + "set changed size during iteration"); + return NULL; + } } return d; } @@ -2544,8 +2562,8 @@ dict_equal(PyDictObject *a, PyDictObject *b) bval = NULL; else bval = *vaddr; - Py_DECREF(key); if (bval == NULL) { + Py_DECREF(key); Py_DECREF(aval); if (PyErr_Occurred()) return -1; @@ -2553,6 +2571,7 @@ dict_equal(PyDictObject *a, PyDictObject *b) } cmp = PyObject_RichCompareBool(aval, bval, Py_EQ); Py_DECREF(aval); + Py_DECREF(key); if (cmp <= 0) /* error or not equal */ return cmp; } -- 1.9.4.msysgit.2