diff --git a/Lib/functools.py b/Lib/functools.py --- a/Lib/functools.py +++ b/Lib/functools.py @@ -55,23 +55,28 @@ convert = { '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)), ('__le__', lambda self, other: self < other or self == other), + ('__ne__', lambda self, other: not self == other), ('__ge__', lambda self, other: not self < other)], '__le__': [('__ge__', lambda self, other: not self <= other or self == other), ('__lt__', lambda self, other: self <= other and not self == other), + ('__ne__', lambda self, other: not self == other), ('__gt__', lambda self, other: not self <= other)], '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)), ('__ge__', lambda self, other: self > other or self == other), + ('__ne__', lambda self, other: not self == other), ('__le__', lambda self, other: not self > other)], '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other), ('__gt__', lambda self, other: self >= other and not self == other), + ('__ne__', lambda self, other: not self == other), ('__lt__', lambda self, other: not self >= other)] } - roots = set(dir(cls)) & set(convert) + defined_methods = set(dir(cls)) + roots = defined_methods & set(convert) if not roots: raise ValueError('must define at least one ordering operation: < > <= >=') root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__ for opname, opfunc in convert[root]: - if opname not in roots: + if opname not in defined_methods: opfunc.__name__ = opname opfunc.__doc__ = getattr(int, opname).__doc__ setattr(cls, opname, opfunc) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -474,6 +474,33 @@ with self.assertRaises(TypeError): TestTO(8) <= () + def test_bug_25732(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __gt__(self, other): + return self.value > other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) != A(2)) + self.assertFalse(A(1) != A(1)) + + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __gt__(self, other): + return self.value > other.value + def __eq__(self, other): + return self.value == other.value + def __ne__(self, other): + raise RuntimeError(self, other) + with self.assertRaises(RuntimeError): + A(1) != A(2) + with self.assertRaises(RuntimeError): + A(1) != A(1) + def test_main(verbose=None): test_classes = ( TestPartial,