diff --git a/Lib/socket.py b/Lib/socket.py index 96f8ed0..24f896b 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -48,6 +48,7 @@ import _socket from _socket import * import os, sys, io +from enum import IntEnum try: import errno @@ -60,6 +61,26 @@ EWOULDBLOCK = getattr(errno, 'EWOULDBLOCK', 11) __all__ = ["getfqdn", "create_connection"] __all__.extend(os._get_exports_list(_socket)) +# Set up the socket.AF_* constants as members of an IntEnum for nicer string +# representations. +# Note that _socket only knows about the integer values. The public interface +# in this module understands the enums and translates them back from integers +# where needed (e.g. .family property of a socket object). +_moduledict = globals() +AddressFamily = IntEnum('AddressFamily', + {name: value for name, value in _moduledict.items() + if name.isupper() and name.startswith('AF_')}) +_moduledict.update(AddressFamily.__members__) + +def _family_converter(family): + """Convert a numeric family value fo AddressFamily member. + + If it's not a known member, return the numeric value itself. + """ + try: + return AddressFamily(family) + except ValueError: + return family _realsocket = socket @@ -91,6 +112,10 @@ class socket(_socket.socket): __slots__ = ["__weakref__", "_io_refs", "_closed"] def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): + # For user code address family values are members of AddressFamily, but + # for the underlying _socket.socket they're just integers. The + # constructor of _socket.socket converts the given family argument to an + # integer automatically. _socket.socket.__init__(self, family, type, proto, fileno) self._io_refs = 0 self._closed = False @@ -229,6 +254,13 @@ class socket(_socket.socket): self._closed = True return super().detach() + @property + def family(self): + """Read-only access to the address family for this socket. + """ + return _family_converter(super().family) + + def fromfd(fd, family, type, proto=0): """ fromfd(fd, family, type[, proto]) -> socket object @@ -454,3 +486,29 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, raise err else: raise error("getaddrinfo returns an empty list") + +# Save the C-imported getaddrinfo, as we're going to override it next. +_getaddrinfo = getaddrinfo + +def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): + """Resolve host and port into list of address info entries. + + Translate the host/port argument into a sequence of 5-tuples that contain + all the necessary arguments for creating a socket connected to that service. + host is a domain name, a string representation of an IPv4/v6 address or + None. port is a string service name such as 'http', a numeric port number or + None. By passing None as the value of host and port, you can pass NULL to + the underlying C API. + + The family, type and proto arguments can be optionally specified in order to + narrow the list of addresses returned. Passing zero as a value for each of + these arguments selects the full range of results. + """ + # We override this function since we want to translate the numeric family + # and socket type values to AddressFamily constants. + addrlist = [] + for res in _getaddrinfo(host, port, family, type, proto, flags): + af, socktype, proto, canonname, sa = res + addrlist.append((_family_converter(af), socktype, proto, canonname, sa)) + return addrlist + diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 54fb9a1..ee9c7c5 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1166,6 +1166,8 @@ class GeneralModuleTests(unittest.TestCase): infos = socket.getaddrinfo(HOST, None, socket.AF_INET) for family, _, _, _, _ in infos: self.assertEqual(family, socket.AF_INET) + # Also check that it's an actual enum + self.assertEqual(str(family), 'AddressFamily.AF_INET') infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM) for _, socktype, _, _, _ in infos: self.assertEqual(socktype, socket.SOCK_STREAM) @@ -1321,6 +1323,20 @@ class GeneralModuleTests(unittest.TestCase): with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: self.assertRaises(OverflowError, s.bind, ('::1', 0, -10)) + def test_str_for_enums(self): + # Make sure that the AF_* constants have enum-like string reprs. + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + self.assertEqual(str(s.family), 'AddressFamily.AF_INET') + + def test_uknown_socket_family_repr(self): + # Test that when created with a family that's not one of the known + # AF_* constants, socket.family just returns the number. + # To do this we fool socket.socket into believing it already has an + # open fd because on this path it doesn't actually verify that the + # family is valid in the constructor. + fd, _ = tempfile.mkstemp() + with socket.socket(family=42424, fileno=fd) as s: + self.assertEqual(s.family, 42424) @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') class BasicCANTest(unittest.TestCase):