diff -r 3be60a4c8c63 Include/pyerrors.h --- a/Include/pyerrors.h Fri Jan 20 11:01:06 2012 -0500 +++ b/Include/pyerrors.h Fri Jan 20 17:51:33 2012 -0500 @@ -207,6 +207,8 @@ PyAPI_DATA(PyObject *) PyExc_BytesWarning; PyAPI_DATA(PyObject *) PyExc_ResourceWarning; +PyAPI_DATA(PyObject *) PyExc_TooManyHashCollisions; + /* Convenience functions */ diff -r 3be60a4c8c63 Lib/test/test_dict.py --- a/Lib/test/test_dict.py Fri Jan 20 11:01:06 2012 -0500 +++ b/Lib/test/test_dict.py Fri Jan 20 17:51:33 2012 -0500 @@ -3,6 +3,8 @@ import collections, random, string import gc, weakref +import sys +import time class DictTest(unittest.TestCase): @@ -757,6 +759,192 @@ self._tracked(MyDict()) +# Support classes for HashCollisionTests: +class ChosenHash: + """ + Use this to create arbitrary collections of keys that are non-equal + but have equal hashes, without needing to include hostile data + within the test suite. + """ + def __init__(self, variability, hash): + self.variability = variability + self.hash = hash + + def __eq__(self, other): + # The variability field is used to handle non-equalness: + return self.variability == other.variability + + def __hash__(self): + return self.hash + + def __repr__(self): + return 'ChosenHash(%r, %r)' % (self.variability, + self.hash) + +class Timer: + """ + Simple way to measure time elapsed during a test case + """ + def __init__(self): + self.starttime = time.time() + + def get_elapsed_time(self): + """Get elapsed time in seconds as a float""" + curtime = time.time() + return curtime - self.starttime + + def elapsed_time_as_str(self): + """Get elapsed time as a string (with units)""" + return '%0.3f seconds' % self.get_elapsed_time() + +class TookTooLong(RuntimeError): + def __init__(self, timelimit, elapsed, itercount=None): + self.timelimit = timelimit + self.elapsed = elapsed + self.itercount = itercount + + def __str__(self): + result = 'took >= %s seconds' % self.timelimit + if self.itercount is not None: + result += (' (%0.3f seconds elapsed after %i iterations)' + % (self.elapsed, self.itercount)) + else: + result += ' (%0.3f seconds elapsed)' % self.elapsed + return result + +# Some of the tests involve constructing large dictionaries. How big +# should they be? +ITEM_COUNT = 1000000 + +# Arbitrary threshold (in seconds) for a "reasonable amount of time" +# that it should take to work with ITEM_COUNT items: +TIME_LIMIT = 5 + +class _FasterThanContext(object): + """ + A context manager for implementing assertFasterThan + """ + def __init__(self, test_case, **kwargs): + self.test_case = test_case + if 'seconds' in kwargs: + self.timelimit = kwargs['seconds'] + else: + raise ValueError() + + def __enter__(self): + self.timer = Timer() + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is not None: + # let unexpected exceptions pass through + return + + if 1: + print('timer within %s took %s' + % (self.test_case, self.timer.elapsed_time_as_str())) + + def handle(self, callable_obj, args, kwargs): + """ + If callable_obj is None, assertRaises/Warns is being used as a + context manager, so check for a 'msg' kwarg and return self. + If callable_obj is not None, call it passing args and kwargs. + """ + if callable_obj is None: + self.msg = kwargs.pop('msg', None) + return self + with self: + callable_obj(*args, **kwargs) + + def check_for_timeout(self, itercount): + """ + Allow directly checking for timeouts from within a loop, supplying + an iteration count. If the timer has elapsed, this will raise a + TookTooLong exception, indicating how many iterations were completed + when the time limit was reached. Otherwise, it does nothing. + """ + elapsed_time = self.timer.get_elapsed_time() + if elapsed_time > self.timelimit: + raise TookTooLong(self.timelimit, + elapsed_time, + itercount) + +@support.cpython_only +class HashCollisionTests(unittest.TestCase): + """ + Issue 13703: tests about the behavior of dicts in the face of hostile data + """ + + def assertFasterThan(self, callableObj=None, *args, **kwargs): + context = _FasterThanContext(self, *args, **kwargs) + return context.handle(callableObj, args, kwargs) + + def test_timings_with_benign_data(self): + # Verify that inserting many keys into a dict only takes a few seconds + d = dict() + with self.assertFasterThan(seconds=TIME_LIMIT) as cm: + for i in range(ITEM_COUNT): + d[i] = 0 + + # Verify that we can also retrieve the values quickly: + with self.assertFasterThan(seconds=TIME_LIMIT) as cm: + d[i] + + # Verify that we can quickly insert the same item many times + # (overwriting each time): + d = dict() + with self.assertFasterThan(seconds=TIME_LIMIT) as cm: + for i in range(ITEM_COUNT): + d[0] = 0 + + def test_not_reaching_limit(self): + # Ensure that we can insert equal-hashed keys up to (but not reaching) + # the collision climit: + with self.assertFasterThan(seconds=TIME_LIMIT) as cm: + d = dict() + for i in range(sys.max_dict_collisions - 1): + key = ChosenHash(i, 42) + d[key] = 0 + + def test_reaching_collision_limit(self): + """ + Ensure that too many non-equal keys with the same hash lead to a + TooManyCollisions exception + """ + with self.assertFasterThan(seconds=TIME_LIMIT) as cm: + with self.assertRaisesRegex(TooManyHashCollisions, + ('1001 hash collisions within dict at' + ' key ChosenHash\(999, 42\)' + ' with hash 42')): + d = dict() + for i in range(sys.max_dict_collisions): + key = ChosenHash(i, 42) + d[key] = 0 + + # Frank Sievertsen found scenarios in which the collision-counting + # scheme could be attacked: + # http://mail.python.org/pipermail/python-dev/2012-January/115726.html + + def test_scenario_b_from_Frank_Sievertsen(self): + d = dict() + + # Insert hash collisions up to (but not reaching) the limit: + with self.assertFasterThan(seconds=TIME_LIMIT) as cm: + for i in range(sys.max_dict_collisions -1 ): + key = ChosenHash(i, 42) + d[key] = 0 + + # Now try to add many equal values that collide + # with the hash, and see how long it takes + with self.assertFasterThan(seconds=TIME_LIMIT) as cm: + for i in range(ITEM_COUNT): + key = ChosenHash(0, 42) + d[key] = 0 + cm.check_for_timeout(i) + # FIXME: currently this reproduces the 2nd issue described in his + # post, by failing for me here ^^^ with a message like this: + # test.test_dict.TookTooLong: took >= 5 seconds (5.000 seconds elapsed after 18838 iterations) + from test import mapping_tests class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): @@ -771,6 +959,7 @@ def test_main(): support.run_unittest( DictTest, + HashCollisionTests, GeneralMappingTests, SubclassMappingTests, ) diff -r 3be60a4c8c63 Objects/dictobject.c --- a/Objects/dictobject.c Fri Jan 20 11:01:06 2012 -0500 +++ b/Objects/dictobject.c Fri Jan 20 17:51:33 2012 -0500 @@ -10,6 +10,8 @@ #include "Python.h" #include "stringlib/eq.h" +/* Maximum number of allowed hash collisions. */ +#define Py_MAX_DICT_COLLISIONS 1000 /* Set a key error with the specified argument, wrapping it in a * tuple automatically so that tuple keys are not unpacked as the @@ -25,6 +27,21 @@ Py_DECREF(tup); } +/* Set a TooManyHashCollisions error */ +static void +set_too_many_collisions_error(size_t collisions, PyObject *key, Py_hash_t hash) +{ + PyErr_Format(PyExc_TooManyHashCollisions, + ("%i hash collisions within dict" + " at key %R with hash %zd"), + collisions, key, hash); + /* Backporting notes: (FIXME) + %R is a Python 3-ism + %zd is for Py_ssize_t, which in Python 3 is the same as Py_hash_t + */ +} + + /* Define this out if you don't want conversion statistics on exit. */ #undef SHOW_CONVERSION_COUNTS @@ -326,6 +343,7 @@ register PyDictEntry *ep; register int cmp; PyObject *startkey; + size_t collisions; i = (size_t)hash & mask; ep = &ep0[i]; @@ -360,6 +378,7 @@ /* In the loop, me_key == dummy is by far (factor of 100s) the least likely outcome, so test for that last. */ + collisions = 1; for (perturb = hash; ; perturb >>= PERTURB_SHIFT) { i = (i << 2) + i + perturb + 1; ep = &ep0[i & mask]; @@ -386,6 +405,10 @@ */ return lookdict(mp, key, hash); } + if (++collisions > Py_MAX_DICT_COLLISIONS) { + set_too_many_collisions_error(collisions, key, hash); + return NULL; + } } else if (ep->me_key == dummy && freeslot == NULL) freeslot = ep; @@ -413,6 +436,7 @@ register size_t mask = (size_t)mp->ma_mask; PyDictEntry *ep0 = mp->ma_table; register PyDictEntry *ep; + size_t collisions; /* Make sure this function doesn't have to handle non-unicode keys, including subclasses of str; e.g., one reason to subclass @@ -439,6 +463,7 @@ /* In the loop, me_key == dummy is by far (factor of 100s) the least likely outcome, so test for that last. */ + collisions = 1; for (perturb = hash; ; perturb >>= PERTURB_SHIFT) { i = (i << 2) + i + perturb + 1; ep = &ep0[i & mask]; @@ -451,6 +476,10 @@ return ep; if (ep->me_key == dummy && freeslot == NULL) freeslot = ep; + if (++collisions > Py_MAX_DICT_COLLISIONS) { + set_too_many_collisions_error(collisions, key, hash); + return NULL; + } } assert(0); /* NOT REACHED */ return 0; diff -r 3be60a4c8c63 Objects/exceptions.c --- a/Objects/exceptions.c Fri Jan 20 11:01:06 2012 -0500 +++ b/Objects/exceptions.c Fri Jan 20 17:51:33 2012 -0500 @@ -2205,6 +2205,12 @@ SimpleExtendsException(PyExc_Warning, ResourceWarning, "Base class for warnings about resource usage."); +/* + * TooManyHashCollisions extends BaseException + */ +SimpleExtendsException(PyExc_BaseException, TooManyHashCollisions, + "Base class for warnings about computationally-infeasible data."); + /* Pre-computed RuntimeError instance for when recursion depth is reached. @@ -2318,6 +2324,7 @@ PRE_INIT(UnicodeWarning) PRE_INIT(BytesWarning) PRE_INIT(ResourceWarning) + PRE_INIT(TooManyHashCollisions) /* OSError subclasses */ PRE_INIT(ConnectionError); @@ -2399,6 +2406,7 @@ POST_INIT(UnicodeWarning) POST_INIT(BytesWarning) POST_INIT(ResourceWarning) + POST_INIT(TooManyHashCollisions) if (!errnomap) { errnomap = PyDict_New(); diff -r 3be60a4c8c63 Python/sysmodule.c --- a/Python/sysmodule.c Fri Jan 20 11:01:06 2012 -0500 +++ b/Python/sysmodule.c Fri Jan 20 17:51:33 2012 -0500 @@ -1619,6 +1619,10 @@ SET_SYS_FROM_STRING("thread_info", PyThread_GetInfo()); #endif + SET_SYS_FROM_STRING("max_dict_collisions", + PyLong_FromLong(1000)); // FIXME + + #undef SET_SYS_FROM_STRING if (PyErr_Occurred()) return NULL;