diff --git a/Doc/library/collections.abc.rst b/Doc/library/collections.abc.rst --- a/Doc/library/collections.abc.rst +++ b/Doc/library/collections.abc.rst @@ -55,13 +55,14 @@ :class:`Set` :class:`Sized`, ``__contains__``, ``__le__``, ``__lt__``, ``__eq__``, ``__ne__``, :class:`Iterable`, ``__iter__``, ``__gt__``, ``__ge__``, ``__and__``, ``__or__``, - :class:`Container` ``__len__`` ``__sub__``, ``__xor__``, and ``isdisjoint`` + :class:`Container` ``__len__`` ``__sub__``, ``__xor__``, ``intersection``, + ``union``, ``difference``, and ``isdisjoint`` :class:`MutableSet` :class:`Set` ``__contains__``, Inherited :class:`Set` methods and ``__iter__``, ``clear``, ``pop``, ``remove``, ``__ior__``, - ``__len__``, ``__iand__``, ``__ixor__``, and ``__isub__`` - ``add``, - ``discard`` + ``__len__``, ``__iand__``, ``__ixor__``, ``__isub__``, + ``add``, ``update``, ``intersection_update``, + ``discard`` and ``difference_update`` :class:`Mapping` :class:`Sized`, ``__getitem__``, ``__contains__``, ``keys``, ``items``, ``values``, :class:`Iterable`, ``__iter__``, ``get``, ``__eq__``, and ``__ne__`` diff --git a/Lib/_collections_abc.py b/Lib/_collections_abc.py --- a/Lib/_collections_abc.py +++ b/Lib/_collections_abc.py @@ -236,11 +236,27 @@ ''' return cls(it) + def copy(self): + """Return a new set with a shallow copy of s.""" + return self._from_iterable(self) + def __and__(self, other): if not isinstance(other, Iterable): return NotImplemented return self._from_iterable(value for value in other if value in self) + def _bin_op(self, bin_op, iterables): + from functools import reduce + ret = reduce(bin_op, iterables, self.copy()) + if ret == NotImplemented: + raise TypeError('%r is not implemented for %r' % ( + bin_op, list(map(type, iterables)))) + return ret + + def intersection(self, *iterables): + """Return a new set with elements common to the set and all others.""" + return self._bin_op(self.__class__.__and__, iterables) + __rand__ = __and__ def isdisjoint(self, other): @@ -256,6 +272,10 @@ chain = (e for s in (self, other) for e in s) return self._from_iterable(chain) + def union(self, *iterables): + """Return a new set with elements from the set and all others.""" + return self._bin_op(self.__class__.__or__, iterables) + __ror__ = __or__ def __sub__(self, other): @@ -266,6 +286,12 @@ return self._from_iterable(value for value in self if value not in other) + def difference(self, *iterables): + """Return a new set with elements in the set that are not in the others. + + """ + return self._bin_op(self.__class__.__sub__, iterables) + def __rsub__(self, other): if not isinstance(other, Set): if not isinstance(other, Iterable): @@ -371,11 +397,24 @@ self.add(value) return self + def _inplace_bin_op(self, inplace_bin_op, iterables): + from functools import reduce + reduce(inplace_bin_op, iterables, self) + # return None + + def update(self, *iterables): + """Update the set, adding elements from all others.""" + return self._inplace_bin_op(self.__class__.__ior__, iterables) + def __iand__(self, it): for value in (self - it): self.discard(value) return self + def intersection_update(self, *iterables): + """Update the set, keeping only elements found in it and all others.""" + return self._inplace_bin_op(self.__class__.__iand__, iterables) + def __ixor__(self, it): if it is self: self.clear() @@ -397,6 +436,10 @@ self.discard(value) return self + def difference_update(self, *iterables): + """Update the set, removing elements found in others.""" + return self._inplace_bin_op(self.__class__.__isub__, iterables) + MutableSet.register(set) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -19,6 +19,12 @@ from collections.abc import Mapping, MutableMapping, KeysView, ItemsView from collections.abc import Sequence, MutableSequence from collections.abc import ByteString +from itertools import combinations, permutations +from random import Random + +def powerset(seq, key=lambda x: x): # based on itertools recipe + return [key(comb) + for r in range(len(seq) + 1) for comb in combinations(seq, r)] ################################################################################ @@ -1053,6 +1059,61 @@ mss.clear() self.assertEqual(len(mss), 0) + +class TestSetMixin: + # test collections.abc.Set + def setUp(self): + self.test_cases = powerset(powerset('abc'), list) + self.random = Random() + self.random.seed(1234567890) # reproducable tests + self.random.shuffle(self.test_cases) + + def _test_op(self, opname): + # check that self.Set and set produce the same result + op, set_op = getattr(self.Set, opname), getattr(set, opname) + with self.assertRaises(TypeError): + op(self.Set(), 1) # non-iterable + + for sets in self.test_cases: + if support.is_resource_enabled('cpu'): + it = permutations(sets) + else: + self.random.shuffle(sets) + it = [sets[i:i+1]+sets[:i]+sets[i+1:] for i in range(len(sets))] + + for p in it: + if not p: + continue + self_set, *iterables = p + a, b = self.Set(self_set), set(self_set) + self.assertEqual(op(a, *iterables), set_op(b, *iterables)) + self.assertEqual(a, b) + + def test_difference(self): + self._test_op('difference') + + def test_intersection(self): + self._test_op('intersection') + + def test_union(self): + self._test_op('union') + + +class TestMutableSetMixin(TestSetMixin): + # test collections.abc.MutableSet + def test_difference_update(self): + self._test_op('difference_update') + + def test_intersection_update(self): + self._test_op('intersection_update') + + def test_update(self): + self._test_op('update') + +class TestMutableSet_WithSet(TestMutableSetMixin, unittest.TestCase): + Set = WithSet + + ################################################################################ ### Counter ################################################################################ @@ -1611,6 +1672,7 @@ NamedTupleDocs = doctest.DocTestSuite(module=collections) test_classes = [TestNamedTuple, NamedTupleDocs, TestOneTrickPonyABCs, TestCollectionABCs, TestCounter, TestChainMap, + TestMutableSet_WithSet, TestOrderedDict, GeneralMappingTests, SubclassMappingTests] support.run_unittest(*test_classes) support.run_doctest(collections, verbose)