From 84eb3e24e80a9e7ec468aaae046d4dde4534d277 Mon Sep 17 00:00:00 2001 From: Hristo Venev Date: Thu, 28 Jan 2021 20:53:35 +0200 Subject: [PATCH] Don't downgrade unicode-only dicts to mixed on non-unicode lookups --- Lib/test/test_dict.py | 77 +++++++++++++++++++++++++++++++++++++++++++ Objects/dictobject.c | 11 +++++-- 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index 6b8596fff6..b662edb782 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -1422,6 +1422,83 @@ class DictTest(unittest.TestCase): d = CustomReversedDict(pairs) self.assertEqual(pairs[::-1], list(dict(d).items())) + def test_str_nonstr(self): + # cpython uses a different lookup function if the dict only contains + # `str` keys. Make sure the unoptimized path is used when a non-`str` + # key appears. + + class StrSub(str): + pass + + eq_count = 0 + # This class compares equal to the string 'key3' + class Key3: + def __hash__(self): + return hash('key3') + + def __eq__(self, other): + nonlocal eq_count + if isinstance(other, Key3) or other == 'key3': + eq_count += 1 + return True + return False + + key3_1 = StrSub('key3') + key3_2 = Key3() + key3_3 = Key3() + + dicts = [] + + for key3 in (key3_1, key3_2): + d = {'key1': 42, 'key2': 43} + d[key3] = 44 + dicts.append(d) + + d = {'key1': 42, 'key2': 43} + self.assertEqual(d.setdefault(key3, 44), 44) + dicts.append(d) + + d = {'key1': 42, 'key2': 43} + d.update({key3: 44}) + dicts.append(d) + + d = {'key1': 42, 'key2': 43} + d |= {key3: 44} + dicts.append(d) + + def make_pairs(): + yield ('key1', 42) + yield ('key2', 43) + yield (key3, 44) + d = dict(make_pairs()) + dicts.append(d) + + d = d.copy() + dicts.append(d) + + d = {key: 42 + i for i,key in enumerate(['key1', 'key2', key3])} + dicts.append(d) + + for d in dicts: + self.assertEqual(d.get('key1'), 42) + + noninterned_key1 = 'ke' + noninterned_key1 += 'y1' + if support.check_impl_detail(cpython=True): + # suppress a SyntaxWarning + interned_key1 = 'key1' + self.assertFalse(noninterned_key1 is interned_key1) + self.assertEqual(d.get(noninterned_key1), 42) + + self.assertEqual(d.get('key3'), 44) + self.assertEqual(d.get(key3_1), 44) + self.assertEqual(d.get(key3_2), 44) + + # make sure __eq__ was called + eq_count = 0 + self.assertEqual(d.get(key3_3), 44) + self.assertGreaterEqual(eq_count, 1) + class CAPITest(unittest.TestCase): diff --git a/Objects/dictobject.c b/Objects/dictobject.c index faee6bc901..96a1bf5e9b 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -438,6 +438,12 @@ dictkeys_set_index(PyDictKeysObject *keys, Py_ssize_t i, Py_ssize_t ix) (d)->ma_keys->dk_lookup = lookdict_unicode; \ } +#define MAYBE_NONUNICODE(d, key) \ + if (!PyUnicode_CheckExact(key) && (d)->ma_keys->dk_lookup != lookdict) { \ + assert((d)->ma_keys->dk_lookup == lookdict_unicode || (d)->ma_keys->dk_lookup == lookdict_unicode_nodummy); \ + (d)->ma_keys->dk_lookup = lookdict; \ + } + /* This immutable, empty PyDictKeysObject is used for PyDict_Clear() * (which cannot fail and thus can do no allocation). */ @@ -835,7 +841,6 @@ lookdict_unicode(PyDictObject *mp, PyObject *key, unicodes is to override __eq__, and for speed we don't cater to that here. */ if (!PyUnicode_CheckExact(key)) { - mp->ma_keys->dk_lookup = lookdict; return lookdict(mp, key, hash, value_addr); } @@ -878,7 +883,6 @@ lookdict_unicode_nodummy(PyDictObject *mp, PyObject *key, unicodes is to override __eq__, and for speed we don't cater to that here. */ if (!PyUnicode_CheckExact(key)) { - mp->ma_keys->dk_lookup = lookdict; return lookdict(mp, key, hash, value_addr); } @@ -1062,7 +1066,6 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) if (ix == DKIX_ERROR) goto Fail; - assert(PyUnicode_CheckExact(key) || mp->ma_keys->dk_lookup == lookdict); MAINTAIN_TRACKING(mp, key, value); /* When insertion order is different from shared key, we can't share @@ -1084,6 +1087,7 @@ insertdict(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject *value) if (insertion_resize(mp) < 0) goto Fail; } + MAYBE_NONUNICODE(mp, key); Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash); ep = &DK_ENTRIES(mp->ma_keys)[mp->ma_keys->dk_nentries]; dictkeys_set_index(mp->ma_keys, hashpos, mp->ma_keys->dk_nentries); @@ -2965,6 +2969,7 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj) return NULL; } } + MAYBE_NONUNICODE(mp, key); Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash); ep0 = DK_ENTRIES(mp->ma_keys); ep = &ep0[mp->ma_keys->dk_nentries]; -- 2.29.2