Index: Include/unicodeobject.h =================================================================== --- Include/unicodeobject.h (Revision 59441) +++ Include/unicodeobject.h (Arbeitskopie) @@ -1330,6 +1330,11 @@ const char *right ); +PyAPI_FUNC(int) _PyUnicode_Eq( + PyObject *left, /* Left string */ + PyObject *right /* Right string */ + ); + /* Rich compare two strings and return one of the following: - NULL in case an exception was raised Index: Include/object.h =================================================================== --- Include/object.h (Revision 59441) +++ Include/object.h (Arbeitskopie) @@ -473,6 +473,8 @@ /* Helpers for hash functions */ PyAPI_FUNC(long) _Py_HashDouble(double); PyAPI_FUNC(long) _Py_HashPointer(void*); +/* Optimized version for setobject.c and dictobject.c */ +PyAPI_FUNC(long) _PyObject_HashFast(PyObject *); /* Helper for passing objects to printf and the like */ #define PyObject_REPR(obj) PyUnicode_AsString(PyObject_Repr(obj)) Index: Objects/object.c =================================================================== --- Objects/object.c (Revision 59441) +++ Objects/object.c (Arbeitskopie) @@ -754,7 +754,22 @@ #endif } +long +_PyObject_HashFast(register PyObject *v) +{ + register long hash = -1; + if (PyUnicode_CheckExact(v)) + hash = ((PyUnicodeObject *) v)->hash; + else if (PyString_CheckExact(v)) + hash = ((PyStringObject *) v)->ob_shash; + + if (hash == -1) + return PyObject_Hash(v); + else + return hash; +} + long PyObject_Hash(PyObject *v) { Index: Objects/dictobject.c =================================================================== --- Objects/dictobject.c (Revision 59441) +++ Objects/dictobject.c (Arbeitskopie) @@ -327,32 +327,13 @@ return 0; } -/* Return 1 if two unicode objects are equal, 0 if not. */ -static int -unicode_eq(PyObject *aa, PyObject *bb) -{ - PyUnicodeObject *a = (PyUnicodeObject *)aa; - PyUnicodeObject *b = (PyUnicodeObject *)bb; - - if (a->length != b->length) - return 0; - if (a->length == 0) - return 1; - if (a->str[0] != b->str[0]) - return 0; - if (a->length == 1) - return 1; - return memcmp(a->str, b->str, a->length * sizeof(Py_UNICODE)) == 0; -} - - /* * Hacked up version of lookdict which can assume keys are always * unicodes; this assumption allows testing for errors during * PyObject_RichCompareBool() to be dropped; unicode-unicode * comparisons never raise exceptions. This also means we don't need * to go through PyObject_RichCompareBool(); we can always use - * unicode_eq() directly. + * _PyUnicode_Eq() directly. * * This is valuable because dicts with only unicode keys are very common. */ @@ -384,7 +365,7 @@ if (ep->me_key == dummy) freeslot = ep; else { - if (ep->me_hash == hash && unicode_eq(ep->me_key, key)) + if (ep->me_hash == hash && _PyUnicode_Eq(ep->me_key, key)) return ep; freeslot = NULL; } @@ -399,7 +380,7 @@ if (ep->me_key == key || (ep->me_hash == hash && ep->me_key != dummy - && unicode_eq(ep->me_key, key))) + && _PyUnicode_Eq(ep->me_key, key))) return ep; if (ep->me_key == dummy && freeslot == NULL) freeslot = ep; @@ -588,16 +569,11 @@ PyThreadState *tstate; if (!PyDict_Check(op)) return NULL; - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) - { - hash = PyObject_Hash(key); - if (hash == -1) { - PyErr_Clear(); - return NULL; - } + hash = _PyObject_HashFast(key); + if (hash == -1) { + PyErr_Clear(); + return NULL; } - /* We can arrive here with a NULL tstate during initialization: try running "python -Wi" for an example related to string interning. Let's just hope that no exception occurs then... */ @@ -637,15 +613,10 @@ PyErr_BadInternalCall(); return NULL; } - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) - { - hash = PyObject_Hash(key); - if (hash == -1) { - return NULL; - } + hash = _PyObject_HashFast(key); + if (hash == -1) { + return NULL; } - ep = (mp->ma_lookup)(mp, key, hash); if (ep == NULL) return NULL; @@ -672,12 +643,9 @@ assert(key); assert(value); mp = (PyDictObject *)op; - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) - { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return -1; } assert(mp->ma_fill <= mp->ma_mask); /* at least one empty slot */ n_used = mp->ma_used; @@ -717,11 +685,9 @@ return -1; } assert(key); - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return -1; } mp = (PyDictObject *)op; ep = (mp->ma_lookup)(mp, key, hash); @@ -997,11 +963,9 @@ long hash; PyDictEntry *ep; assert(mp->ma_table != NULL); - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return NULL; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return NULL; } ep = (mp->ma_lookup)(mp, key, hash); if (ep == NULL) @@ -1593,11 +1557,9 @@ long hash; PyDictEntry *ep; - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return NULL; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return NULL; } ep = (mp->ma_lookup)(mp, key, hash); if (ep == NULL) @@ -1617,11 +1579,9 @@ if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &failobj)) return NULL; - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return NULL; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return NULL; } ep = (mp->ma_lookup)(mp, key, hash); if (ep == NULL) @@ -1646,11 +1606,9 @@ if (!PyArg_UnpackTuple(args, "setdefault", 1, 2, &key, &failobj)) return NULL; - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return NULL; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return NULL; } ep = (mp->ma_lookup)(mp, key, hash); if (ep == NULL) @@ -1692,11 +1650,9 @@ "pop(): dictionary is empty"); return NULL; } - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return NULL; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return NULL; } ep = (mp->ma_lookup)(mp, key, hash); if (ep == NULL) @@ -1883,11 +1839,9 @@ PyDictObject *mp = (PyDictObject *)op; PyDictEntry *ep; - if (!PyUnicode_CheckExact(key) || - (hash = ((PyUnicodeObject *) key)->hash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return -1; } ep = (mp->ma_lookup)(mp, key, hash); return ep == NULL ? -1 : (ep->me_value != NULL); Index: Objects/unicodeobject.c =================================================================== --- Objects/unicodeobject.c (Revision 59441) +++ Objects/unicodeobject.c (Arbeitskopie) @@ -95,6 +95,9 @@ extern "C" { #endif +/* forward declarations */ +static long unicode_hash(PyUnicodeObject *); + /* This dictionary holds all interned unicode strings. Note that references to strings in this dictionary are *not* counted in the string's ob_refcnt. When the interned string reaches a refcnt of 0 the string deallocation @@ -6349,6 +6352,28 @@ return PyBool_FromLong(result); } +/* optimized version used by dictobject.c and setobject.c + * Return 1 if two unicode objects are equal, 0 if not. */ +int +_PyUnicode_Eq(PyObject *aa, PyObject *bb) +{ + PyUnicodeObject *a = (PyUnicodeObject *)aa; + PyUnicodeObject *b = (PyUnicodeObject *)bb; + + if (a->length != b->length) + return 0; + if (a->length == 0) + return 1; + if (a->str[0] != b->str[0]) + return 0; + if (a->length == 1) + return 1; + if (unicode_hash(a) != unicode_hash(b)) + return 0; + return memcmp(a->str, b->str, a->length * sizeof(Py_UNICODE)) == 0; +} + + int PyUnicode_Contains(PyObject *container, PyObject *element) { Index: Objects/setobject.c =================================================================== --- Objects/setobject.c (Revision 59441) +++ Objects/setobject.c (Arbeitskopie) @@ -144,12 +144,12 @@ } /* - * Hacked up version of set_lookkey which can assume keys are always strings; - * This means we can always use _PyString_Eq directly and not have to check to + * Hacked up version of set_lookkey which can assume keys are always unicode; + * This means we can always use _PyUnicode_Eq directly and not have to check to * see if the comparison altered the table. */ static setentry * -set_lookkey_string(PySetObject *so, PyObject *key, register long hash) +set_lookkey_unicode(PySetObject *so, PyObject *key, register long hash) { register Py_ssize_t i; register size_t perturb; @@ -158,11 +158,11 @@ setentry *table = so->table; register setentry *entry; - /* Make sure this function doesn't have to handle non-string keys, + /* Make sure this function doesn't have to handle non-unicode keys, including subclasses of str; e.g., one reason to subclass strings is to override __eq__, and for speed we don't cater to that here. */ - if (!PyString_CheckExact(key)) { + if (!PyUnicode_CheckExact(key)) { so->lookup = set_lookkey; return set_lookkey(so, key, hash); } @@ -173,7 +173,7 @@ if (entry->key == dummy) freeslot = entry; else { - if (entry->hash == hash && _PyString_Eq(entry->key, key)) + if (entry->hash == hash && _PyUnicode_Eq(entry->key, key)) return entry; freeslot = NULL; } @@ -188,7 +188,7 @@ if (entry->key == key || (entry->hash == hash && entry->key != dummy - && _PyString_Eq(entry->key, key))) + && _PyUnicode_Eq(entry->key, key))) return entry; if (entry->key == dummy && freeslot == NULL) freeslot = entry; @@ -375,11 +375,9 @@ register long hash; register Py_ssize_t n_used; - if (!PyString_CheckExact(key) || - (hash = ((PyStringObject *) key)->ob_shash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return -1; } assert(so->fill <= so->mask); /* at least one empty slot */ n_used = so->used; @@ -422,11 +420,10 @@ PyObject *old_key; assert (PyAnySet_Check(so)); - if (!PyString_CheckExact(key) || - (hash = ((PyStringObject *) key)->ob_shash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; + + hash = _PyObject_HashFast(key); + if (hash == -1) { + return -1; } entry = (so->lookup)(so, key, hash); if (entry == NULL) @@ -668,11 +665,9 @@ long hash; setentry *entry; - if (!PyString_CheckExact(key) || - (hash = ((PyStringObject *) key)->ob_shash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; + hash = _PyObject_HashFast(key); + if (hash == -1) { + return -1; } entry = (so->lookup)(so, key, hash); if (entry == NULL) @@ -989,7 +984,7 @@ INIT_NONZERO_SET_SLOTS(so); } - so->lookup = set_lookkey_string; + so->lookup = set_lookkey_unicode; so->weakreflist = NULL; if (iterable != NULL) { Index: Lib/test/test_set.py =================================================================== --- Lib/test/test_set.py (Revision 59441) +++ Lib/test/test_set.py (Arbeitskopie) @@ -7,6 +7,7 @@ import os from random import randrange, shuffle import sys +import warnings class PassThru(Exception): pass @@ -817,6 +818,44 @@ self.length = 3 self.repr = None +#------------------------------------------------------------------------------ + +class TestBasicOpsString(TestBasicOps): + def setUp(self): + self.case = "string set" + self.values = ["a", "b", "c"] + self.set = set(self.values) + self.dup = set(self.values) + self.length = 3 + self.repr = "{'a', 'c', 'b'}" + +#------------------------------------------------------------------------------ + +class TestBasicOpsBytes(TestBasicOps): + def setUp(self): + self.case = "string set" + self.values = [b"a", b"b", b"c"] + self.set = set(self.values) + self.dup = set(self.values) + self.length = 3 + self.repr = "{b'a', b'c', b'b'}" + +#------------------------------------------------------------------------------ + +class TestBasicOpsMixedStringBytes(TestBasicOps): + def setUp(self): + self.warning_filters = warnings.filters[:] + warnings.simplefilter('ignore', BytesWarning) + self.case = "string and bytes set" + self.values = ["a", "b", b"a", b"b"] + self.set = set(self.values) + self.dup = set(self.values) + self.length = 4 + self.repr = "{'a', b'a', 'b', b'b'}" + + def tearDown(self): + warnings.filters = self.warning_filters + #============================================================================== def baditer(): @@ -1581,6 +1620,9 @@ TestBasicOpsSingleton, TestBasicOpsTuple, TestBasicOpsTriple, + TestBasicOpsString, + TestBasicOpsBytes, + TestBasicOpsMixedStringBytes, TestBinaryOps, TestUpdateOps, TestMutate,