diff --git a/Lib/_abcoll.py b/Lib/_abcoll.py index 0957553..a1b9b3c 100644 --- a/Lib/_abcoll.py +++ b/Lib/_abcoll.py @@ -209,7 +209,7 @@ class Set(Sized, Iterable, Container): return cls(it) def __and__(self, other): - if not isinstance(other, Iterable): + if not isinstance(other, Set): return NotImplemented return self._from_iterable(value for value in other if value in self) @@ -220,26 +220,30 @@ class Set(Sized, Iterable, Container): return True def __or__(self, other): - if not isinstance(other, Iterable): + if not isinstance(other, Set): return NotImplemented chain = (e for s in (self, other) for e in s) return self._from_iterable(chain) def __sub__(self, other): if not isinstance(other, Set): - if not isinstance(other, Iterable): - return NotImplemented - other = self._from_iterable(other) + return NotImplemented return self._from_iterable(value for value in self if value not in other) def __xor__(self, other): if not isinstance(other, Set): - if not isinstance(other, Iterable): - return NotImplemented - other = self._from_iterable(other) + return NotImplemented return (self - other) | (other - self) + def __rsub__(self, other): + if not isinstance(other, Set): + return NotImplemented + return self._from_iterable(other) - self + __ror__ = __or__ + __rand__ = __and__ + __rxor__ = __xor__ + def _hash(self): """Compute the hash value of a set. @@ -311,34 +315,40 @@ class MutableSet(Set): except KeyError: pass - def __ior__(self, it: Iterable): - for value in it: + def __ior__(self, other: Set): + if not isinstance(other, Set): + return NotImplemented + for value in other: self.add(value) return self - def __iand__(self, it: Iterable): - for value in (self - it): + def __iand__(self, other: Set): + if not isinstance(other, Set): + return NotImplemented + for value in (self - other): self.discard(value) return self - def __ixor__(self, it: Iterable): - if it is self: + def __ixor__(self, other: Set): + if not isinstance(other, Set): + return NotImplemented + if other is self: self.clear() else: - if not isinstance(it, Set): - it = self._from_iterable(it) - for value in it: + for value in other: if value in self: self.discard(value) else: self.add(value) return self - def __isub__(self, it: Iterable): - if it is self: + def __isub__(self, other: Set): + if not isinstance(other, Set): + return NotImplemented + if other is self: self.clear() else: - for value in it: + for value in other: self.discard(value) return self diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 02b9dc3..7233cf6 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -476,6 +476,16 @@ class TestCollectionABCs(ABCTestCase): return iter([]) self.validate_comparison(MySet()) + def test_set_interoperability(self): + # Classes inherited from MutableSet should interoperate with + # regular sets. Issues #8743 and #2226. + w = WithSet('abracadabra') + s = set('simsalabim') + self.assertEqual(sorted(s&w), sorted(w&s)) + self.assertEqual(sorted(s|w), sorted(w|s)) + self.assertEqual(sorted(s^w), sorted(w^s)) + self.assertEqual(sorted(s-w), sorted(WithSet(s) - set(w))) + def test_hash_Set(self): class OneTwoThreeSet(Set): def __init__(self):