Index: Objects/setobject.c =================================================================== --- Objects/setobject.c (revision 64056) +++ Objects/setobject.c (working copy) @@ -955,15 +955,23 @@ } static PyObject * -set_update(PySetObject *so, PyObject *other) +set_update(PySetObject *so, PyObject *others) { - if (set_update_internal(so, other) == -1) - return NULL; + Py_ssize_t i; + Py_ssize_t nothers; + + nothers = PyTuple_GET_SIZE(others); + for (i = 0; i < nothers; i++) { + PyObject *other = PyTuple_GET_ITEM(others, i); + if (set_update_internal(so, other) == -1) { + return NULL; + } + } Py_RETURN_NONE; } PyDoc_STRVAR(update_doc, -"Update a set with the union of itself and another."); +"Update a set with the union of itself and any number of iterables."); static PyObject * make_new_set(PyTypeObject *type, PyObject *iterable) @@ -1144,35 +1152,46 @@ PyDoc_STRVAR(clear_doc, "Remove all elements from this set."); static PyObject * -set_union(PySetObject *so, PyObject *other) +set_union(PySetObject *so, PyObject *others) { PySetObject *result; - + PyObject *arg; + Py_ssize_t nothers; + Py_ssize_t i; + result = (PySetObject *)set_copy(so); if (result == NULL) return NULL; - if ((PyObject *)so == other) - return (PyObject *)result; - if (set_update_internal(result, other) == -1) { - Py_DECREF(result); - return NULL; + nothers = PyTuple_GET_SIZE(others); + for (i = 0; i < nothers; i++) { + arg = PyTuple_GET_ITEM(others, i); + if ((PyObject *)so != arg && set_update_internal(result, arg) == -1) { + Py_DECREF(result); + return NULL; + } } return (PyObject *)result; } PyDoc_STRVAR(union_doc, - "Return the union of two sets as a new set.\n\ + "Return the union of any number of sets as a new set.\n\ \n\ -(i.e. all elements that are in either set.)"); +(i.e. all elements that are at least one set.)"); static PyObject * set_or(PySetObject *so, PyObject *other) { + PySetObject *result; if (!PyAnySet_Check(so) || !PyAnySet_Check(other)) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - return set_union(so, other); + result = (PySetObject *)set_copy(so); + if (result == NULL) + return NULL; + if (set_update_internal(result, other) == -1) + return NULL; + return (PyObject *)result; } static PyObject * @@ -1189,7 +1208,7 @@ } static PyObject * -set_intersection(PySetObject *so, PyObject *other) +set_intersection2(PySetObject *so, PyObject *other) { PySetObject *result; PyObject *key, *it, *tmp; @@ -1271,17 +1290,190 @@ return (PyObject *)result; } +static PyObject * +set_intersection2_update(PySetObject *so, PyObject *other) +{ + PyObject *tmp; + + tmp = set_intersection2(so, other); + if (tmp == NULL) + return NULL; + set_swap_bodies(so, (PySetObject *)tmp); + Py_DECREF(tmp); + Py_RETURN_NONE; +} + +static PyObject * +set_intersection(PySetObject *so, PyObject *others) +{ + PySetObject *result; + PyObject *iterables = NULL; + Py_ssize_t size, niterables, i, j; + Py_ssize_t nhashtables; + PyObject *other; + + size = PyTuple_GET_SIZE(others); + if (size == 0) { + return set_copy(so); + } + + if (size == 1) { + return set_intersection2(so, PyTuple_GET_ITEM(others, 0)); + } + + /* Count how many items are sets, frozensets or dicts */ + nhashtables = 1; + for (i = 0; i < size; i++) { + other = PyTuple_GET_ITEM(others, i); + if (PyAnySet_Check(other) || PyDict_Check(other)) { + nhashtables++; + } + } + niterables = size+1 - nhashtables; + + if (nhashtables > 1) { + PyObject *sort_tuple, *small; + int is_set; + Py_ssize_t pos = 0; + PyObject *hashtables = PyList_New(nhashtables); + + /* Put all hashtables in a decorated list and sort them by + * size. Elements of the list are of the form: + * (len(S), index, S) + * where all indices are distinct. + * index is needed because comparing two sets of equal size + * throws an exception now. */ + if (hashtables == NULL) + return NULL; + if (niterables) { + iterables = PyTuple_New(niterables); + if (iterables == NULL) { + Py_DECREF(hashtables); + return NULL; + } + } + sort_tuple = Py_BuildValue("niO", PySet_GET_SIZE(so), 0, so); + if (sort_tuple == NULL) { + Py_DECREF(hashtables); + Py_XDECREF(iterables); + return NULL; + } + PyList_SET_ITEM(hashtables, 0, sort_tuple); + for (i = 0, j = 1; i < size; i++) { + other = PyTuple_GET_ITEM(others, i); + if (PyAnySet_Check(other) || PyDict_Check(other)) { + sort_tuple = Py_BuildValue("niO", PyObject_Size(other), j, other); + if (sort_tuple == NULL) { + Py_DECREF(hashtables); + Py_XDECREF(iterables); + return NULL; + } + PyList_SET_ITEM(hashtables, j++, sort_tuple); + } + else if (niterables) { + PyTuple_SET_ITEM(iterables, i+1 - j, other); + Py_INCREF(other); + } + } + + if (PyList_Sort(hashtables) == -1) { + Py_DECREF(hashtables); + Py_XDECREF(iterables); + return NULL; + } + + result = (PySetObject *)make_new_set(Py_TYPE(so), NULL); + if (result == NULL) { + Py_DECREF(hashtables); + Py_XDECREF(iterables); + return NULL; + } + + /* build the set of elements in the smallest hashtable which also + * belong to all the other hashtables */ + sort_tuple = PyList_GET_ITEM(hashtables, 0); + small = PyTuple_GET_ITEM(sort_tuple, 2); + is_set = PyAnySet_Check(small); + while (1) { + setentry *entry; + setentry de; +next_element_in_small: + if (is_set) { + if (!set_next((PySetObject *)small, &pos, &entry)) + break; + } + else { + if (!_PyDict_Next(small, &pos, &de.key, NULL, &de.hash)) + break; + entry = &de; + } + for (j = 1; j < nhashtables; j++) { + int rv; + sort_tuple = PyList_GET_ITEM(hashtables, j); + other = PyTuple_GET_ITEM(sort_tuple, 2); + if (PyAnySet_Check(other)) + rv = set_contains_entry((PySetObject *)other, entry); + else /* other is a dict */ + rv = _PyDict_Contains(other, entry->key, entry->hash); + if (rv == -1) { + Py_DECREF(result); + Py_DECREF(hashtables); + Py_XDECREF(iterables); + return NULL; + } + if (!rv) { + goto next_element_in_small; + } + } + if (set_add_entry(result, entry) == -1) { + Py_DECREF(result); + Py_DECREF(hashtables); + Py_XDECREF(iterables); + return NULL; + } + } + Py_DECREF(hashtables); + } + else /* nhashtables == 1 (it must be 'so'). So niterables > 0 */ + { + iterables = others; + result = so; + /* These will be DECREF'ed in the next part */ + Py_INCREF(iterables); + Py_INCREF(so); + } + + if (niterables) { + /* Intersect with iterables which are not hashtables */ + for (i = 0; i < niterables; i++) { + other = PyTuple_GET_ITEM(iterables, i); + so = result; + result = (PySetObject *)set_intersection2(so, other); + if (result == NULL) { + Py_DECREF(iterables); + Py_DECREF(so); + return NULL; + } + Py_DECREF(so); + } + Py_DECREF(iterables); + } + + return (PyObject *)result; +} + PyDoc_STRVAR(intersection_doc, -"Return the intersection of two sets as a new set.\n\ +"Return the intersection of this set with any number of iterables\n\ +as a new set.\n\ \n\ -(i.e. all elements that are in both sets.)"); +(i.e. all elements that are in the set and all iterables.)"); static PyObject * -set_intersection_update(PySetObject *so, PyObject *other) +set_intersection_update(PySetObject *so, PyObject *others) { PyObject *tmp; - tmp = set_intersection(so, other); + tmp = set_intersection(so, others); if (tmp == NULL) return NULL; set_swap_bodies(so, (PySetObject *)tmp); @@ -1290,7 +1482,7 @@ } PyDoc_STRVAR(intersection_update_doc, -"Update a set with the intersection of itself and another."); +"Update a set with the intersection of itself one or more others."); static PyObject * set_and(PySetObject *so, PyObject *other) @@ -1299,7 +1491,7 @@ Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - return set_intersection(so, other); + return set_intersection2(so, other); } static PyObject * @@ -1311,7 +1503,7 @@ Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - result = set_intersection_update(so, other); + result = set_intersection2_update(so, other); if (result == NULL) return NULL; Py_DECREF(result); @@ -1907,9 +2099,9 @@ difference_doc}, {"difference_update", (PyCFunction)set_difference_update, METH_O, difference_update_doc}, - {"intersection",(PyCFunction)set_intersection, METH_O, + {"intersection",(PyCFunction)set_intersection, METH_VARARGS, intersection_doc}, - {"intersection_update",(PyCFunction)set_intersection_update, METH_O, + {"intersection_update",(PyCFunction)set_intersection_update, METH_VARARGS, intersection_update_doc}, {"isdisjoint", (PyCFunction)set_isdisjoint, METH_O, isdisjoint_doc}, @@ -1931,9 +2123,9 @@ {"test_c_api", (PyCFunction)test_c_api, METH_NOARGS, test_c_api_doc}, #endif - {"union", (PyCFunction)set_union, METH_O, + {"union", (PyCFunction)set_union, METH_VARARGS, union_doc}, - {"update", (PyCFunction)set_update, METH_O, + {"update", (PyCFunction)set_update, METH_VARARGS, update_doc}, {NULL, NULL} /* sentinel */ }; @@ -2032,7 +2224,7 @@ copy_doc}, {"difference", (PyCFunction)set_difference, METH_O, difference_doc}, - {"intersection",(PyCFunction)set_intersection, METH_O, + {"intersection",(PyCFunction)set_intersection, METH_VARARGS, intersection_doc}, {"isdisjoint", (PyCFunction)set_isdisjoint, METH_O, isdisjoint_doc}, @@ -2044,7 +2236,7 @@ reduce_doc}, {"symmetric_difference",(PyCFunction)set_symmetric_difference, METH_O, symmetric_difference_doc}, - {"union", (PyCFunction)set_union, METH_O, + {"union", (PyCFunction)set_union, METH_VARARGS, union_doc}, {NULL, NULL} /* sentinel */ }; Index: Lib/test/test_set.py =================================================================== --- Lib/test/test_set.py (revision 64056) +++ Lib/test/test_set.py (working copy) @@ -79,7 +79,12 @@ self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg')) self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc')) self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef')) - + for C1, C2, C3 in ((set, dict.fromkeys, frozenset), + (dict.fromkeys, frozenset, tuple), + (str, list, tuple)): + self.assertEqual(self.thetype('abcdaba').union( + C1('bcdb'), C2('ebce'), C3('dbf')), set('abcdef')) + def test_or(self): i = self.s.union(self.otherword) self.assertEqual(self.s | set(self.otherword), i) @@ -103,6 +108,11 @@ self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set('')) self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc')) self.assertEqual(self.thetype('abcba').intersection(C('ef')), set('')) + for C1, C2, C3 in ((set, dict.fromkeys, frozenset), + (dict.fromkeys, frozenset, tuple), + (str, list, tuple)): + self.assertEqual(self.thetype('abcdaba').intersection( + C1('bcdb'), C2('ebce'), C3('dbf')), set('b')) def test_isdisjoint(self): def f(s1, s2): @@ -410,6 +420,12 @@ s = self.thetype('abcba') self.assertEqual(s.update(C(p)), None) self.assertEqual(s, set(q)) + s = self.thetype('abcba') + self.assertEqual(s.update(), None) + self.assertEqual(s, set('abc')) + s = self.thetype('abcba') + self.assertEqual(s.update('abda', 'eaa', 'dfd'), None) + self.assertEqual(s, set('abcdef')) def test_ior(self): self.s |= set(self.otherword) @@ -431,7 +447,13 @@ s = self.thetype('abcba') self.assertEqual(s.intersection_update(C(p)), None) self.assertEqual(s, set(q)) - + s = self.thetype('abcba') + self.assertEqual(s.intersection_update(), None) + self.assertEqual(s, set('abc')) + s = self.thetype('abcba') + self.assertEqual(s.intersection_update('befbdc', 'bbca', 'abd'), None) + self.assertEqual(s, set('b')) + def test_iand(self): self.s &= set(self.otherword) for c in (self.word + self.otherword):