diff --git a/Lib/_abcoll.py b/Lib/_abcoll.py --- a/Lib/_abcoll.py +++ b/Lib/_abcoll.py @@ -165,12 +165,17 @@ def __gt__(self, other): if not isinstance(other, Set): return NotImplemented - return other.__lt__(self) + return len(self) > len(other) and self.__ge__(other) def __ge__(self, other): if not isinstance(other, Set): return NotImplemented - return other.__le__(self) + if len(self) < len(other): + return False + for elem in other: + if elem not in self: + return False + return True def __eq__(self, other): if not isinstance(other, Set): @@ -194,6 +199,8 @@ return NotImplemented return self._from_iterable(value for value in other if value in self) + __rand__ = __and__ + def isdisjoint(self, other): 'Return True if two sets have a null intersection.' for value in other: @@ -207,6 +214,8 @@ chain = (e for s in (self, other) for e in s) return self._from_iterable(chain) + __ror__ = __or__ + def __sub__(self, other): if not isinstance(other, Set): if not isinstance(other, Iterable): @@ -215,6 +224,14 @@ return self._from_iterable(value for value in self if value not in other) + def __rsub__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return self._from_iterable(value for value in other + if value not in self) + def __xor__(self, other): if not isinstance(other, Set): if not isinstance(other, Iterable): @@ -222,6 +239,8 @@ other = self._from_iterable(other) return (self - other) | (other - self) + __rxor__ = __xor__ + # Sets are not hashable by default, but subclasses can change this __hash__ = None diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -618,10 +618,160 @@ cs = MyComparableSet() ncs = MyNonComparableSet() - self.assertFalse(ncs < cs) - self.assertFalse(ncs <= cs) - self.assertFalse(cs > ncs) - self.assertFalse(cs >= ncs) + + # Run all the variants to make sure they don't mutually recurse + ncs < cs + ncs <= cs + ncs > cs + ncs >= cs + cs < ncs + cs <= ncs + cs > ncs + cs >= ncs + + def assertSameSet(self, s1, s2): + # coerce both to a real set then check equality + self.assertEqual(set(s1), set(s2)) + + def test_Set_interoperability_with_real_sets(self): + # Issue: 8743 + class ListSet(Set): + def __init__(self, elements=()): + self.data = [] + for elem in elements: + if elem not in self.data: + self.data.append(elem) + def __contains__(self, elem): + return elem in self.data + def __iter__(self): + return iter(self.data) + def __len__(self): + return len(self.data) + def __repr__(self): + return 'Set({!r})'.format(self.data) + + r1 = set('abc') + r2 = set('bcd') + r3 = set('abcde') + f1 = ListSet('abc') + f2 = ListSet('bcd') + f3 = ListSet('abcde') + l1 = list('abccba') + l2 = list('bcddcb') + l3 = list('abcdeedcba') + + target = r1 & r2 + self.assertSameSet(f1 & f2, target) + self.assertSameSet(f1 & r2, target) + self.assertSameSet(r2 & f1, target) + self.assertSameSet(f1 & l2, target) + + target = r1 | r2 + self.assertSameSet(f1 | f2, target) + self.assertSameSet(f1 | r2, target) + self.assertSameSet(r2 | f1, target) + self.assertSameSet(f1 | l2, target) + + fwd_target = r1 - r2 + rev_target = r2 - r1 + self.assertSameSet(f1 - f2, fwd_target) + self.assertSameSet(f2 - f1, rev_target) + self.assertSameSet(f1 - r2, fwd_target) + self.assertSameSet(f2 - r1, rev_target) + self.assertSameSet(r1 - f2, fwd_target) + self.assertSameSet(r2 - f1, rev_target) + self.assertSameSet(f1 - l2, fwd_target) + self.assertSameSet(f2 - l1, rev_target) + + target = r1 ^ r2 + self.assertSameSet(f1 ^ f2, target) + self.assertSameSet(f1 ^ r2, target) + self.assertSameSet(r2 ^ f1, target) + self.assertSameSet(f1 ^ l2, target) + + # proper subset + self.assertTrue(f1 < f3) + self.assertFalse(f1 < f1) + self.assertFalse(f1 < f2) + self.assertTrue(r1 < f3) + self.assertFalse(r1 < f1) + self.assertFalse(r1 < f2) + self.assertTrue(r1 < r3) + self.assertFalse(r1 < r1) + self.assertFalse(r1 < r2) + # python 2 only, cross-type compares will succeed + f1 < l3 + f1 < l1 + f1 < l2 + + # any subset + self.assertTrue(f1 <= f3) + self.assertTrue(f1 <= f1) + self.assertFalse(f1 <= f2) + self.assertTrue(r1 <= f3) + self.assertTrue(r1 <= f1) + self.assertFalse(r1 <= f2) + self.assertTrue(r1 <= r3) + self.assertTrue(r1 <= r1) + self.assertFalse(r1 <= r2) + # python 2 only, cross-type compares will succeed + f1 <= l3 + f1 <= l1 + f1 <= l2 + + # proper superset + self.assertTrue(f3 > f1) + self.assertFalse(f1 > f1) + self.assertFalse(f2 > f1) + self.assertTrue(r3 > r1) + self.assertFalse(f1 > r1) + self.assertFalse(f2 > r1) + self.assertTrue(r3 > r1) + self.assertFalse(r1 > r1) + self.assertFalse(r2 > r1) + # python 2 only, cross-type compares will succeed + f1 > l3 + f1 > l1 + f1 > l2 + + # any superset + self.assertTrue(f3 >= f1) + self.assertTrue(f1 >= f1) + self.assertFalse(f2 >= f1) + self.assertTrue(r3 >= r1) + self.assertTrue(f1 >= r1) + self.assertFalse(f2 >= r1) + self.assertTrue(r3 >= r1) + self.assertTrue(r1 >= r1) + self.assertFalse(r2 >= r1) + # python 2 only, cross-type compares will succeed + f1 >= l3 + f1 >=l1 + f1 >= l2 + + # equality + self.assertTrue(f1 == f1) + self.assertTrue(r1 == f1) + self.assertTrue(f1 == r1) + self.assertFalse(f1 == f3) + self.assertFalse(r1 == f3) + self.assertFalse(f1 == r3) + # python 2 only, cross-type compares will succeed + f1 == l3 + f1 == l1 + f1 == l2 + + # inequality + self.assertFalse(f1 != f1) + self.assertFalse(r1 != f1) + self.assertFalse(f1 != r1) + self.assertTrue(f1 != f3) + self.assertTrue(r1 != f3) + self.assertTrue(f1 != r3) + # python 2 only, cross-type compares will succeed + f1 != l3 + f1 != l1 + f1 != l2 def test_Mapping(self): for sample in [dict]: diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -1017,8 +1017,6 @@ # without calling __cmp__. self.assertEqual(cmp(a, a), 0) - self.assertRaises(TypeError, cmp, a, 12) - self.assertRaises(TypeError, cmp, "abc", a) #============================================================================== @@ -1269,17 +1267,6 @@ self.assertEqual(self.other != self.set, True) self.assertEqual(self.set != self.other, True) - def test_ge_gt_le_lt(self): - self.assertRaises(TypeError, lambda: self.set < self.other) - self.assertRaises(TypeError, lambda: self.set <= self.other) - self.assertRaises(TypeError, lambda: self.set > self.other) - self.assertRaises(TypeError, lambda: self.set >= self.other) - - self.assertRaises(TypeError, lambda: self.other < self.set) - self.assertRaises(TypeError, lambda: self.other <= self.set) - self.assertRaises(TypeError, lambda: self.other > self.set) - self.assertRaises(TypeError, lambda: self.other >= self.set) - def test_update_operator(self): try: self.set |= self.other @@ -1392,18 +1379,6 @@ #------------------------------------------------------------------------------ -class TestOnlySetsOperator(TestOnlySetsInBinaryOps): - def setUp(self): - self.set = set((1, 2, 3)) - self.other = operator.add - self.otherIsIterable = False - - def test_ge_gt_le_lt(self): - with test_support.check_py3k_warnings(): - super(TestOnlySetsOperator, self).test_ge_gt_le_lt() - -#------------------------------------------------------------------------------ - class TestOnlySetsTuple(TestOnlySetsInBinaryOps): def setUp(self): self.set = set((1, 2, 3)) @@ -1801,7 +1776,6 @@ TestSubsetNonOverlap, TestOnlySetsNumeric, TestOnlySetsDict, - TestOnlySetsOperator, TestOnlySetsTuple, TestOnlySetsString, TestOnlySetsGenerator, diff --git a/Objects/setobject.c b/Objects/setobject.c --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -1796,12 +1796,8 @@ PyObject *r1, *r2; if(!PyAnySet_Check(w)) { - if (op == Py_EQ) - Py_RETURN_FALSE; - if (op == Py_NE) - Py_RETURN_TRUE; - PyErr_SetString(PyExc_TypeError, "can only compare to a set"); - return NULL; + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; } switch (op) { case Py_EQ: