diff -r 15f46b850257 Lib/ipaddress.py --- a/Lib/ipaddress.py Sun Jan 18 17:40:17 2015 +0100 +++ b/Lib/ipaddress.py Sun Jan 18 22:01:20 2015 +0200 @@ -386,40 +386,7 @@ def get_mixed_type_key(obj): return NotImplemented -class _TotalOrderingMixin: - # Helper that derives the other comparison operations from - # __lt__ and __eq__ - # We avoid functools.total_ordering because it doesn't handle - # NotImplemented correctly yet (http://bugs.python.org/issue10042) - def __eq__(self, other): - raise NotImplementedError - def __ne__(self, other): - equal = self.__eq__(other) - if equal is NotImplemented: - return NotImplemented - return not equal - def __lt__(self, other): - raise NotImplementedError - def __le__(self, other): - less = self.__lt__(other) - if less is NotImplemented or not less: - return self.__eq__(other) - return less - def __gt__(self, other): - less = self.__lt__(other) - if less is NotImplemented: - return NotImplemented - equal = self.__eq__(other) - if equal is NotImplemented: - return NotImplemented - return not (less or equal) - def __ge__(self, other): - less = self.__lt__(other) - if less is NotImplemented: - return NotImplemented - return not less - -class _IPAddressBase(_TotalOrderingMixin): +class _IPAddressBase: """The mother class.""" @@ -449,6 +416,13 @@ class _IPAddressBase(_TotalOrderingMixin msg = '%200s has no version specified' % (type(self),) raise NotImplementedError(msg) + # XXX Default __ne__ returns True if __eq__ returns NotImplemented + def __ne__(self, other): + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not equal + def _check_int_address(self, address): if address < 0: msg = "%d (< 0) is not permitted as an IPv%d address" @@ -568,6 +542,7 @@ class _IPAddressBase(_TotalOrderingMixin cls._report_invalid_netmask(ip_str) +@functools.total_ordering class _BaseAddress(_IPAddressBase): """A generic IP object. @@ -592,12 +567,11 @@ class _BaseAddress(_IPAddressBase): return NotImplemented def __lt__(self, other): + if not isinstance(other, _BaseAddress): + return NotImplemented if self._version != other._version: raise TypeError('%s and %s are not of the same version' % ( self, other)) - if not isinstance(other, _BaseAddress): - raise TypeError('%s and %s are not of the same type' % ( - self, other)) if self._ip != other._ip: return self._ip < other._ip return False @@ -627,6 +601,7 @@ class _BaseAddress(_IPAddressBase): return (self._version, self) +@functools.total_ordering class _BaseNetwork(_IPAddressBase): """A generic IP network object. @@ -676,12 +651,11 @@ class _BaseNetwork(_IPAddressBase): return self._address_class(broadcast + n) def __lt__(self, other): + if not isinstance(other, _BaseNetwork): + return NotImplemented if self._version != other._version: raise TypeError('%s and %s are not of the same version' % ( self, other)) - if not isinstance(other, _BaseNetwork): - raise TypeError('%s and %s are not of the same type' % ( - self, other)) if self.network_address != other.network_address: return self.network_address < other.network_address if self.netmask != other.netmask: diff -r 15f46b850257 Lib/test/test_ipaddress.py --- a/Lib/test/test_ipaddress.py Sun Jan 18 17:40:17 2015 +0100 +++ b/Lib/test/test_ipaddress.py Sun Jan 18 22:01:20 2015 +0200 @@ -7,6 +7,7 @@ import unittest import re import contextlib +import functools import operator import ipaddress @@ -528,6 +529,20 @@ class FactoryFunctionErrors(BaseTestCase self.assertFactoryError(ipaddress.ip_network, "network") +@functools.total_ordering +class LargestObject: + def __eq__(self, other): + return isinstance(other, LargestObject) + def __lt__(self, other): + return False + +@functools.total_ordering +class SmallestObject: + def __eq__(self, other): + return isinstance(other, SmallestObject) + def __gt__(self, other): + return False + class ComparisonTests(unittest.TestCase): v4addr = ipaddress.IPv4Address(1) @@ -581,6 +596,28 @@ class ComparisonTests(unittest.TestCase) self.assertRaises(TypeError, lambda: lhs <= rhs) self.assertRaises(TypeError, lambda: lhs >= rhs) + def test_foreign_type_ordering(self): + other = object() + smallest = SmallestObject() + largest = LargestObject() + for obj in self.objects: + with self.assertRaises(TypeError): + obj < other + with self.assertRaises(TypeError): + obj > other + with self.assertRaises(TypeError): + obj <= other + with self.assertRaises(TypeError): + obj >= other + self.assertTrue(obj < largest) + self.assertFalse(obj > largest) + self.assertTrue(obj <= largest) + self.assertFalse(obj >= largest) + self.assertFalse(obj < smallest) + self.assertTrue(obj > smallest) + self.assertFalse(obj <= smallest) + self.assertTrue(obj >= smallest) + def test_mixed_type_key(self): # with get_mixed_type_key, you can sort addresses and network. v4_ordered = [self.v4addr, self.v4net, self.v4intf] @@ -601,7 +638,7 @@ class ComparisonTests(unittest.TestCase) v4addr = ipaddress.ip_address('1.1.1.1') v4net = ipaddress.ip_network('1.1.1.1') v6addr = ipaddress.ip_address('::1') - v6net = ipaddress.ip_address('::1') + v6net = ipaddress.ip_network('::1') self.assertRaises(TypeError, v4addr.__lt__, v6addr) self.assertRaises(TypeError, v4addr.__gt__, v6addr) @@ -1362,10 +1399,10 @@ class IpaddrUnitTest(unittest.TestCase): unsorted = [ip4, ip1, ip3, ip2] unsorted.sort() self.assertEqual(sorted, unsorted) - self.assertRaises(TypeError, ip1.__lt__, - ipaddress.ip_address('10.10.10.0')) - self.assertRaises(TypeError, ip2.__lt__, - ipaddress.ip_address('10.10.10.0')) + self.assertIs(ip1.__lt__(ipaddress.ip_address('10.10.10.0')), + NotImplemented) + self.assertIs(ip2.__lt__(ipaddress.ip_address('10.10.10.0')), + NotImplemented) # <=, >= self.assertTrue(ipaddress.ip_network('1.1.1.1') <=