import unittest def total_ordering(cls): """Class decorator that fills in missing ordering methods""" convert = { '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)), ('__le__', lambda self, other: self < other or self == other), ('__ge__', lambda self, other: not self < other)], '__le__': [('__ge__', lambda self, other: not self <= other or self == other), ('__lt__', lambda self, other: self <= other and not self == other), ('__gt__', lambda self, other: not self <= other)], '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)), ('__ge__', lambda self, other: self > other or self == other), ('__le__', lambda self, other: not self > other)], '__ge__': [('__le__', lambda self, other: not self >= other or self == other), ('__gt__', lambda self, other: self >= other and not self == other), ('__lt__', lambda self, other: not self >= other)] } # Find comparisons not inherited from object. roots = [op for op in convert if getattr(cls, op) is not getattr(object, op)] if not roots: raise ValueError('must define at least one ordering operation: < > <= >=') root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__ for opname, opfunc in convert[root]: if opname not in roots: opfunc.__name__ = opname opfunc.__doc__ = getattr(int, opname).__doc__ setattr(cls, opname, opfunc) return cls @total_ordering class TestTotalOrderingLT: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, TestTotalOrderingLT): return self.value == other.value return False def __lt__(self, other): if isinstance(other, TestTotalOrderingLT): return self.value < other.value raise TypeError @total_ordering class TestTotalOrderingLE: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, TestTotalOrderingLE): return self.value == other.value return False def __le__(self, other): if isinstance(other, TestTotalOrderingLE): return self.value <= other.value raise TypeError @total_ordering class TestTotalOrderingGT: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, TestTotalOrderingGT): return self.value == other.value return False def __gt__(self, other): if isinstance(other, TestTotalOrderingGT): return self.value > other.value raise TypeError @total_ordering class TestTotalOrderingGE: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, TestTotalOrderingGE): return self.value == other.value return False def __ge__(self, other): if isinstance(other, TestTotalOrderingGE): return self.value >= other.value raise TypeError class TestTotalOrderingMixin: def test_good_comparison(self): two = self.theclass(2) two2 = self.theclass(2) six = self.theclass(6) self.assertTrue(two < six) self.assertTrue(two <= six) self.assertTrue(two <= two2) self.assertTrue(six > two) self.assertTrue(six >= two) self.assertTrue(two >= two2) self.assertFalse(two > six) self.assertFalse(two > two2) self.assertFalse(two >= six) self.assertFalse(six < two) self.assertFalse(two < two2) self.assertFalse(six <= two) def test_bad_comparison(self): two = self.theclass(2) with self.assertRaises(TypeError): dummy = two < () with self.assertRaises(TypeError): dummy = two <= () with self.assertRaises(TypeError): dummy = two > () with self.assertRaises(TypeError): dummy = two >= () class TestTotalOrdering_LT(TestTotalOrderingMixin, unittest.TestCase): theclass = TestTotalOrderingLT class TestTotalOrdering_LE(TestTotalOrderingMixin, unittest.TestCase): theclass = TestTotalOrderingLE class TestTotalOrdering_GT(TestTotalOrderingMixin, unittest.TestCase): theclass = TestTotalOrderingGT class TestTotalOrdering_GE(TestTotalOrderingMixin, unittest.TestCase): theclass = TestTotalOrderingGE if __name__ == "__main__": unittest.main()