diff -r d7ac90c0463a Lib/_collections_abc.py --- a/Lib/_collections_abc.py Sat Feb 01 22:49:59 2014 +0100 +++ b/Lib/_collections_abc.py Sun Feb 02 18:11:02 2014 +1000 @@ -232,7 +232,7 @@ return cls(it) def __and__(self, other): - if not isinstance(other, Iterable): + if not isinstance(other, Set): return NotImplemented return self._from_iterable(value for value in other if value in self) @@ -244,26 +244,31 @@ return True def __or__(self, other): - if not isinstance(other, Iterable): + if not isinstance(other, Set): return NotImplemented chain = (e for s in (self, other) for e in s) return self._from_iterable(chain) def __sub__(self, other): if not isinstance(other, Set): - if not isinstance(other, Iterable): - return NotImplemented - other = self._from_iterable(other) + return NotImplemented return self._from_iterable(value for value in self if value not in other) def __xor__(self, other): if not isinstance(other, Set): - if not isinstance(other, Iterable): - return NotImplemented - other = self._from_iterable(other) + return NotImplemented return (self - other) | (other - self) + def __rsub__(self, other): + if not isinstance(other, Set): + return NotImplemented + return self._from_iterable(value for value in other + if value not in self) + __ror__ = __or__ + __rand__ = __and__ + __rxor__ = __xor__ + def _hash(self): """Compute the hash value of a set. @@ -347,34 +352,42 @@ except KeyError: pass - def __ior__(self, it): - for value in it: + def __ior__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + for value in other: self.add(value) return self - def __iand__(self, it): - for value in (self - it): + def __iand__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + if not isinstance(other, Set): + other = self._from_iterable(other) + for value in (self - other): self.discard(value) return self - def __ixor__(self, it): - if it is self: + def __ixor__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + if other is self: self.clear() else: - if not isinstance(it, Set): - it = self._from_iterable(it) - for value in it: + for value in other: if value in self: self.discard(value) else: self.add(value) return self - def __isub__(self, it): - if it is self: + def __isub__(self, other): + if not isinstance(other, Iterable): + return NotImplemented + if other is self: self.clear() else: - for value in it: + for value in other: self.discard(value) return self diff -r d7ac90c0463a Lib/test/test_collections.py --- a/Lib/test/test_collections.py Sat Feb 01 22:49:59 2014 +0100 +++ b/Lib/test/test_collections.py Sun Feb 02 18:11:02 2014 +1000 @@ -625,6 +625,82 @@ return iter([]) self.validate_comparison(MySet()) + def test_set_interoperability(self): + # Classes inherited from Set/MutableSet should interoperate with + # regular sets. Issues #8743 and #2226. + # Check the regular Set methods + w = WithSet('abracadabra') + s = set('simsalabim') + expected_intersection = sorted(set(w) & s) + self.assertEqual(sorted(s&w), expected_intersection) + self.assertEqual(sorted(w&s), expected_intersection) + expected_union = sorted(set(w) | s) + self.assertEqual(sorted(s|w), expected_union) + self.assertEqual(sorted(w|s), expected_union) + expected_disjunction = sorted(set(w) ^ s) + self.assertEqual(sorted(s^w), expected_disjunction) + self.assertEqual(sorted(w^s), expected_disjunction) + expected_difference_w_as_lhs = sorted(set(w) - s) + self.assertEqual(sorted(w-s), expected_difference_w_as_lhs) + expected_difference_s_as_lhs = sorted(s - set(w)) + self.assertEqual(sorted(s-w), expected_difference_s_as_lhs) + # Check the MutableSet in-place methods interoperate not only + # with builtin sets, but also arbitrary iterables + w2 = WithSet(w) + w2 &= s + self.assertEqual(sorted(w2), expected_intersection) + w2 = WithSet(w) + w2 &= iter(s) + self.assertEqual(sorted(w2), expected_intersection) + s2 = set(s) + s2 &= w + self.assertEqual(sorted(s2), expected_intersection) + w2 = WithSet(w) + w2 |= s + self.assertEqual(sorted(w2), expected_union) + w2 = WithSet(w) + w2 |= iter(s) + self.assertEqual(sorted(w2), expected_union) + s2 = set(s) + s2 |= w + self.assertEqual(sorted(s2), expected_union) + w2 = WithSet(w) + w2 ^= s + self.assertEqual(sorted(w2), expected_disjunction) + w2 = WithSet(w) + w2 ^= iter(s) + self.assertEqual(sorted(w2), expected_disjunction) + s2 = set(s) + s2 ^= w + self.assertEqual(sorted(s2), expected_disjunction) + w2 = WithSet(w) + w2 -= s + self.assertEqual(sorted(w2), expected_difference_w_as_lhs) + w2 = WithSet(w) + w2 -= iter(s) + self.assertEqual(sorted(w2), expected_difference_w_as_lhs) + s2 = set(s) + s2 -= w + self.assertEqual(sorted(s2), expected_difference_s_as_lhs) + + def test_unhandled_operand(self): + # Classes inherited from Set/MutableSet should interoperate + # with other operands by returning NotImplemented when appropriate + w = WithSet('abracadabra') + self.assertIs(w.__or__(1), NotImplemented) + self.assertIs(w.__and__(1), NotImplemented) + self.assertIs(w.__sub__(1), NotImplemented) + self.assertIs(w.__xor__(1), NotImplemented) + self.assertIs(w.__ror__(1), NotImplemented) + self.assertIs(w.__rand__(1), NotImplemented) + self.assertIs(w.__rsub__(1), NotImplemented) + self.assertIs(w.__rxor__(1), NotImplemented) + self.assertIs(w.__ior__(1), NotImplemented) + self.assertIs(w.__iand__(1), NotImplemented) + self.assertIs(w.__isub__(1), NotImplemented) + self.assertIs(w.__ixor__(1), NotImplemented) + + def test_hash_Set(self): class OneTwoThreeSet(Set): def __init__(self):