Index: Include/unicodeobject.h =================================================================== --- Include/unicodeobject.h (Revision 59424) +++ Include/unicodeobject.h (Arbeitskopie) @@ -1330,6 +1330,11 @@ const char *right ); +PyAPI_FUNC(int) _PyUnicode_Eq( + PyObject *left, /* Left string */ + PyObject *right /* Right string */ + ); + /* Rich compare two strings and return one of the following: - NULL in case an exception was raised Index: Objects/unicodeobject.c =================================================================== --- Objects/unicodeobject.c (Revision 59424) +++ Objects/unicodeobject.c (Arbeitskopie) @@ -95,6 +95,9 @@ extern "C" { #endif +/* forward declarations */ +static long unicode_hash(PyUnicodeObject *); + /* This dictionary holds all interned unicode strings. Note that references to strings in this dictionary are *not* counted in the string's ob_refcnt. When the interned string reaches a refcnt of 0 the string deallocation @@ -6349,6 +6352,18 @@ return PyBool_FromLong(result); } +/* optimized version used by Objects/setobject.c */ +int +_PyUnicode_Eq(PyObject *o1, PyObject *o2) +{ + PyUnicodeObject *a = (PyUnicodeObject*) o1; + PyUnicodeObject *b = (PyUnicodeObject*) o2; + return Py_Size(a) == Py_Size(b) + && unicode_hash(a) == unicode_hash(b) + && unicode_compare(a, b) == 0; +} + + int PyUnicode_Contains(PyObject *container, PyObject *element) { Index: Objects/setobject.c =================================================================== --- Objects/setobject.c (Revision 59424) +++ Objects/setobject.c (Arbeitskopie) @@ -50,6 +50,20 @@ INIT_NONZERO_SET_SLOTS(so); \ } while(0) +#define GET_HASH(hash, key) do { \ + hash = -1; \ + if (PyUnicode_CheckExact(key)) { \ + hash = ((PyUnicodeObject *) key)->hash; \ + } else if (PyString_CheckExact(key)) { \ + hash = ((PyStringObject *) key)->ob_shash; \ + } \ + if (hash == -1) { \ + hash = PyObject_Hash(key); \ + if (hash == -1) \ + return -1; \ + } \ + } while(0) + /* Reuse scheme to save calls to malloc, free, and memset */ #define MAXFREESETS 80 static PySetObject *free_sets[MAXFREESETS]; @@ -144,12 +158,12 @@ } /* - * Hacked up version of set_lookkey which can assume keys are always strings; - * This means we can always use _PyString_Eq directly and not have to check to + * Hacked up version of set_lookkey which can assume keys are always unicode; + * This means we can always use _PyUnicode_Eq directly and not have to check to * see if the comparison altered the table. */ static setentry * -set_lookkey_string(PySetObject *so, PyObject *key, register long hash) +set_lookkey_unicode(PySetObject *so, PyObject *key, register long hash) { register Py_ssize_t i; register size_t perturb; @@ -158,11 +172,11 @@ setentry *table = so->table; register setentry *entry; - /* Make sure this function doesn't have to handle non-string keys, + /* Make sure this function doesn't have to handle non-unicode keys, including subclasses of str; e.g., one reason to subclass strings is to override __eq__, and for speed we don't cater to that here. */ - if (!PyString_CheckExact(key)) { + if (!PyUnicode_CheckExact(key)) { so->lookup = set_lookkey; return set_lookkey(so, key, hash); } @@ -173,7 +187,7 @@ if (entry->key == dummy) freeslot = entry; else { - if (entry->hash == hash && _PyString_Eq(entry->key, key)) + if (entry->hash == hash && _PyUnicode_Eq(entry->key, key)) return entry; freeslot = NULL; } @@ -188,7 +202,7 @@ if (entry->key == key || (entry->hash == hash && entry->key != dummy - && _PyString_Eq(entry->key, key))) + && _PyUnicode_Eq(entry->key, key))) return entry; if (entry->key == dummy && freeslot == NULL) freeslot = entry; @@ -375,12 +389,8 @@ register long hash; register Py_ssize_t n_used; - if (!PyString_CheckExact(key) || - (hash = ((PyStringObject *) key)->ob_shash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; - } + GET_HASH(hash, key); + assert(so->fill <= so->mask); /* at least one empty slot */ n_used = so->used; Py_INCREF(key); @@ -422,12 +432,9 @@ PyObject *old_key; assert (PyAnySet_Check(so)); - if (!PyString_CheckExact(key) || - (hash = ((PyStringObject *) key)->ob_shash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; - } + + GET_HASH(hash, key); + entry = (so->lookup)(so, key, hash); if (entry == NULL) return -1; @@ -668,12 +675,8 @@ long hash; setentry *entry; - if (!PyString_CheckExact(key) || - (hash = ((PyStringObject *) key)->ob_shash) == -1) { - hash = PyObject_Hash(key); - if (hash == -1) - return -1; - } + GET_HASH(hash, key); + entry = (so->lookup)(so, key, hash); if (entry == NULL) return -1; @@ -989,7 +992,7 @@ INIT_NONZERO_SET_SLOTS(so); } - so->lookup = set_lookkey_string; + so->lookup = set_lookkey_unicode; so->weakreflist = NULL; if (iterable != NULL) { Index: Lib/test/test_set.py =================================================================== --- Lib/test/test_set.py (Revision 59424) +++ Lib/test/test_set.py (Arbeitskopie) @@ -7,6 +7,7 @@ import os from random import randrange, shuffle import sys +import warnings class PassThru(Exception): pass @@ -817,6 +818,44 @@ self.length = 3 self.repr = None +#------------------------------------------------------------------------------ + +class TestBasicOpsString(TestBasicOps): + def setUp(self): + self.case = "string set" + self.values = ["a", "b", "c"] + self.set = set(self.values) + self.dup = set(self.values) + self.length = 3 + self.repr = "{'a', 'c', 'b'}" + +#------------------------------------------------------------------------------ + +class TestBasicOpsBytes(TestBasicOps): + def setUp(self): + self.case = "string set" + self.values = [b"a", b"b", b"c"] + self.set = set(self.values) + self.dup = set(self.values) + self.length = 3 + self.repr = "{b'a', b'c', b'b'}" + +#------------------------------------------------------------------------------ + +class TestBasicOpsMixedStringBytes(TestBasicOps): + def setUp(self): + self.warning_filters = warnings.filters[:] + warnings.simplefilter('ignore', BytesWarning) + self.case = "string and bytes set" + self.values = ["a", "b", b"a", b"b"] + self.set = set(self.values) + self.dup = set(self.values) + self.length = 4 + self.repr = "{'a', b'a', 'b', b'b'}" + + def tearDown(self): + warnings.filters = self.warning_filters + #============================================================================== def baditer(): @@ -1581,6 +1620,9 @@ TestBasicOpsSingleton, TestBasicOpsTuple, TestBasicOpsTriple, + TestBasicOpsString, + TestBasicOpsBytes, + TestBasicOpsMixedStringBytes, TestBinaryOps, TestUpdateOps, TestMutate,