# HG changeset patch # User David Watson , Heiko Wundram # Date 1307303051 -3600 # Branch 2.7 # Node ID 3788da99ca221287c595729e1301988815e5a943 # Parent 6a4b67e8530972470e41a18b7dcf035b101ee26c Issue #6560: support sendmsg() and recvmsg() on sockets. Patch by David Watson and Heiko Wundram. diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst --- a/Doc/library/socket.rst +++ b/Doc/library/socket.rst @@ -165,6 +165,7 @@ The module :mod:`socket` exports the fol SOMAXCONN MSG_* SOL_* + SCM_* IPPROTO_* IPPORT_* INADDR_* @@ -486,6 +487,49 @@ The module :mod:`socket` exports the fol .. versionadded:: 2.3 +.. + XXX: Are sendmsg(), recvmsg() and CMSG_*() available on any + non-Unix platforms? The old (obsolete?) 4.2BSD form of the + interface, in which struct msghdr has no msg_control or + msg_controllen members, is not currently supported. + +.. function:: CMSG_LEN(length) + + Return the total length, without trailing padding, of an ancillary + data item with associated data of the given *length*. This value + can often be used as the buffer size for :meth:`~socket.recvmsg` to + receive a single item of ancillary data, but :rfc:`3542` requires + portable applications to use :func:`CMSG_SPACE` and thus include + space for padding, even when the item will be the last in the + buffer. Raises :exc:`OverflowError` if *length* is outside the + permissible range of values. + + Availability: most Unix platforms, possibly others. + + .. versionadded:: XXX + + +.. function:: CMSG_SPACE(length) + + Return the buffer size needed for :meth:`~socket.recvmsg` to + receive an ancillary data item with associated data of the given + *length*, along with any trailing padding. The buffer space needed + to receive multiple items is the sum of the :func:`CMSG_SPACE` + values for their associated data lengths. Raises + :exc:`OverflowError` if *length* is outside the permissible range + of values. + + Note that some systems might support ancillary data without + providing this function. Also note that setting the buffer size + using the results of this function may not precisely limit the + amount of ancillary data that can be received, since additional + data may be able to fit into the padding area. + + Availability: most Unix platforms, possibly others. + + .. versionadded:: XXX + + .. function:: getdefaulttimeout() Return the default timeout in floating seconds for new socket objects. A value @@ -689,6 +733,109 @@ correspond to Unix system calls applicab to zero. (The format of *address* depends on the address family --- see above.) +.. method:: socket.recvmsg(bufsize[, ancbufsize[, flags]]) + + Receive normal data (up to *bufsize* bytes) and ancillary data from + the socket. The *ancbufsize* argument sets the size in bytes of + the internal buffer used to receive the ancillary data; it defaults + to 0, meaning that no ancillary data will be received. Appropriate + buffer sizes for ancillary data can be calculated using + :func:`CMSG_SPACE` or :func:`CMSG_LEN`, and items which do not fit + into the buffer might be truncated or discarded. The *flags* + argument defaults to 0 and has the same meaning as for + :meth:`recv`. + + The return value is a 4-tuple: ``(data, ancdata, msg_flags, + address)``. The *data* item is a :class:`bytes` object holding the + non-ancillary data received. The *ancdata* item is a list of zero + or more tuples ``(cmsg_level, cmsg_type, cmsg_data)`` representing + the ancillary data (control messages) received: *cmsg_level* and + *cmsg_type* are integers specifying the protocol level and + protocol-specific type respectively, and *cmsg_data* is a + :class:`bytes` object holding the associated data. The *msg_flags* + item is the bitwise OR of various flags indicating conditions on + the received message; see your system documentation for details. + If the receiving socket is unconnected, *address* is the address of + the sending socket, if available; otherwise, its value is + unspecified. + + On some systems, :meth:`sendmsg` and :meth:`recvmsg` can be used to + pass file descriptors between processes over an :const:`AF_UNIX` + socket. When this facility is used (it is often restricted to + :const:`SOCK_STREAM` sockets), :meth:`recvmsg` will return, in its + ancillary data, items of the form ``(socket.SOL_SOCKET, + socket.SCM_RIGHTS, fds)``, where *fds* is a :class:`bytes` object + representing the new file descriptors as a binary array of the + native C :ctype:`int` type. If :meth:`recvmsg` raises an exception + after the system call returns, it will first attempt to close any + file descriptors received via this mechanism. + + Some systems do not indicate the truncated length of ancillary data + items which have been only partially received. If an item appears + to extend beyond the end of the buffer, :meth:`recvmsg` will issue + a :exc:`RuntimeWarning`, and will return the part of it which is + inside the buffer provided it has not been truncated before the + start of its associated data. + + On systems which support the :const:`SCM_RIGHTS` mechanism, the + following function will receive up to *maxfds* file descriptors, + returning the message data and a list containing the descriptors + (while ignoring unexpected conditions such as unrelated control + messages being received). See also :meth:`sendmsg`. :: + + import socket, array + + def recv_fds(sock, msglen, maxfds): + fds = array.array("i") # Array of ints + msg, ancdata, flags, addr = sock.recvmsg(msglen, socket.CMSG_LEN(maxfds * fds.itemsize)) + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if (cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS): + # Append data, ignoring any truncated integers at the end. + fds.fromstring(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + return msg, list(fds) + + Availability: most Unix platforms, possibly others. + + .. versionadded:: XXX + + +.. method:: socket.recvmsg_into(buffers[, ancbufsize[, flags]]) + + Receive normal data and ancillary data from the socket, behaving as + :meth:`recvmsg` would, but scatter the non-ancillary data into a + series of buffers instead of returning a new bytes object. The + *buffers* argument must be an iterable of objects that export + writable buffers (e.g. :class:`bytearray` objects); these will be + filled with successive chunks of the non-ancillary data until it + has all been written or there are no more buffers. The operating + system may set a limit (:func:`~os.sysconf` value ``SC_IOV_MAX``) + on the number of buffers that can be used. The *ancbufsize* and + *flags* arguments have the same meaning as for :meth:`recvmsg`. + + The return value is a 4-tuple: ``(nbytes, ancdata, msg_flags, + address)``, where *nbytes* is the total number of bytes of + non-ancillary data written into the buffers, and *ancdata*, + *msg_flags* and *address* are the same as for :meth:`recvmsg`. + + Example:: + + >>> import socket + >>> s1, s2 = socket.socketpair() + >>> b1 = bytearray(b'----') + >>> b2 = bytearray(b'0123456789') + >>> b3 = bytearray(b'--------------') + >>> s1.send(b'Mary had a little lamb') + 22 + >>> s2.recvmsg_into([b1, memoryview(b2)[2:9], b3]) + (22, [], 0, None) + >>> [b1, b2, b3] + [bytearray(b'Mary'), bytearray(b'01 had a 9'), bytearray(b'little lamb---')] + + Availability: most Unix platforms, possibly others. + + .. versionadded:: XXX + + .. method:: socket.recvfrom_into(buffer[, nbytes[, flags]]) Receive data from the socket, writing it into *buffer* instead of creating a @@ -740,6 +887,41 @@ correspond to Unix system calls applicab above.) +.. method:: socket.sendmsg(buffers[, ancdata[, flags[, address]]]) + + Send normal and ancillary data to the socket, gathering the + non-ancillary data from a series of buffers and concatenating it + into a single message. The *buffers* argument specifies the + non-ancillary data as an iterable of buffer-compatible objects + (e.g. :class:`bytes` objects); the operating system may set a limit + (:func:`~os.sysconf` value ``SC_IOV_MAX``) on the number of buffers + that can be used. The *ancdata* argument specifies the ancillary + data (control messages) as an iterable of zero or more tuples + ``(cmsg_level, cmsg_type, cmsg_data)``, where *cmsg_level* and + *cmsg_type* are integers specifying the protocol level and + protocol-specific type respectively, and *cmsg_data* is a + buffer-compatible object holding the associated data. Note that + some systems (in particular, systems without :func:`CMSG_SPACE`) + might support sending only one control message per call. The + *flags* argument defaults to 0 and has the same meaning as for + :meth:`send`. If *address* is supplied and not ``None``, it sets a + destination address for the message. The return value is the + number of bytes of non-ancillary data sent. + + The following function sends the list of file descriptors *fds* + over an :const:`AF_UNIX` socket, on systems which support the + :const:`SCM_RIGHTS` mechanism. See also :meth:`recvmsg`. :: + + import socket, array + + def send_fds(sock, msg, fds): + return sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))]) + + Availability: most Unix platforms, possibly others. + + .. versionadded:: XXX + + .. method:: socket.setblocking(flag) Set blocking or non-blocking mode of the socket: if *flag* is 0, the socket is diff --git a/Lib/socket.py b/Lib/socket.py --- a/Lib/socket.py +++ b/Lib/socket.py @@ -162,14 +162,15 @@ if sys.platform == "riscos": # All the method names that must be delegated to either the real socket # object or the _closedsocket object. _delegate_methods = ("recv", "recvfrom", "recv_into", "recvfrom_into", - "send", "sendto") + "send", "sendto", "sendmsg", "recvmsg", "recvmsg_into") class _closedsocket(object): __slots__ = [] def _dummy(*args): raise error(EBADF, 'Bad file descriptor') # All _delegate_methods must also be initialized here. - send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy + send = recv = recv_into = sendto = recvfrom = recvfrom_into = sendmsg = \ + recvmsg = recvmsg_into = _dummy __getattr__ = _dummy # Wrapper around platform socket objects. This implements diff --git a/Lib/ssl.py b/Lib/ssl.py --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -208,6 +208,13 @@ class SSLSocket(socket): else: return self._sock.sendto(data, flags_or_addr, addr) + def sendmsg(self, *args, **kwargs): + if self._sslobj: + raise ValueError("sendmsg not allowed on instances of %s" % + self.__class__) + else: + return self._sock.sendmsg(*args, **kwargs) + def sendall(self, data, flags=0): if self._sslobj: if flags != 0: @@ -264,6 +271,20 @@ class SSLSocket(socket): else: return self._sock.recvfrom_into(buffer, nbytes, flags) + def recvmsg(self, *args, **kwargs): + if self._sslobj: + raise ValueError("recvmsg not allowed on instances of %s" % + self.__class__) + else: + return self._sock.recvmsg(*args, **kwargs) + + def recvmsg_into(self, *args, **kwargs): + if self._sslobj: + raise ValueError("recvmsg_into not allowed on instances of %s" % + self.__class__) + else: + return self._sock.recvmsg_into(*args, **kwargs) + def pending(self): if self._sslobj: return self._sslobj.pending() diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -6,6 +6,8 @@ from test import test_support import errno import socket import select +import tempfile +import _testcapi import time import traceback import Queue @@ -43,6 +45,9 @@ except ImportError: HOST = test_support.HOST MSG = 'Michael Gilfix was here\n' +# Size in bytes of the int type +SIZEOF_INT = array.array("i").itemsize + class SocketTCPTest(unittest.TestCase): def setUp(self): @@ -64,6 +69,28 @@ class SocketUDPTest(unittest.TestCase): self.serv.close() self.serv = None +class ThreadSafeCleanupTestCase(unittest.TestCase): + """Subclass of unittest.TestCase with thread-safe cleanup methods. + + This subclass protects the addCleanup() and doCleanups() methods + with a recursive lock. + """ + + if threading: + def __init__(self, *args, **kwargs): + super(ThreadSafeCleanupTestCase, self).__init__(*args, **kwargs) + self._cleanup_lock = threading.RLock() + + def addCleanup(self, *args, **kwargs): + with self._cleanup_lock: + return super(ThreadSafeCleanupTestCase, self).addCleanup( + *args, **kwargs) + + def doCleanups(self, *args, **kwargs): + with self._cleanup_lock: + return super(ThreadSafeCleanupTestCase, self).doCleanups( + *args, **kwargs) + class ThreadableTest: """Threadable Test class @@ -241,6 +268,243 @@ class SocketPairTest(unittest.TestCase, ThreadableTest.clientTearDown(self) +# The following classes are used by the sendmsg()/recvmsg() tests. +# Combining, for instance, ConnectedStreamTestMixin and TCPTestBase +# gives a drop-in replacement for SocketConnectedTest, but different +# address families can be used, and the attributes serv_addr and +# cli_addr will be set to the addresses of the endpoints. + +class SocketTestBase(unittest.TestCase): + """A base class for socket tests. + + Subclasses must provide methods newSocket() to return a new socket + and bindSock(sock) to bind it to an unused address. + + Creates a socket self.serv and sets self.serv_addr to its address. + """ + + def setUp(self): + self.serv = self.newSocket() + self.bindServer() + + def bindServer(self): + """Bind server socket and set self.serv_addr to its address.""" + self.bindSock(self.serv) + self.serv_addr = self.serv.getsockname() + + def tearDown(self): + self.serv.close() + self.serv = None + + +class SocketListeningTestMixin(SocketTestBase): + """Mixin to listen on the server socket.""" + + def setUp(self): + super(SocketListeningTestMixin, self).setUp() + self.serv.listen(1) + + +class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase, + ThreadableTest): + """Mixin to add client socket and allow client/server tests. + + Client socket is self.cli and its address is self.cli_addr. See + ThreadableTest for usage information. + """ + + def __init__(self, *args, **kwargs): + super(ThreadedSocketTestMixin, self).__init__(*args, **kwargs) + ThreadableTest.__init__(self) + + def clientSetUp(self): + self.cli = self.newClientSocket() + self.bindClient() + + def newClientSocket(self): + """Return a new socket for use as client.""" + return self.newSocket() + + def bindClient(self): + """Bind client socket and set self.cli_addr to its address.""" + self.bindSock(self.cli) + self.cli_addr = self.cli.getsockname() + + def clientTearDown(self): + self.cli.close() + self.cli = None + ThreadableTest.clientTearDown(self) + + +class ConnectedStreamTestMixin(SocketListeningTestMixin, + ThreadedSocketTestMixin): + """Mixin to allow client/server stream tests with connected client. + + Server's socket representing connection to client is self.cli_conn + and client's connection to server is self.serv_conn. (Based on + SocketConnectedTest.) + """ + + def setUp(self): + super(ConnectedStreamTestMixin, self).setUp() + # Indicate explicitly we're ready for the client thread to + # proceed and then perform the blocking call to accept + self.serverExplicitReady() + conn, addr = self.serv.accept() + self.cli_conn = conn + + def tearDown(self): + self.cli_conn.close() + self.cli_conn = None + super(ConnectedStreamTestMixin, self).tearDown() + + def clientSetUp(self): + super(ConnectedStreamTestMixin, self).clientSetUp() + self.cli.connect(self.serv_addr) + self.serv_conn = self.cli + + def clientTearDown(self): + self.serv_conn.close() + self.serv_conn = None + super(ConnectedStreamTestMixin, self).clientTearDown() + + +class UnixSocketTestBase(SocketTestBase): + """Base class for Unix-domain socket tests.""" + + # This class is used for file descriptor passing tests, so we + # create the sockets in a private directory so that other users + # can't send anything that might be problematic for a privileged + # user running the tests. + + def setUp(self): + self.dir_path = tempfile.mkdtemp() + self.addCleanup(os.rmdir, self.dir_path) + super(UnixSocketTestBase, self).setUp() + + def bindSock(self, sock): + path = tempfile.mktemp(dir=self.dir_path) + sock.bind(path) + self.addCleanup(test_support.unlink, path) + +class UnixStreamBase(UnixSocketTestBase): + """Base class for Unix-domain SOCK_STREAM tests.""" + + def newSocket(self): + return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + +class InetTestBase(SocketTestBase): + """Base class for IPv4 socket tests.""" + + host = HOST + + def setUp(self): + super(InetTestBase, self).setUp() + self.port = self.serv_addr[1] + + def bindSock(self, sock): + test_support.bind_port(sock, host=self.host) + +class TCPTestBase(InetTestBase): + """Base class for TCP-over-IPv4 tests.""" + + def newSocket(self): + return socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +class UDPTestBase(InetTestBase): + """Base class for UDP-over-IPv4 tests.""" + + def newSocket(self): + return socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + +class SCTPStreamBase(InetTestBase): + """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode.""" + + def newSocket(self): + return socket.socket(socket.AF_INET, socket.SOCK_STREAM, + socket.IPPROTO_SCTP) + + +class Inet6TestBase(InetTestBase): + """Base class for IPv6 socket tests.""" + + # Don't use "localhost" here - it may not have an IPv6 address + # assigned to it by default (e.g. in /etc/hosts), and if someone + # has assigned it an IPv4-mapped address, then it's unlikely to + # work with the full IPv6 API. + host = "::1" + +class UDP6TestBase(Inet6TestBase): + """Base class for UDP-over-IPv6 tests.""" + + def newSocket(self): + return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + + +# Test-skipping decorators for use with ThreadableTest. + +def skipWithClientIf(condition, reason): + """Skip decorated test if condition is true, add client_skip decorator. + + If the decorated object is not a class, sets its attribute + "client_skip" to a decorator which will return an empty function + if the test is to be skipped, or the original function if it is + not. This can be used to avoid running the client part of a + skipped test when using ThreadableTest. + """ + def client_pass(*args, **kwargs): + pass + def skipdec(obj): + retval = unittest.skip(reason)(obj) + if not isinstance(obj, type): + retval.client_skip = lambda f: client_pass + return retval + def noskipdec(obj): + if not (isinstance(obj, type) or hasattr(obj, "client_skip")): + obj.client_skip = lambda f: f + return obj + return skipdec if condition else noskipdec + + +def requireAttrs(obj, *attributes): + """Skip decorated test if obj is missing any of the given attributes. + + Sets client_skip attribute as skipWithClientIf() does. + """ + missing = [name for name in attributes if not hasattr(obj, name)] + return skipWithClientIf( + missing, "don't have " + ", ".join(name for name in missing)) + + +def requireSocket(*args): + """Skip decorated test if a socket cannot be created with given arguments. + + When an argument is given as a string, will use the value of that + attribute of the socket module, or skip the test if it doesn't + exist. Sets client_skip attribute as skipWithClientIf() does. + """ + err = None + missing = [obj for obj in args if + isinstance(obj, str) and not hasattr(socket, obj)] + if missing: + err = "don't have " + ", ".join(name for name in missing) + else: + callargs = [getattr(socket, obj) if isinstance(obj, str) else obj + for obj in args] + try: + s = socket.socket(*callargs) + except socket.error as e: + # XXX: check errno? + err = str(e) + else: + s.close() + return skipWithClientIf( + err is not None, + "can't create socket({0}): {1}".format( + ", ".join(str(o) for o in args), err)) + + ####################################################################### ## Begin Tests @@ -829,6 +1093,1840 @@ class BasicUDPTest(ThreadedUDPSocketTest def _testRecvFromNegative(self): self.cli.sendto(MSG, 0, (HOST, self.port)) + +# Tests for the sendmsg()/recvmsg() interface. Where possible, the +# same test code is used with different families and types of socket +# (e.g. stream, datagram), and tests using recvmsg() are repeated +# using recvmsg_into(). +# +# The generic test classes such as SendmsgTests and +# RecvmsgGenericTests inherit from SendrecvmsgBase and expect to be +# supplied with sockets cli_sock and serv_sock representing the +# client's and the server's end of the connection respectively, and +# attributes cli_addr and serv_addr holding their (numeric where +# appropriate) addresses. +# +# The final concrete test classes combine these with subclasses of +# SocketTestBase which set up client and server sockets of a specific +# type, and with subclasses of SendrecvmsgBase such as +# SendrecvmsgDgramBase and SendrecvmsgConnectedBase which map these +# sockets to cli_sock and serv_sock and override the methods and +# attributes of SendrecvmsgBase to fill in destination addresses if +# needed when sending, check for specific flags in msg_flags, etc. +# +# RecvmsgIntoMixin provides a version of doRecvmsg() implemented using +# recvmsg_into(). + +# XXX: like the other datagram (UDP) tests in this module, the code +# here assumes that datagram delivery on the local machine will be +# reliable. + +class SendrecvmsgBase(ThreadSafeCleanupTestCase): + # Base class for sendmsg()/recvmsg() tests. + + # Time in seconds to wait before considering a test failed, or + # None for no timeout. Not all tests actually set a timeout. + fail_timeout = 3.0 + + def setUp(self): + self.misc_event = threading.Event() + super(SendrecvmsgBase, self).setUp() + + def sendToServer(self, msg): + # Send msg to the server. + return self.cli_sock.send(msg) + + # Tuple of alternative default arguments for sendmsg() when called + # via sendmsgToServer() (e.g. to include a destination address). + sendmsg_to_server_defaults = () + + def sendmsgToServer(self, *args): + # Call sendmsg() on self.cli_sock with the given arguments, + # filling in any arguments which are not supplied with the + # corresponding items of self.sendmsg_to_server_defaults, if + # any. + return self.cli_sock.sendmsg( + *(args + self.sendmsg_to_server_defaults[len(args):])) + + def doRecvmsg(self, sock, bufsize, *args): + # Call recvmsg() on sock with given arguments and return its + # result. Should be used for tests which can use either + # recvmsg() or recvmsg_into() - RecvmsgIntoMixin overrides + # this method with one which emulates it using recvmsg_into(), + # thus allowing the same test to be used for both methods. + result = sock.recvmsg(bufsize, *args) + self.registerRecvmsgResult(result) + return result + + def registerRecvmsgResult(self, result): + # Called by doRecvmsg() with the return value of recvmsg() or + # recvmsg_into(). Can be overridden to arrange cleanup based + # on the returned ancillary data, for instance. + pass + + def checkRecvmsgAddress(self, addr1, addr2): + # Called to compare the received address with the address of + # the peer. + self.assertEqual(addr1, addr2) + + # Flags that are normally unset in msg_flags + msg_flags_common_unset = 0 + for name in ("MSG_CTRUNC", "MSG_OOB"): + msg_flags_common_unset |= getattr(socket, name, 0) + + # Flags that are normally set + msg_flags_common_set = 0 + + # Flags set when a complete record has been received (e.g. MSG_EOR + # for SCTP) + msg_flags_eor_indicator = 0 + + # Flags set when a complete record has not been received + # (e.g. MSG_TRUNC for datagram sockets) + msg_flags_non_eor_indicator = 0 + + def checkFlags(self, flags, eor=None, checkset=0, checkunset=0, ignore=0): + # Method to check the value of msg_flags returned by recvmsg[_into](). + # + # Checks that all bits in msg_flags_common_set attribute are + # set in "flags" and all bits in msg_flags_common_unset are + # unset. + # + # The "eor" argument specifies whether the flags should + # indicate that a full record (or datagram) has been received. + # If "eor" is None, no checks are done; otherwise, checks + # that: + # + # * if "eor" is true, all bits in msg_flags_eor_indicator are + # set and all bits in msg_flags_non_eor_indicator are unset + # + # * if "eor" is false, all bits in msg_flags_non_eor_indicator + # are set and all bits in msg_flags_eor_indicator are unset + # + # If "checkset" and/or "checkunset" are supplied, they require + # the given bits to be set or unset respectively, overriding + # what the attributes require for those bits. + # + # If any bits are set in "ignore", they will not be checked, + # regardless of the other inputs. + # + # Will raise Exception if the inputs require a bit to be both + # set and unset, and it is not ignored. + + defaultset = self.msg_flags_common_set + defaultunset = self.msg_flags_common_unset + + if eor: + defaultset |= self.msg_flags_eor_indicator + defaultunset |= self.msg_flags_non_eor_indicator + elif eor is not None: + defaultset |= self.msg_flags_non_eor_indicator + defaultunset |= self.msg_flags_eor_indicator + + # Function arguments override defaults + defaultset &= ~checkunset + defaultunset &= ~checkset + + # Merge arguments with remaining defaults, and check for conflicts + checkset |= defaultset + checkunset |= defaultunset + inboth = checkset & checkunset & ~ignore + if inboth: + raise Exception("contradictory set, unset requirements for flags " + "{0:#x}".format(inboth)) + + # Compare with given msg_flags value + mask = (checkset | checkunset) & ~ignore + self.assertEqual(flags & mask, checkset & mask) + + +class RecvmsgIntoMixin(SendrecvmsgBase): + # Mixin to implement doRecvmsg() using recvmsg_into(). + + def doRecvmsg(self, sock, bufsize, *args): + buf = bytearray(bufsize) + result = sock.recvmsg_into([buf], *args) + self.registerRecvmsgResult(result) + self.assertGreaterEqual(result[0], 0) + self.assertLessEqual(result[0], bufsize) + return (bytes(buf[:result[0]]),) + result[1:] + + +class SendrecvmsgDgramFlagsBase(SendrecvmsgBase): + # Defines flags to be checked in msg_flags for datagram sockets. + + @property + def msg_flags_non_eor_indicator(self): + return (super(SendrecvmsgDgramFlagsBase, self) + .msg_flags_non_eor_indicator) | socket.MSG_TRUNC + + +class SendrecvmsgSCTPFlagsBase(SendrecvmsgBase): + # Defines flags to be checked in msg_flags for SCTP sockets. + + @property + def msg_flags_eor_indicator(self): + return super(SendrecvmsgSCTPFlagsBase, self).msg_flags_eor_indicator | socket.MSG_EOR + + +class SendrecvmsgConnectionlessBase(SendrecvmsgBase): + # Base class for tests on connectionless-mode sockets. Users must + # supply sockets on attributes cli and serv to be mapped to + # cli_sock and serv_sock respectively. + + @property + def serv_sock(self): + return self.serv + + @property + def cli_sock(self): + return self.cli + + @property + def sendmsg_to_server_defaults(self): + return ([], [], 0, self.serv_addr) + + def sendToServer(self, msg): + return self.cli_sock.sendto(msg, self.serv_addr) + + +class SendrecvmsgConnectedBase(SendrecvmsgBase): + # Base class for tests on connected sockets. Users must supply + # sockets on attributes serv_conn and cli_conn (representing the + # connections *to* the server and the client), to be mapped to + # cli_sock and serv_sock respectively. + + @property + def serv_sock(self): + return self.cli_conn + + @property + def cli_sock(self): + return self.serv_conn + + def checkRecvmsgAddress(self, addr1, addr2): + # Address is currently "unspecified" for a connected socket, + # so we don't examine it + pass + + +class SendrecvmsgServerTimeoutBase(SendrecvmsgBase): + # Base class to set a timeout on server's socket. + + def setUp(self): + super(SendrecvmsgServerTimeoutBase, self).setUp() + self.serv_sock.settimeout(self.fail_timeout) + + +class SendmsgTests(SendrecvmsgServerTimeoutBase): + # Tests for sendmsg() which can use any socket type and do not + # involve recvmsg() or recvmsg_into(). + + def testSendmsg(self): + # Send a simple message with sendmsg(). + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsg(self): + self.assertEqual(self.sendmsgToServer([MSG]), len(MSG)) + + def testSendmsgDataGenerator(self): + # Send from buffer obtained from a generator (not a sequence). + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgDataGenerator(self): + self.assertEqual(self.sendmsgToServer((o for o in [MSG])), + len(MSG)) + + def testSendmsgAncillaryGenerator(self): + # Gather (empty) ancillary data from a generator. + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgAncillaryGenerator(self): + self.assertEqual(self.sendmsgToServer([MSG], (o for o in [])), + len(MSG)) + + def testSendmsgArray(self): + # Send data from an array instead of the usual bytes object. + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgArray(self): + self.assertEqual(self.sendmsgToServer([array.array("B", MSG)]), + len(MSG)) + + def testSendmsgGather(self): + # Send message data from more than one buffer (gather write). + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgGather(self): + self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG)) + + def testSendmsgBadArgs(self): + # Check that sendmsg() rejects invalid arguments. + self.assertEqual(self.serv_sock.recv(1000), b"done") + + def _testSendmsgBadArgs(self): + self.assertRaises(TypeError, self.cli_sock.sendmsg) + self.assertRaises(TypeError, self.sendmsgToServer, + bytearray(b"not in an iterable")) + self.assertRaises(TypeError, self.sendmsgToServer, + object()) + self.assertRaises(TypeError, self.sendmsgToServer, + [object()]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG, object()]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], object()) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [], object()) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [], 0, object()) + self.sendToServer(b"done") + + def testSendmsgBadCmsg(self): + # Check that invalid ancillary data items are rejected. + self.assertEqual(self.serv_sock.recv(1000), b"done") + + def _testSendmsgBadCmsg(self): + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [object()]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(object(), 0, b"data")]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, object(), b"data")]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, 0, object())]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, 0)]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, 0, b"data", 42)]) + self.sendToServer(b"done") + + @requireAttrs(socket, "CMSG_SPACE") + def testSendmsgBadMultiCmsg(self): + # Check that invalid ancillary data items are rejected when + # more than one item is present. + self.assertEqual(self.serv_sock.recv(1000), b"done") + + @testSendmsgBadMultiCmsg.client_skip + def _testSendmsgBadMultiCmsg(self): + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [0, 0, b""]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, 0, b""), object()]) + self.sendToServer(b"done") + + def testSendmsgExcessCmsgReject(self): + # Check that sendmsg() rejects excess ancillary data items + # when the number that can be sent is limited. + self.assertEqual(self.serv_sock.recv(1000), b"done") + + def _testSendmsgExcessCmsgReject(self): + if not hasattr(socket, "CMSG_SPACE"): + # Can only send one item + with self.assertRaises(socket.error) as cm: + self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")]) + self.assertIsNone(cm.exception.errno) + self.sendToServer(b"done") + + def testSendmsgAfterClose(self): + # Check that sendmsg() fails on a closed socket. + pass + + def _testSendmsgAfterClose(self): + self.cli_sock.close() + self.assertRaises(socket.error, self.sendmsgToServer, [MSG]) + + +class SendmsgStreamTests(SendmsgTests): + # Tests for sendmsg() which require a stream socket and do not + # involve recvmsg() or recvmsg_into(). + + def testSendmsgExplicitNoneAddr(self): + # Check that peer address can be specified as None. + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgExplicitNoneAddr(self): + self.assertEqual(self.sendmsgToServer([MSG], [], 0, None), len(MSG)) + + def testSendmsgTimeout(self): + # Check that timeout works with sendmsg(). + self.assertEqual(self.serv_sock.recv(512), b"a"*512) + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + + def _testSendmsgTimeout(self): + try: + self.cli_sock.settimeout(0.03) + with self.assertRaises(socket.timeout): + while True: + self.sendmsgToServer([b"a"*512]) + finally: + self.misc_event.set() + + # XXX: would be nice to have more tests for sendmsg flags argument. + + # Linux supports MSG_DONTWAIT when sending, but in general, it + # only works when receiving. Could add other platforms if they + # support it too. + @skipWithClientIf(sys.platform not in {"linux2"}, + "MSG_DONTWAIT not known to work on this platform when " + "sending") + def testSendmsgDontWait(self): + # Check that MSG_DONTWAIT in flags causes non-blocking behaviour. + self.assertEqual(self.serv_sock.recv(512), b"a"*512) + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + + @testSendmsgDontWait.client_skip + def _testSendmsgDontWait(self): + try: + with self.assertRaises(socket.error) as cm: + while True: + self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT) + self.assertIn(cm.exception.errno, + (errno.EAGAIN, errno.EWOULDBLOCK)) + finally: + self.misc_event.set() + + +class SendmsgConnectionlessTests(SendmsgTests): + # Tests for sendmsg() which require a connectionless-mode + # (e.g. datagram) socket, and do not involve recvmsg() or + # recvmsg_into(). + + def testSendmsgNoDestAddr(self): + # Check that sendmsg() fails when no destination address is + # given for unconnected socket. + pass + + def _testSendmsgNoDestAddr(self): + self.assertRaises(socket.error, self.cli_sock.sendmsg, + [MSG]) + self.assertRaises(socket.error, self.cli_sock.sendmsg, + [MSG], [], 0, None) + + +class RecvmsgGenericTests(SendrecvmsgBase): + # Tests for recvmsg() which can also be emulated using + # recvmsg_into(), and can use any socket type. + + def testRecvmsg(self): + # Receive a simple message with recvmsg[_into](). + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG)) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsg(self): + self.sendToServer(MSG) + + def testRecvmsgExplicitDefaults(self): + # Test recvmsg[_into]() with default arguments provided explicitly. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 0, 0) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgExplicitDefaults(self): + self.sendToServer(MSG) + + def testRecvmsgShorter(self): + # Receive a message smaller than buffer. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG) + 42) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgShorter(self): + self.sendToServer(MSG) + + def testRecvmsgTrunc(self): + # Receive part of message, check for truncation indicators. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG) - 3) + self.assertEqual(msg, MSG[:-3]) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=False) + + def _testRecvmsgTrunc(self): + self.sendToServer(MSG) + + def testRecvmsgShortAncillaryBuf(self): + # Test ancillary data buffer too small to hold any ancillary data. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 1) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgShortAncillaryBuf(self): + self.sendToServer(MSG) + + def testRecvmsgLongAncillaryBuf(self): + # Test large ancillary data buffer. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 10240) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgLongAncillaryBuf(self): + self.sendToServer(MSG) + + def testRecvmsgAfterClose(self): + # Check that recvmsg[_into]() fails on a closed socket. + self.serv_sock.close() + self.assertRaises(socket.error, self.doRecvmsg, self.serv_sock, 1024) + + def _testRecvmsgAfterClose(self): + pass + + def testRecvmsgTimeout(self): + # Check that timeout works. + try: + self.serv_sock.settimeout(0.03) + self.assertRaises(socket.timeout, + self.doRecvmsg, self.serv_sock, len(MSG)) + finally: + self.misc_event.set() + + def _testRecvmsgTimeout(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + + @requireAttrs(socket, "MSG_PEEK") + def testRecvmsgPeek(self): + # Check that MSG_PEEK in flags enables examination of pending + # data without consuming it. + + # Receive part of data with MSG_PEEK. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG) - 3, 0, + socket.MSG_PEEK) + self.assertEqual(msg, MSG[:-3]) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + # Ignoring MSG_TRUNC here (so this test is the same for stream + # and datagram sockets). Some wording in POSIX seems to + # suggest that it needn't be set when peeking, but that may + # just be a slip. + self.checkFlags(flags, eor=False, + ignore=getattr(socket, "MSG_TRUNC", 0)) + + # Receive all data with MSG_PEEK. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 0, + socket.MSG_PEEK) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + # Check that the same data can still be received normally. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG)) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + @testRecvmsgPeek.client_skip + def _testRecvmsgPeek(self): + self.sendToServer(MSG) + + @requireAttrs(socket.socket, "sendmsg") + def testRecvmsgFromSendmsg(self): + # Test receiving with recvmsg[_into]() when message is sent + # using sendmsg(). + self.serv_sock.settimeout(self.fail_timeout) + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG)) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + @testRecvmsgFromSendmsg.client_skip + def _testRecvmsgFromSendmsg(self): + self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG)) + + +class RecvmsgGenericStreamTests(RecvmsgGenericTests): + # Tests which require a stream socket and can use either recvmsg() + # or recvmsg_into(). + + def testRecvmsgEOF(self): + # Receive end-of-stream indicator (b"", peer socket closed). + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024) + self.assertEqual(msg, b"") + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=None) # Might not have end-of-record marker + + def _testRecvmsgEOF(self): + self.cli_sock.close() + + def testRecvmsgOverflow(self): + # Receive a message in more than one chunk. + seg1, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG) - 3) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=False) + + seg2, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + msg = seg1 + seg2 + self.assertEqual(msg, MSG) + + def _testRecvmsgOverflow(self): + self.sendToServer(MSG) + + +class RecvmsgTests(RecvmsgGenericTests): + # Tests for recvmsg() which can use any socket type. + + def testRecvmsgBadArgs(self): + # Check that recvmsg() rejects invalid arguments. + self.assertRaises(TypeError, self.serv_sock.recvmsg) + self.assertRaises(ValueError, self.serv_sock.recvmsg, + -1, 0, 0) + self.assertRaises(ValueError, self.serv_sock.recvmsg, + len(MSG), -1, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg, + [bytearray(10)], 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg, + object(), 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg, + len(MSG), object(), 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg, + len(MSG), 0, object()) + + msg, ancdata, flags, addr = self.serv_sock.recvmsg(len(MSG), 0, 0) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgBadArgs(self): + self.sendToServer(MSG) + + +class RecvmsgIntoTests(RecvmsgIntoMixin, RecvmsgGenericTests): + # Tests for recvmsg_into() which can use any socket type. + + def testRecvmsgIntoBadArgs(self): + # Check that recvmsg_into() rejects invalid arguments. + buf = bytearray(len(MSG)) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + len(MSG), 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + buf, 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [object()], 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [b"I'm not writable"], 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [buf, object()], 0, 0) + self.assertRaises(ValueError, self.serv_sock.recvmsg_into, + [buf], -1, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [buf], object(), 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [buf], 0, object()) + + nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf], 0, 0) + self.assertEqual(nbytes, len(MSG)) + self.assertEqual(buf, bytearray(MSG)) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgIntoBadArgs(self): + self.sendToServer(MSG) + + def testRecvmsgIntoGenerator(self): + # Receive into buffer obtained from a generator (not a sequence). + buf = bytearray(len(MSG)) + nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into( + (o for o in [buf])) + self.assertEqual(nbytes, len(MSG)) + self.assertEqual(buf, bytearray(MSG)) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgIntoGenerator(self): + self.sendToServer(MSG) + + def testRecvmsgIntoArray(self): + # Receive into an array rather than the usual bytearray. + buf = array.array("B", [0] * len(MSG)) + nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf]) + self.assertEqual(nbytes, len(MSG)) + self.assertEqual(buf.tostring(), MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgIntoArray(self): + self.sendToServer(MSG) + + def testRecvmsgIntoScatter(self): + # Receive into multiple buffers (scatter write). + b1 = bytearray(b"----") + b2 = bytearray(b"0123456789") + b3 = bytearray(b"--------------") + nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into( + [b1, memoryview(b2)[2:9], b3]) + self.assertEqual(nbytes, len(b"Mary had a little lamb")) + self.assertEqual(b1, bytearray(b"Mary")) + self.assertEqual(b2, bytearray(b"01 had a 9")) + self.assertEqual(b3, bytearray(b"little lamb---")) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgIntoScatter(self): + self.sendToServer(b"Mary had a little lamb") + + +class CmsgMacroTests(unittest.TestCase): + # Test the functions CMSG_LEN() and CMSG_SPACE(). Tests + # assumptions used by sendmsg() and recvmsg[_into](), which share + # code with these functions. + + # Match the definition in socketmodule.c + socklen_t_limit = min(0x7fffffff, _testcapi.INT_MAX) + + @requireAttrs(socket, "CMSG_LEN") + def testCMSG_LEN(self): + # Test CMSG_LEN() with various valid and invalid values, + # checking the assumptions used by recvmsg() and sendmsg(). + toobig = self.socklen_t_limit - socket.CMSG_LEN(0) + 1 + values = list(range(257)) + list(range(toobig - 257, toobig)) + + # struct cmsghdr has at least three members, two of which are ints + self.assertGreater(socket.CMSG_LEN(0), array.array("i").itemsize * 2) + for n in values: + ret = socket.CMSG_LEN(n) + # This is how recvmsg() calculates the data size + self.assertEqual(ret - socket.CMSG_LEN(0), n) + self.assertLessEqual(ret, self.socklen_t_limit) + + self.assertRaises(OverflowError, socket.CMSG_LEN, -1) + # sendmsg() shares code with these functions, and requires + # that it reject values over the limit. + self.assertRaises(OverflowError, socket.CMSG_LEN, toobig) + self.assertRaises(OverflowError, socket.CMSG_LEN, sys.maxsize) + + @requireAttrs(socket, "CMSG_SPACE") + def testCMSG_SPACE(self): + # Test CMSG_SPACE() with various valid and invalid values, + # checking the assumptions used by sendmsg(). + toobig = self.socklen_t_limit - socket.CMSG_SPACE(1) + 1 + values = list(range(257)) + list(range(toobig - 257, toobig)) + + last = socket.CMSG_SPACE(0) + # struct cmsghdr has at least three members, two of which are ints + self.assertGreater(last, array.array("i").itemsize * 2) + for n in values: + ret = socket.CMSG_SPACE(n) + self.assertGreaterEqual(ret, last) + self.assertGreaterEqual(ret, socket.CMSG_LEN(n)) + self.assertGreaterEqual(ret, n + socket.CMSG_LEN(0)) + self.assertLessEqual(ret, self.socklen_t_limit) + last = ret + + self.assertRaises(OverflowError, socket.CMSG_SPACE, -1) + # sendmsg() shares code with these functions, and requires + # that it reject values over the limit. + self.assertRaises(OverflowError, socket.CMSG_SPACE, toobig) + self.assertRaises(OverflowError, socket.CMSG_SPACE, sys.maxsize) + + +class SCMRightsTest(SendrecvmsgServerTimeoutBase): + # Tests for file descriptor passing on Unix-domain sockets. + + # Invalid file descriptor value that's unlikely to evaluate to a + # real FD even if one of its bytes is replaced with a different + # value (which shouldn't actually happen). + badfd = -0x5555 + + def newFDs(self, n): + # Return a list of n file descriptors for newly-created files + # containing their list indices as ASCII numbers. + fds = [] + for i in range(n): + fd, path = tempfile.mkstemp() + self.addCleanup(os.unlink, path) + self.addCleanup(os.close, fd) + os.write(fd, str(i).encode()) + fds.append(fd) + return fds + + def checkFDs(self, fds): + # Check that the file descriptors in the given list contain + # their correct list indices as ASCII numbers. + for n, fd in enumerate(fds): + os.lseek(fd, 0, os.SEEK_SET) + self.assertEqual(os.read(fd, 1024), str(n).encode()) + + def registerRecvmsgResult(self, result): + self.addCleanup(self.closeRecvmsgFDs, result) + + def closeRecvmsgFDs(self, recvmsg_result): + # Close all file descriptors specified in the ancillary data + # of the given return value from recvmsg() or recvmsg_into(). + for cmsg_level, cmsg_type, cmsg_data in recvmsg_result[1]: + if (cmsg_level == socket.SOL_SOCKET and + cmsg_type == socket.SCM_RIGHTS): + fds = array.array("i") + fds.fromstring(cmsg_data[: + len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + for fd in fds: + os.close(fd) + + def createAndSendFDs(self, n): + # Send n new file descriptors created by newFDs() to the + # server, with the constant MSG as the non-ancillary data. + self.assertEqual( + self.sendmsgToServer([MSG], + [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", self.newFDs(n)))]), + len(MSG)) + + def checkRecvmsgFDs(self, numfds, result, maxcmsgs=1, ignoreflags=0): + # Check that constant MSG was received with numfds file + # descriptors in a maximum of maxcmsgs control messages (which + # must contain only complete integers). By default, check + # that MSG_CTRUNC is unset, but ignore any flags in + # ignoreflags. + msg, ancdata, flags, addr = result + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC, + ignore=ignoreflags) + + self.assertIsInstance(ancdata, list) + self.assertLessEqual(len(ancdata), maxcmsgs) + fds = array.array("i") + for item in ancdata: + self.assertIsInstance(item, tuple) + cmsg_level, cmsg_type, cmsg_data = item + self.assertEqual(cmsg_level, socket.SOL_SOCKET) + self.assertEqual(cmsg_type, socket.SCM_RIGHTS) + self.assertIsInstance(cmsg_data, bytes) + self.assertEqual(len(cmsg_data) % SIZEOF_INT, 0) + fds.fromstring(cmsg_data) + + self.assertEqual(len(fds), numfds) + self.checkFDs(fds) + + def testFDPassSimple(self): + # Pass a single FD (array read from bytes object). + self.checkRecvmsgFDs(1, self.doRecvmsg(self.serv_sock, + len(MSG), 10240)) + + def _testFDPassSimple(self): + self.assertEqual( + self.sendmsgToServer( + [MSG], + [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", self.newFDs(1)).tostring())]), + len(MSG)) + + def testMultipleFDPass(self): + # Pass multiple FDs in a single array. + self.checkRecvmsgFDs(4, self.doRecvmsg(self.serv_sock, + len(MSG), 10240)) + + def _testMultipleFDPass(self): + self.createAndSendFDs(4) + + @requireAttrs(socket, "CMSG_SPACE") + def testFDPassCMSG_SPACE(self): + # Test using CMSG_SPACE() to calculate ancillary buffer size. + self.checkRecvmsgFDs( + 4, self.doRecvmsg(self.serv_sock, len(MSG), + socket.CMSG_SPACE(4 * SIZEOF_INT))) + + @testFDPassCMSG_SPACE.client_skip + def _testFDPassCMSG_SPACE(self): + self.createAndSendFDs(4) + + def testFDPassCMSG_LEN(self): + # Test using CMSG_LEN() to calculate ancillary buffer size. + self.checkRecvmsgFDs(1, + self.doRecvmsg(self.serv_sock, len(MSG), + socket.CMSG_LEN(4 * SIZEOF_INT)), + # RFC 3542 says implementations may set + # MSG_CTRUNC if there isn't enough space + # for trailing padding. + ignoreflags=socket.MSG_CTRUNC) + + def _testFDPassCMSG_LEN(self): + self.createAndSendFDs(1) + + @requireAttrs(socket, "CMSG_SPACE") + def testFDPassSeparate(self): + # Pass two FDs in two separate arrays. Arrays may be combined + # into a single control message by the OS. + self.checkRecvmsgFDs(2, + self.doRecvmsg(self.serv_sock, len(MSG), 10240), + maxcmsgs=2) + + @testFDPassSeparate.client_skip + def _testFDPassSeparate(self): + fd0, fd1 = self.newFDs(2) + self.assertEqual( + self.sendmsgToServer([MSG], [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd0])), + (socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd1]))]), + len(MSG)) + + @requireAttrs(socket, "CMSG_SPACE") + def testFDPassSeparateMinSpace(self): + # Pass two FDs in two separate arrays, receiving them into the + # minimum space for two arrays. + self.checkRecvmsgFDs(2, + self.doRecvmsg(self.serv_sock, len(MSG), + socket.CMSG_SPACE(SIZEOF_INT) + + socket.CMSG_LEN(SIZEOF_INT)), + maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC) + + @testFDPassSeparateMinSpace.client_skip + def _testFDPassSeparateMinSpace(self): + fd0, fd1 = self.newFDs(2) + self.assertEqual( + self.sendmsgToServer([MSG], [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd0])), + (socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd1]))]), + len(MSG)) + + def sendAncillaryIfPossible(self, msg, ancdata): + # Try to send msg and ancdata to server, but if the system + # call fails, just send msg with no ancillary data. + try: + nbytes = self.sendmsgToServer([msg], ancdata) + except socket.error as e: + # Check that it was the system call that failed + self.assertIsInstance(e.errno, int) + nbytes = self.sendmsgToServer([msg]) + self.assertEqual(nbytes, len(msg)) + + def testFDPassEmpty(self): + # Try to pass an empty FD array. Can receive either no array + # or an empty array. + self.checkRecvmsgFDs(0, self.doRecvmsg(self.serv_sock, + len(MSG), 10240), + ignoreflags=socket.MSG_CTRUNC) + + def _testFDPassEmpty(self): + self.sendAncillaryIfPossible(MSG, [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + b"")]) + + def testFDPassPartialInt(self): + # Try to pass a truncated FD array. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 10240) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC) + self.assertLessEqual(len(ancdata), 1) + for cmsg_level, cmsg_type, cmsg_data in ancdata: + self.assertEqual(cmsg_level, socket.SOL_SOCKET) + self.assertEqual(cmsg_type, socket.SCM_RIGHTS) + self.assertLess(len(cmsg_data), SIZEOF_INT) + + def _testFDPassPartialInt(self): + self.sendAncillaryIfPossible( + MSG, + [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [self.badfd]).tostring()[:-1])]) + + @requireAttrs(socket, "CMSG_SPACE") + def testFDPassPartialIntInMiddle(self): + # Try to pass two FD arrays, the first of which is truncated. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 10240) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC) + self.assertLessEqual(len(ancdata), 2) + fds = array.array("i") + # Arrays may have been combined in a single control message + for cmsg_level, cmsg_type, cmsg_data in ancdata: + self.assertEqual(cmsg_level, socket.SOL_SOCKET) + self.assertEqual(cmsg_type, socket.SCM_RIGHTS) + fds.fromstring(cmsg_data[: + len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + self.assertLessEqual(len(fds), 2) + self.checkFDs(fds) + + @testFDPassPartialIntInMiddle.client_skip + def _testFDPassPartialIntInMiddle(self): + fd0, fd1 = self.newFDs(2) + self.sendAncillaryIfPossible( + MSG, + [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd0, self.badfd]).tostring()[:-1]), + (socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd1]))]) + + def checkTruncatedHeader(self, result, ignoreflags=0): + # Check that no ancillary data items are returned when data is + # truncated inside the cmsghdr structure. + msg, ancdata, flags, addr = result + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC, + ignore=ignoreflags) + + def testCmsgTruncNoBufSize(self): + # Check that no ancillary data is received when no buffer size + # is specified. + self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG)), + # BSD seems to set MSG_CTRUNC only + # if an item has been partially + # received. + ignoreflags=socket.MSG_CTRUNC) + + def _testCmsgTruncNoBufSize(self): + self.createAndSendFDs(1) + + def testCmsgTrunc0(self): + # Check that no ancillary data is received when buffer size is 0. + self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 0), + ignoreflags=socket.MSG_CTRUNC) + + def _testCmsgTrunc0(self): + self.createAndSendFDs(1) + + # Check that no ancillary data is returned for various non-zero + # (but still too small) buffer sizes. + + def testCmsgTrunc1(self): + self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 1)) + + def _testCmsgTrunc1(self): + self.createAndSendFDs(1) + + def testCmsgTrunc2Int(self): + # The cmsghdr structure has at least three members, two of + # which are ints, so we still shouldn't see any ancillary + # data. + self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), + SIZEOF_INT * 2)) + + def _testCmsgTrunc2Int(self): + self.createAndSendFDs(1) + + def testCmsgTruncLen0Minus1(self): + self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), + socket.CMSG_LEN(0) - 1)) + + def _testCmsgTruncLen0Minus1(self): + self.createAndSendFDs(1) + + # The following tests try to truncate the control message in the + # middle of the FD array. + + def checkTruncatedArray(self, ancbuf, maxdata, mindata=0): + # Check that file descriptor data is truncated to between + # mindata and maxdata bytes when received with buffer size + # ancbuf, and that any complete file descriptor numbers are + # valid. + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), ancbuf) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC) + + if mindata == 0 and ancdata == []: + return + self.assertEqual(len(ancdata), 1) + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + self.assertEqual(cmsg_level, socket.SOL_SOCKET) + self.assertEqual(cmsg_type, socket.SCM_RIGHTS) + self.assertGreaterEqual(len(cmsg_data), mindata) + self.assertLessEqual(len(cmsg_data), maxdata) + fds = array.array("i") + fds.fromstring(cmsg_data[: + len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + self.checkFDs(fds) + + def testCmsgTruncLen0(self): + self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0), maxdata=0) + + def _testCmsgTruncLen0(self): + self.createAndSendFDs(1) + + def testCmsgTruncLen0Plus1(self): + self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0) + 1, maxdata=1) + + def _testCmsgTruncLen0Plus1(self): + self.createAndSendFDs(2) + + def testCmsgTruncLen1(self): + self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(SIZEOF_INT), + maxdata=SIZEOF_INT) + + def _testCmsgTruncLen1(self): + self.createAndSendFDs(2) + + def testCmsgTruncLen2Minus1(self): + self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(2 * SIZEOF_INT) - 1, + maxdata=(2 * SIZEOF_INT) - 1) + + def _testCmsgTruncLen2Minus1(self): + self.createAndSendFDs(2) + + +class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase): + # Test sendmsg() and recvmsg[_into]() using the ancillary data + # features of the RFC 3542 Advanced Sockets API for IPv6. + # Currently we can only handle certain data items (e.g. traffic + # class, hop limit, MTU discovery and fragmentation settings) + # without resorting to unportable means such as the struct module, + # but the tests here are aimed at testing the ancillary data + # handling in sendmsg() and recvmsg() rather than the IPv6 API + # itself. + + # Test value to use when setting hop limit of packet + hop_limit = 2 + + # Test value to use when setting traffic class of packet. + # -1 means "use kernel default". + traffic_class = -1 + + def ancillaryMapping(self, ancdata): + # Given ancillary data list ancdata, return a mapping from + # pairs (cmsg_level, cmsg_type) to corresponding cmsg_data. + # Check that no (level, type) pair appears more than once. + d = {} + for cmsg_level, cmsg_type, cmsg_data in ancdata: + self.assertNotIn((cmsg_level, cmsg_type), d) + d[(cmsg_level, cmsg_type)] = cmsg_data + return d + + def checkHopLimit(self, ancbufsize, maxhop=255, ignoreflags=0): + # Receive hop limit into ancbufsize bytes of ancillary data + # space. Check that data is MSG, ancillary data is not + # truncated (but ignore any flags in ignoreflags), and hop + # limit is between 0 and maxhop inclusive. + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVHOPLIMIT, 1) + self.misc_event.set() + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), ancbufsize) + + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC, + ignore=ignoreflags) + + self.assertEqual(len(ancdata), 1) + self.assertIsInstance(ancdata[0], tuple) + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + self.assertEqual(cmsg_level, socket.IPPROTO_IPV6) + self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT) + self.assertIsInstance(cmsg_data, bytes) + self.assertEqual(len(cmsg_data), SIZEOF_INT) + a = array.array("i") + a.fromstring(cmsg_data) + self.assertGreaterEqual(a[0], 0) + self.assertLessEqual(a[0], maxhop) + + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testRecvHopLimit(self): + # Test receiving the packet hop limit as ancillary data. + self.checkHopLimit(ancbufsize=10240) + + @testRecvHopLimit.client_skip + def _testRecvHopLimit(self): + # Need to wait until server has asked to receive ancillary + # data, as implementations are not required to buffer it + # otherwise. + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testRecvHopLimitCMSG_SPACE(self): + # Test receiving hop limit, using CMSG_SPACE to calculate buffer size. + self.checkHopLimit(ancbufsize=socket.CMSG_SPACE(SIZEOF_INT)) + + @testRecvHopLimitCMSG_SPACE.client_skip + def _testRecvHopLimitCMSG_SPACE(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + # Could test receiving into buffer sized using CMSG_LEN, but RFC + # 3542 says portable applications must provide space for trailing + # padding. Implementations may set MSG_CTRUNC if there isn't + # enough space for the padding. + + @requireAttrs(socket.socket, "sendmsg") + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testSetHopLimit(self): + # Test setting hop limit on outgoing packet and receiving it + # at the other end. + self.checkHopLimit(ancbufsize=10240, maxhop=self.hop_limit) + + @testSetHopLimit.client_skip + def _testSetHopLimit(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.assertEqual( + self.sendmsgToServer([MSG], + [(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT, + array.array("i", [self.hop_limit]))]), + len(MSG)) + + def checkTrafficClassAndHopLimit(self, ancbufsize, maxhop=255, + ignoreflags=0): + # Receive traffic class and hop limit into ancbufsize bytes of + # ancillary data space. Check that data is MSG, ancillary + # data is not truncated (but ignore any flags in ignoreflags), + # and traffic class and hop limit are in range (hop limit no + # more than maxhop). + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVHOPLIMIT, 1) + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVTCLASS, 1) + self.misc_event.set() + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), ancbufsize) + + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC, + ignore=ignoreflags) + self.assertEqual(len(ancdata), 2) + ancmap = self.ancillaryMapping(ancdata) + + tcdata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_TCLASS)] + self.assertEqual(len(tcdata), SIZEOF_INT) + a = array.array("i") + a.fromstring(tcdata) + self.assertGreaterEqual(a[0], 0) + self.assertLessEqual(a[0], 255) + + hldata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT)] + self.assertEqual(len(hldata), SIZEOF_INT) + a = array.array("i") + a.fromstring(hldata) + self.assertGreaterEqual(a[0], 0) + self.assertLessEqual(a[0], maxhop) + + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testRecvTrafficClassAndHopLimit(self): + # Test receiving traffic class and hop limit as ancillary data. + self.checkTrafficClassAndHopLimit(ancbufsize=10240) + + @testRecvTrafficClassAndHopLimit.client_skip + def _testRecvTrafficClassAndHopLimit(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testRecvTrafficClassAndHopLimitCMSG_SPACE(self): + # Test receiving traffic class and hop limit, using + # CMSG_SPACE() to calculate buffer size. + self.checkTrafficClassAndHopLimit( + ancbufsize=socket.CMSG_SPACE(SIZEOF_INT) * 2) + + @testRecvTrafficClassAndHopLimitCMSG_SPACE.client_skip + def _testRecvTrafficClassAndHopLimitCMSG_SPACE(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket.socket, "sendmsg") + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testSetTrafficClassAndHopLimit(self): + # Test setting traffic class and hop limit on outgoing packet, + # and receiving them at the other end. + self.checkTrafficClassAndHopLimit(ancbufsize=10240, + maxhop=self.hop_limit) + + @testSetTrafficClassAndHopLimit.client_skip + def _testSetTrafficClassAndHopLimit(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.assertEqual( + self.sendmsgToServer([MSG], + [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS, + array.array("i", [self.traffic_class])), + (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT, + array.array("i", [self.hop_limit]))]), + len(MSG)) + + @requireAttrs(socket.socket, "sendmsg") + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testOddCmsgSize(self): + # Try to send ancillary data with first item one byte too + # long. Fall back to sending with correct size if this fails, + # and check that second item was handled correctly. + self.checkTrafficClassAndHopLimit(ancbufsize=10240, + maxhop=self.hop_limit) + + @testOddCmsgSize.client_skip + def _testOddCmsgSize(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + try: + nbytes = self.sendmsgToServer( + [MSG], + [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS, + array.array("i", [self.traffic_class]).tostring() + b"\x00"), + (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT, + array.array("i", [self.hop_limit]))]) + except socket.error as e: + self.assertIsInstance(e.errno, int) + nbytes = self.sendmsgToServer( + [MSG], + [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS, + array.array("i", [self.traffic_class])), + (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT, + array.array("i", [self.hop_limit]))]) + self.assertEqual(nbytes, len(MSG)) + + # Tests for proper handling of truncated ancillary data + + def checkHopLimitTruncatedHeader(self, ancbufsize, ignoreflags=0): + # Receive hop limit into ancbufsize bytes of ancillary data + # space, which should be too small to contain the ancillary + # data header (if ancbufsize is None, pass no second argument + # to recvmsg()). Check that data is MSG, MSG_CTRUNC is set + # (unless included in ignoreflags), and no ancillary data is + # returned. + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVHOPLIMIT, 1) + self.misc_event.set() + args = () if ancbufsize is None else (ancbufsize,) + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), *args) + + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC, + ignore=ignoreflags) + + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testCmsgTruncNoBufSize(self): + # Check that no ancillary data is received when no ancillary + # buffer size is provided. + self.checkHopLimitTruncatedHeader(ancbufsize=None, + # BSD seems to set + # MSG_CTRUNC only if an item + # has been partially + # received. + ignoreflags=socket.MSG_CTRUNC) + + @testCmsgTruncNoBufSize.client_skip + def _testCmsgTruncNoBufSize(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testSingleCmsgTrunc0(self): + # Check that no ancillary data is received when ancillary + # buffer size is zero. + self.checkHopLimitTruncatedHeader(ancbufsize=0, + ignoreflags=socket.MSG_CTRUNC) + + @testSingleCmsgTrunc0.client_skip + def _testSingleCmsgTrunc0(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + # Check that no ancillary data is returned for various non-zero + # (but still too small) buffer sizes. + + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testSingleCmsgTrunc1(self): + self.checkHopLimitTruncatedHeader(ancbufsize=1) + + @testSingleCmsgTrunc1.client_skip + def _testSingleCmsgTrunc1(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testSingleCmsgTrunc2Int(self): + self.checkHopLimitTruncatedHeader(ancbufsize=2 * SIZEOF_INT) + + @testSingleCmsgTrunc2Int.client_skip + def _testSingleCmsgTrunc2Int(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testSingleCmsgTruncLen0Minus1(self): + self.checkHopLimitTruncatedHeader(ancbufsize=socket.CMSG_LEN(0) - 1) + + @testSingleCmsgTruncLen0Minus1.client_skip + def _testSingleCmsgTruncLen0Minus1(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT") + def testSingleCmsgTruncInData(self): + # Test truncation of a control message inside its associated + # data. The message may be returned with its data truncated, + # or not returned at all. + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVHOPLIMIT, 1) + self.misc_event.set() + msg, ancdata, flags, addr = self.doRecvmsg( + self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1) + + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC) + + self.assertLessEqual(len(ancdata), 1) + if ancdata: + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + self.assertEqual(cmsg_level, socket.IPPROTO_IPV6) + self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT) + self.assertLess(len(cmsg_data), SIZEOF_INT) + + @testSingleCmsgTruncInData.client_skip + def _testSingleCmsgTruncInData(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + def checkTruncatedSecondHeader(self, ancbufsize, ignoreflags=0): + # Receive traffic class and hop limit into ancbufsize bytes of + # ancillary data space, which should be large enough to + # contain the first item, but too small to contain the header + # of the second. Check that data is MSG, MSG_CTRUNC is set + # (unless included in ignoreflags), and only one ancillary + # data item is returned. + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVHOPLIMIT, 1) + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVTCLASS, 1) + self.misc_event.set() + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), ancbufsize) + + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC, + ignore=ignoreflags) + + self.assertEqual(len(ancdata), 1) + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + self.assertEqual(cmsg_level, socket.IPPROTO_IPV6) + self.assertIn(cmsg_type, {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT}) + self.assertEqual(len(cmsg_data), SIZEOF_INT) + a = array.array("i") + a.fromstring(cmsg_data) + self.assertGreaterEqual(a[0], 0) + self.assertLessEqual(a[0], 255) + + # Try the above test with various buffer sizes. + + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testSecondCmsgTrunc0(self): + self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT), + ignoreflags=socket.MSG_CTRUNC) + + @testSecondCmsgTrunc0.client_skip + def _testSecondCmsgTrunc0(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testSecondCmsgTrunc1(self): + self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) + 1) + + @testSecondCmsgTrunc1.client_skip + def _testSecondCmsgTrunc1(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testSecondCmsgTrunc2Int(self): + self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) + + 2 * SIZEOF_INT) + + @testSecondCmsgTrunc2Int.client_skip + def _testSecondCmsgTrunc2Int(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testSecondCmsgTruncLen0Minus1(self): + self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) + + socket.CMSG_LEN(0) - 1) + + @testSecondCmsgTruncLen0Minus1.client_skip + def _testSecondCmsgTruncLen0Minus1(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT", + "IPV6_RECVTCLASS", "IPV6_TCLASS") + def testSecomdCmsgTruncInData(self): + # Test truncation of the second of two control messages inside + # its associated data. + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVHOPLIMIT, 1) + self.serv_sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_RECVTCLASS, 1) + self.misc_event.set() + msg, ancdata, flags, addr = self.doRecvmsg( + self.serv_sock, len(MSG), + socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1) + + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC) + + cmsg_types = {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT} + + cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0) + self.assertEqual(cmsg_level, socket.IPPROTO_IPV6) + cmsg_types.remove(cmsg_type) + self.assertEqual(len(cmsg_data), SIZEOF_INT) + a = array.array("i") + a.fromstring(cmsg_data) + self.assertGreaterEqual(a[0], 0) + self.assertLessEqual(a[0], 255) + + if ancdata: + cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0) + self.assertEqual(cmsg_level, socket.IPPROTO_IPV6) + cmsg_types.remove(cmsg_type) + self.assertLess(len(cmsg_data), SIZEOF_INT) + + self.assertEqual(ancdata, []) + + @testSecomdCmsgTruncInData.client_skip + def _testSecomdCmsgTruncInData(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + self.sendToServer(MSG) + + +# Derive concrete test classes for different socket types. + +class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase, + SendrecvmsgConnectionlessBase, + ThreadedSocketTestMixin, UDPTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase): + pass + + +class SendrecvmsgUDP6TestBase(SendrecvmsgDgramFlagsBase, + SendrecvmsgConnectionlessBase, + ThreadedSocketTestMixin, UDP6TestBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support") +@requireSocket("AF_INET6", "SOCK_DGRAM") +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgUDP6TestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support") +@requireSocket("AF_INET6", "SOCK_DGRAM") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgUDP6Test(RecvmsgTests, SendrecvmsgUDP6TestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support") +@requireSocket("AF_INET6", "SOCK_DGRAM") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support") +@requireAttrs(socket, "IPPROTO_IPV6") +@requireSocket("AF_INET6", "SOCK_DGRAM") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest, + SendrecvmsgUDP6TestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support") +@requireAttrs(socket, "IPPROTO_IPV6") +@requireSocket("AF_INET6", "SOCK_DGRAM") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin, + RFC3542AncillaryTest, + SendrecvmsgUDP6TestBase): + pass + + +class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase, + ConnectedStreamTestMixin, TCPTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgTCPTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests, + SendrecvmsgTCPTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, + SendrecvmsgTCPTestBase): + pass + + +class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase, + SendrecvmsgConnectedBase, + ConnectedStreamTestMixin, SCTPStreamBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, + SendrecvmsgSCTPStreamTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, + SendrecvmsgSCTPStreamTestBase): + pass + + +class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase, + ConnectedStreamTestMixin, UnixStreamBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +@requireAttrs(socket, "AF_UNIX") +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@requireAttrs(socket, "AF_UNIX") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, + SendrecvmsgUnixStreamTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@requireAttrs(socket, "AF_UNIX") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, + SendrecvmsgUnixStreamTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg", "recvmsg") +@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg", "recvmsg_into") +@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS") +@unittest.skipUnless(thread, 'Threading required for this test.') +class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest, + SendrecvmsgUnixStreamTestBase): + pass + + +# Test interrupting the interruptible send/receive methods with a +# signal when a timeout is set. These tests avoid having multiple +# threads alive during the test so that the OS cannot deliver the +# signal to the wrong one. + +class InterruptedTimeoutBase(unittest.TestCase): + # Base class for interrupted send/receive tests. Installs an + # empty handler for SIGALRM and removes it on teardown, along with + # any scheduled alarms. + + def setUp(self): + super(InterruptedTimeoutBase, self).setUp() + orig_alrm_handler = signal.signal(signal.SIGALRM, + lambda signum, frame: None) + self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler) + self.addCleanup(self.setAlarm, 0) + + # Timeout for socket operations + timeout = 4.0 + + # Provide setAlarm() method to schedule delivery of SIGALRM after + # given number of seconds, or cancel it if zero, and an + # appropriate time value to use. Use setitimer() if available. + if hasattr(signal, "setitimer"): + alarm_time = 0.05 + + def setAlarm(self, seconds): + signal.setitimer(signal.ITIMER_REAL, seconds) + else: + # Old systems may deliver the alarm up to one second early + alarm_time = 2 + + def setAlarm(self, seconds): + signal.alarm(seconds) + + +# Require siginterrupt() in order to ensure that system calls are +# interrupted by default. +@requireAttrs(signal, "siginterrupt") +@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"), + "Don't have signal.alarm or signal.setitimer") +class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase): + # Test interrupting the recv*() methods with signals when a + # timeout is set. + + def setUp(self): + super(InterruptedRecvTimeoutTest, self).setUp() + self.serv.settimeout(self.timeout) + + def checkInterruptedRecv(self, func, *args, **kwargs): + # Check that func(*args, **kwargs) raises socket.error with an + # errno of EINTR when interrupted by a signal. + self.setAlarm(self.alarm_time) + with self.assertRaises(socket.error) as cm: + func(*args, **kwargs) + self.assertNotIsInstance(cm.exception, socket.timeout) + self.assertEqual(cm.exception.errno, errno.EINTR) + + def testInterruptedRecvTimeout(self): + self.checkInterruptedRecv(self.serv.recv, 1024) + + def testInterruptedRecvIntoTimeout(self): + self.checkInterruptedRecv(self.serv.recv_into, bytearray(1024)) + + def testInterruptedRecvfromTimeout(self): + self.checkInterruptedRecv(self.serv.recvfrom, 1024) + + def testInterruptedRecvfromIntoTimeout(self): + self.checkInterruptedRecv(self.serv.recvfrom_into, bytearray(1024)) + + @requireAttrs(socket.socket, "recvmsg") + def testInterruptedRecvmsgTimeout(self): + self.checkInterruptedRecv(self.serv.recvmsg, 1024) + + @requireAttrs(socket.socket, "recvmsg_into") + def testInterruptedRecvmsgIntoTimeout(self): + self.checkInterruptedRecv(self.serv.recvmsg_into, [bytearray(1024)]) + + +# Require siginterrupt() in order to ensure that system calls are +# interrupted by default. +@requireAttrs(signal, "siginterrupt") +@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"), + "Don't have signal.alarm or signal.setitimer") +@unittest.skipUnless(thread, 'Threading required for this test.') +class InterruptedSendTimeoutTest(InterruptedTimeoutBase, + ThreadSafeCleanupTestCase, + SocketListeningTestMixin, TCPTestBase): + # Test interrupting the interruptible send*() methods with signals + # when a timeout is set. + + def setUp(self): + super(InterruptedSendTimeoutTest, self).setUp() + self.serv_conn = self.newSocket() + self.addCleanup(self.serv_conn.close) + # Use a thread to complete the connection, but wait for it to + # terminate before running the test, so that there is only one + # thread to accept the signal. + cli_thread = threading.Thread(target=self.doConnect) + cli_thread.start() + self.cli_conn, addr = self.serv.accept() + self.addCleanup(self.cli_conn.close) + cli_thread.join() + self.serv_conn.settimeout(self.timeout) + + def doConnect(self): + self.serv_conn.connect(self.serv_addr) + + def checkInterruptedSend(self, func, *args, **kwargs): + # Check that func(*args, **kwargs), run in a loop, raises + # socket.error with an errno of EINTR when interrupted by a + # signal. + with self.assertRaises(socket.error) as cm: + while True: + self.setAlarm(self.alarm_time) + func(*args, **kwargs) + self.assertNotIsInstance(cm.exception, socket.timeout) + self.assertEqual(cm.exception.errno, errno.EINTR) + + def testInterruptedSendTimeout(self): + self.checkInterruptedSend(self.serv_conn.send, b"a"*512) + + def testInterruptedSendtoTimeout(self): + # Passing an actual address here as Python's wrapper for + # sendto() doesn't allow passing a zero-length one; POSIX + # requires that the address is ignored since the socket is + # connection-mode, however. + self.checkInterruptedSend(self.serv_conn.sendto, b"a"*512, + self.serv_addr) + + @requireAttrs(socket.socket, "sendmsg") + def testInterruptedSendmsgTimeout(self): + self.checkInterruptedSend(self.serv_conn.sendmsg, [b"a"*512]) + + @unittest.skipUnless(thread, 'Threading required for this test.') class TCPCloserTest(ThreadedTCPSocketTest): @@ -1609,6 +3707,31 @@ def test_main(): if isTipcAvailable(): tests.append(TIPCTest) tests.append(TIPCThreadableTest) + tests.extend([ + CmsgMacroTests, + SendmsgUDPTest, + RecvmsgUDPTest, + RecvmsgIntoUDPTest, + SendmsgUDP6Test, + RecvmsgUDP6Test, + RecvmsgRFC3542AncillaryUDP6Test, + RecvmsgIntoRFC3542AncillaryUDP6Test, + RecvmsgIntoUDP6Test, + SendmsgTCPTest, + RecvmsgTCPTest, + RecvmsgIntoTCPTest, + SendmsgSCTPStreamTest, + RecvmsgSCTPStreamTest, + RecvmsgIntoSCTPStreamTest, + SendmsgUnixStreamTest, + RecvmsgUnixStreamTest, + RecvmsgIntoUnixStreamTest, + RecvmsgSCMRightsStreamTest, + RecvmsgIntoSCMRightsStreamTest, + # These are slow when setitimer() is not available + InterruptedRecvTimeoutTest, + InterruptedSendTimeoutTest, + ]) thread_info = test_support.threading_setup() test_support.run_unittest(*tests) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -188,8 +188,11 @@ class BasicSocketTests(unittest.TestCase self.assertRaises(socket.error, ss.recv_into, bytearray(b'x')) self.assertRaises(socket.error, ss.recvfrom, 1) self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1) + self.assertRaises(socket.error, ss.recvmsg, 1) + self.assertRaises(socket.error, ss.recvmsg_into, [bytearray(b'x')]) self.assertRaises(socket.error, ss.send, b'x') self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0)) + self.assertRaises(socket.error, ss.sendmsg, [b'x']) class NetworkedTests(unittest.TestCase): @@ -1200,17 +1203,30 @@ else: count, addr = s.recvfrom_into(b) return b[:count] + def _recvmsg(*args, **kwargs): + return s.recvmsg(*args, **kwargs)[0] + + def _recvmsg_into(bufsize, *args, **kwargs): + b = bytearray(bufsize) + return bytes(b[:s.recvmsg_into([b], *args, **kwargs)[0]]) + + def _sendmsg(msg, *args, **kwargs): + return s.sendmsg([msg]) + # (name, method, whether to expect success, *args) send_methods = [ ('send', s.send, True, []), ('sendto', s.sendto, False, ["some.address"]), + ('sendmsg', _sendmsg, False, []), ('sendall', s.sendall, True, []), ] recv_methods = [ ('recv', s.recv, True, []), ('recvfrom', s.recvfrom, False, ["some.address"]), + ('recvmsg', _recvmsg, False, [100]), ('recv_into', _recv_into, True, []), ('recvfrom_into', _recvfrom_into, False, []), + ('recvmsg_into', _recvmsg_into, False, [100]), ] data_prefix = u"PREFIX_" diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c --- a/Modules/socketmodule.c +++ b/Modules/socketmodule.c @@ -251,6 +251,7 @@ shutdown(how) -- shut down traffic in on #ifdef HAVE_SYS_TYPES_H #include #endif +#include /* Generic socket object definitions and includes */ #define PySocket_BUILDING_SOCKET @@ -456,6 +457,17 @@ static PyTypeObject sock_type; #include #endif +/* Largest value to try to store in a socklen_t (used when handling + ancillary data). POSIX requires socklen_t to hold at least + (2**31)-1 and recommends against storing larger values, but + socklen_t was originally int in the BSD interface, so to be on the + safe side we use the smaller of (2**31)-1 and INT_MAX. */ +#if INT_MAX > 0x7fffffff +#define SOCKLEN_T_LIMIT 0x7fffffff +#else +#define SOCKLEN_T_LIMIT INT_MAX +#endif + #ifdef Py_SOCKET_FD_CAN_BE_GE_FD_SETSIZE /* Platform can select file descriptors beyond FD_SETSIZE */ #define IS_SELECTABLE(s) 1 @@ -1627,6 +1639,117 @@ getsockaddrlen(PySocketSockObject *s, so } +/* Support functions for the sendmsg() and recvmsg[_into]() methods. + Currently, these methods are only compiled if the RFC 2292/3542 + CMSG_LEN() macro is available. Older systems seem to have used + sizeof(struct cmsghdr) + (length) where CMSG_LEN() is used now, so + it may be possible to define CMSG_LEN() that way if it's not + provided. Some architectures might need extra padding after the + cmsghdr, however, and CMSG_LEN() would have to take account of + this. */ +#ifdef CMSG_LEN +/* If length is in range, set *result to CMSG_LEN(length) and return + true; otherwise, return false. */ +static int +get_CMSG_LEN(size_t length, size_t *result) +{ + size_t tmp; + + if (length > (SOCKLEN_T_LIMIT - CMSG_LEN(0))) + return 0; + tmp = CMSG_LEN(length); + if (tmp > SOCKLEN_T_LIMIT || tmp < length) + return 0; + *result = tmp; + return 1; +} + +#ifdef CMSG_SPACE +/* If length is in range, set *result to CMSG_SPACE(length) and return + true; otherwise, return false. */ +static int +get_CMSG_SPACE(size_t length, size_t *result) +{ + size_t tmp; + + /* Use CMSG_SPACE(1) here in order to take account of the padding + necessary before *and* after the data. */ + if (length > (SOCKLEN_T_LIMIT - CMSG_SPACE(1))) + return 0; + tmp = CMSG_SPACE(length); + if (tmp > SOCKLEN_T_LIMIT || tmp < length) + return 0; + *result = tmp; + return 1; +} +#endif + +/* Return true iff msg->msg_controllen is valid, cmsgh is a valid + pointer in msg->msg_control with at least "space" bytes after it, + and its cmsg_len member inside the buffer. */ +static int +cmsg_min_space(struct msghdr *msg, struct cmsghdr *cmsgh, size_t space) +{ + size_t cmsg_offset; + static const size_t cmsg_len_end = (offsetof(struct cmsghdr, cmsg_len) + + sizeof(cmsgh->cmsg_len)); + + if (cmsgh == NULL || msg->msg_control == NULL || msg->msg_controllen < 0) + return 0; + if (space < cmsg_len_end) + space = cmsg_len_end; + cmsg_offset = (char *)cmsgh - (char *)msg->msg_control; + return (cmsg_offset <= (size_t)-1 - space && + cmsg_offset + space <= msg->msg_controllen); +} + +/* If pointer CMSG_DATA(cmsgh) is in buffer msg->msg_control, set + *space to number of bytes following it in the buffer and return + true; otherwise, return false. Assumes cmsgh, msg->msg_control and + msg->msg_controllen are valid. */ +static int +get_cmsg_data_space(struct msghdr *msg, struct cmsghdr *cmsgh, size_t *space) +{ + size_t data_offset; + char *data_ptr; + + if ((data_ptr = (char *)CMSG_DATA(cmsgh)) == NULL) + return 0; + data_offset = data_ptr - (char *)msg->msg_control; + if (data_offset > msg->msg_controllen) + return 0; + *space = msg->msg_controllen - data_offset; + return 1; +} + +/* If cmsgh is invalid or not contained in the buffer pointed to by + msg->msg_control, return -1. If cmsgh is valid and its associated + data is entirely contained in the buffer, set *data_len to the + length of the associated data and return 0. If only part of the + associated data is contained in the buffer but cmsgh is otherwise + valid, set *data_len to the length contained in the buffer and + return 1. */ +static int +get_cmsg_data_len(struct msghdr *msg, struct cmsghdr *cmsgh, size_t *data_len) +{ + size_t space, cmsg_data_len; + + if (!cmsg_min_space(msg, cmsgh, CMSG_LEN(0)) || + cmsgh->cmsg_len < CMSG_LEN(0)) + return -1; + cmsg_data_len = cmsgh->cmsg_len - CMSG_LEN(0); + if (!get_cmsg_data_space(msg, cmsgh, &space)) + return -1; + if (space >= cmsg_data_len) { + *data_len = cmsg_data_len; + return 0; + } + *data_len = space; + return 1; +} +#endif /* CMSG_LEN */ + + /* s.accept() method */ static PyObject * @@ -2697,6 +2820,333 @@ PyDoc_STRVAR(recvfrom_into_doc, Like recv_into(buffer[, nbytes[, flags]]) but also return the sender's address info."); +/* The sendmsg() and recvmsg[_into]() methods require a working + CMSG_LEN(). See the comment near get_CMSG_LEN(). */ +#ifdef CMSG_LEN +/* + * Call recvmsg() with the supplied iovec structures, flags, and + * ancillary data buffer size (controllen). Returns the tuple return + * value for recvmsg() or recvmsg_into(), with the first item provided + * by the supplied makeval() function. makeval() will be called with + * the length read and makeval_data as arguments, and must return a + * new reference (which will be decrefed if there is a subsequent + * error). On error, closes any file descriptors received via + * SCM_RIGHTS. + */ +static PyObject * +sock_recvmsg_guts(PySocketSockObject *s, struct iovec *iov, int iovlen, + int flags, Py_ssize_t controllen, + PyObject *(*makeval)(ssize_t, void *), void *makeval_data) +{ + ssize_t bytes_received = -1; + int timeout; + sock_addr_t addrbuf; + socklen_t addrbuflen; + static const struct msghdr msg_blank; + struct msghdr msg; + PyObject *cmsg_list = NULL, *retval = NULL; + void *controlbuf = NULL; + struct cmsghdr *cmsgh; + size_t cmsgdatalen = 0; + int cmsg_status; + + /* XXX: POSIX says that msg_name and msg_namelen "shall be + ignored" when the socket is connected (Linux fills them in + anyway for AF_UNIX sockets at least). Normally msg_namelen + seems to be set to 0 if there's no address, but try to + initialize msg_name to something that won't be mistaken for a + real address if that doesn't happen. */ + if (!getsockaddrlen(s, &addrbuflen)) + return NULL; + memset(&addrbuf, 0, addrbuflen); + SAS2SA(&addrbuf)->sa_family = AF_UNSPEC; + + if (controllen < 0 || controllen > SOCKLEN_T_LIMIT) { + PyErr_SetString(PyExc_ValueError, + "invalid ancillary data buffer length"); + return NULL; + } + if (controllen > 0 && (controlbuf = PyMem_Malloc(controllen)) == NULL) + return PyErr_NoMemory(); + + /* Make the system call. */ + if (!IS_SELECTABLE(s)) { + select_error(); + goto finally; + } + + msg = msg_blank; /* Set all members to 0 or NULL */ + msg.msg_name = SAS2SA(&addrbuf); + msg.msg_namelen = addrbuflen; + msg.msg_iov = iov; + msg.msg_iovlen = iovlen; + msg.msg_control = controlbuf; + msg.msg_controllen = controllen; + + Py_BEGIN_ALLOW_THREADS; + timeout = internal_select(s, 0); + if (!timeout) + bytes_received = recvmsg(s->sock_fd, &msg, flags); + Py_END_ALLOW_THREADS; + + if (timeout == 1) { + PyErr_SetString(socket_timeout, "timed out"); + goto finally; + } + + if (bytes_received < 0) { + s->errorhandler(); + goto finally; + } + + /* Make list of (level, type, data) tuples from control messages. */ + if ((cmsg_list = PyList_New(0)) == NULL) + goto err_closefds; + /* Check for empty ancillary data as old CMSG_FIRSTHDR() + implementations didn't do so. */ + for (cmsgh = ((msg.msg_controllen > 0) ? CMSG_FIRSTHDR(&msg) : NULL); + cmsgh != NULL; cmsgh = CMSG_NXTHDR(&msg, cmsgh)) { + PyObject *bytes, *tuple; + int tmp; + + cmsg_status = get_cmsg_data_len(&msg, cmsgh, &cmsgdatalen); + if (cmsg_status != 0) { + if (PyErr_WarnEx(PyExc_RuntimeWarning, + "received malformed or improperly-truncated " + "ancillary data", 1) == -1) + goto err_closefds; + } + if (cmsg_status < 0) + break; + if (cmsgdatalen > PY_SSIZE_T_MAX) { + PyErr_SetString(socket_error, "control message too long"); + goto err_closefds; + } + + bytes = PyBytes_FromStringAndSize((char *)CMSG_DATA(cmsgh), + cmsgdatalen); + tuple = Py_BuildValue("iiN", (int)cmsgh->cmsg_level, + (int)cmsgh->cmsg_type, bytes); + if (tuple == NULL) + goto err_closefds; + tmp = PyList_Append(cmsg_list, tuple); + Py_DECREF(tuple); + if (tmp != 0) + goto err_closefds; + + if (cmsg_status != 0) + break; + } + + retval = Py_BuildValue("NOiN", + (*makeval)(bytes_received, makeval_data), + cmsg_list, + (int)msg.msg_flags, + makesockaddr(s->sock_fd, SAS2SA(&addrbuf), + ((msg.msg_namelen > addrbuflen) ? + addrbuflen : msg.msg_namelen), + s->sock_proto)); + if (retval == NULL) + goto err_closefds; + +finally: + Py_XDECREF(cmsg_list); + PyMem_Free(controlbuf); + return retval; + +err_closefds: +#ifdef SCM_RIGHTS + /* Close all descriptors coming from SCM_RIGHTS, so they don't leak. */ + for (cmsgh = ((msg.msg_controllen > 0) ? CMSG_FIRSTHDR(&msg) : NULL); + cmsgh != NULL; cmsgh = CMSG_NXTHDR(&msg, cmsgh)) { + cmsg_status = get_cmsg_data_len(&msg, cmsgh, &cmsgdatalen); + if (cmsg_status < 0) + break; + if (cmsgh->cmsg_level == SOL_SOCKET && + cmsgh->cmsg_type == SCM_RIGHTS) { + size_t numfds; + int *fdp; + + numfds = cmsgdatalen / sizeof(int); + fdp = (int *)CMSG_DATA(cmsgh); + while (numfds-- > 0) + close(*fdp++); + } + if (cmsg_status != 0) + break; + } +#endif /* SCM_RIGHTS */ + goto finally; +} + + +static PyObject * +makeval_recvmsg(ssize_t received, void *data) +{ + PyObject **buf = data; + + if (received < PyBytes_GET_SIZE(*buf)) + _PyBytes_Resize(buf, received); + Py_XINCREF(*buf); + return *buf; +} + +/* s.recvmsg(bufsize[, ancbufsize[, flags]]) method */ + +static PyObject * +sock_recvmsg(PySocketSockObject *s, PyObject *args) +{ + Py_ssize_t bufsize, ancbufsize = 0; + int flags = 0; + struct iovec iov; + PyObject *buf = NULL, *retval = NULL; + + if (!PyArg_ParseTuple(args, "n|ni:recvmsg", &bufsize, &ancbufsize, &flags)) + return NULL; + + if (bufsize < 0) { + PyErr_SetString(PyExc_ValueError, "negative buffer size in recvmsg()"); + return NULL; + } + if ((buf = PyBytes_FromStringAndSize(NULL, bufsize)) == NULL) + return NULL; + iov.iov_base = PyBytes_AS_STRING(buf); + iov.iov_len = bufsize; + + /* Note that we're passing a pointer to *our pointer* to the bytes + object here (&buf); makeval_recvmsg() may incref the object, or + deallocate it and set our pointer to NULL. */ + retval = sock_recvmsg_guts(s, &iov, 1, flags, ancbufsize, + &makeval_recvmsg, &buf); + Py_XDECREF(buf); + return retval; +} + +PyDoc_STRVAR(recvmsg_doc, +"recvmsg(bufsize[, ancbufsize[, flags]]) -> (data, ancdata, msg_flags, address)\n\ +\n\ +Receive normal data (up to bufsize bytes) and ancillary data from the\n\ +socket. The ancbufsize argument sets the size in bytes of the\n\ +internal buffer used to receive the ancillary data; it defaults to 0,\n\ +meaning that no ancillary data will be received. Appropriate buffer\n\ +sizes for ancillary data can be calculated using CMSG_SPACE() or\n\ +CMSG_LEN(), and items which do not fit into the buffer might be\n\ +truncated or discarded. The flags argument defaults to 0 and has the\n\ +same meaning as for recv().\n\ +\n\ +The return value is a 4-tuple: (data, ancdata, msg_flags, address).\n\ +The data item is a bytes object holding the non-ancillary data\n\ +received. The ancdata item is a list of zero or more tuples\n\ +(cmsg_level, cmsg_type, cmsg_data) representing the ancillary data\n\ +(control messages) received: cmsg_level and cmsg_type are integers\n\ +specifying the protocol level and protocol-specific type respectively,\n\ +and cmsg_data is a bytes object holding the associated data. The\n\ +msg_flags item is the bitwise OR of various flags indicating\n\ +conditions on the received message; see your system documentation for\n\ +details. If the receiving socket is unconnected, address is the\n\ +address of the sending socket, if available; otherwise, its value is\n\ +unspecified.\n\ +\n\ +If recvmsg() raises an exception after the system call returns, it\n\ +will first attempt to close any file descriptors received via the\n\ +SCM_RIGHTS mechanism."); + + +static PyObject * +makeval_recvmsg_into(ssize_t received, void *data) +{ + return PyInt_FromSsize_t(received); +} + +/* s.recvmsg_into(buffers[, ancbufsize[, flags]]) method */ + +static PyObject * +sock_recvmsg_into(PySocketSockObject *s, PyObject *args) +{ + Py_ssize_t ancbufsize = 0; + int flags = 0; + struct iovec *iovs = NULL; + Py_ssize_t i, nitems, nbufs = 0; + Py_buffer *bufs = NULL; + PyObject *buffers_arg, *fast, *retval = NULL; + + if (!PyArg_ParseTuple(args, "O|ni:recvmsg_into", + &buffers_arg, &ancbufsize, &flags)) + return NULL; + + if ((fast = PySequence_Fast(buffers_arg, + "recvmsg_into() argument 1 must be an " + "iterable")) == NULL) + return NULL; + nitems = PySequence_Fast_GET_SIZE(fast); + if (nitems > INT_MAX) { + PyErr_SetString(socket_error, "recvmsg_into() argument 1 is too long"); + goto finally; + } + + /* Fill in an iovec for each item, and save the Py_buffer + structs to release afterwards. */ + if (nitems > 0 && ((iovs = PyMem_New(struct iovec, nitems)) == NULL || + (bufs = PyMem_New(Py_buffer, nitems)) == NULL)) { + PyErr_NoMemory(); + goto finally; + } + for (; nbufs < nitems; nbufs++) { + if (!PyArg_Parse(PySequence_Fast_GET_ITEM(fast, nbufs), + "w*;recvmsg_into() argument 1 must be an iterable " + "of single-segment read-write buffers", + &bufs[nbufs])) + goto finally; + iovs[nbufs].iov_base = bufs[nbufs].buf; + iovs[nbufs].iov_len = bufs[nbufs].len; + } + + retval = sock_recvmsg_guts(s, iovs, nitems, flags, ancbufsize, + &makeval_recvmsg_into, NULL); +finally: + for (i = 0; i < nbufs; i++) + PyBuffer_Release(&bufs[i]); + PyMem_Free(bufs); + PyMem_Free(iovs); + Py_DECREF(fast); + return retval; +} + +PyDoc_STRVAR(recvmsg_into_doc, +"recvmsg_into(buffers[, ancbufsize[, flags]]) -> (nbytes, ancdata, msg_flags, address)\n\ +\n\ +Receive normal data and ancillary data from the socket, scattering the\n\ +non-ancillary data into a series of buffers. The buffers argument\n\ +must be an iterable of objects that export writable buffers\n\ +(e.g. bytearray objects); these will be filled with successive chunks\n\ +of the non-ancillary data until it has all been written or there are\n\ +no more buffers. The ancbufsize argument sets the size in bytes of\n\ +the internal buffer used to receive the ancillary data; it defaults to\n\ +0, meaning that no ancillary data will be received. Appropriate\n\ +buffer sizes for ancillary data can be calculated using CMSG_SPACE()\n\ +or CMSG_LEN(), and items which do not fit into the buffer might be\n\ +truncated or discarded. The flags argument defaults to 0 and has the\n\ +same meaning as for recv().\n\ +\n\ +The return value is a 4-tuple: (nbytes, ancdata, msg_flags, address).\n\ +The nbytes item is the total number of bytes of non-ancillary data\n\ +written into the buffers. The ancdata item is a list of zero or more\n\ +tuples (cmsg_level, cmsg_type, cmsg_data) representing the ancillary\n\ +data (control messages) received: cmsg_level and cmsg_type are\n\ +integers specifying the protocol level and protocol-specific type\n\ +respectively, and cmsg_data is a bytes object holding the associated\n\ +data. The msg_flags item is the bitwise OR of various flags\n\ +indicating conditions on the received message; see your system\n\ +documentation for details. If the receiving socket is unconnected,\n\ +address is the address of the sending socket, if available; otherwise,\n\ +its value is unspecified.\n\ +\n\ +If recvmsg_into() raises an exception after the system call returns,\n\ +it will first attempt to close any file descriptors received via the\n\ +SCM_RIGHTS mechanism."); +#endif /* CMSG_LEN */ + + /* s.send(data [,flags]) method */ static PyObject * @@ -2883,6 +3333,236 @@ Like send(data, flags) but allows specif For IP sockets, the address is a pair (hostaddr, port)."); +/* The sendmsg() and recvmsg[_into]() methods require a working + CMSG_LEN(). See the comment near get_CMSG_LEN(). */ +#ifdef CMSG_LEN +/* s.sendmsg(buffers[, ancdata[, flags[, address]]]) method */ + +static PyObject * +sock_sendmsg(PySocketSockObject *s, PyObject *args) +{ + Py_ssize_t i, ndataparts, ndatabufs = 0, ncmsgs, ncmsgbufs = 0; + Py_buffer *databufs = NULL; + struct iovec *iovs = NULL; + sock_addr_t addrbuf; + static const struct msghdr msg_blank; + struct msghdr msg; + struct cmsginfo { + int level; + int type; + Py_buffer data; + } *cmsgs = NULL; + void *controlbuf = NULL; + size_t controllen, controllen_last; + ssize_t bytes_sent = -1; + int addrlen, timeout, flags = 0; + PyObject *data_arg, *cmsg_arg = NULL, *addr_arg = NULL, *data_fast = NULL, + *cmsg_fast = NULL, *retval = NULL; + + if (!PyArg_ParseTuple(args, "O|OiO:sendmsg", + &data_arg, &cmsg_arg, &flags, &addr_arg)) + return NULL; + + msg = msg_blank; /* Set all members to 0 or NULL */ + + /* Parse destination address. */ + if (addr_arg != NULL && addr_arg != Py_None) { + if (!getsockaddrarg(s, addr_arg, SAS2SA(&addrbuf), &addrlen)) + goto finally; + msg.msg_name = &addrbuf; + msg.msg_namelen = addrlen; + } + + /* Fill in an iovec for each message part, and save the Py_buffer + structs to release afterwards. */ + if ((data_fast = PySequence_Fast(data_arg, + "sendmsg() argument 1 must be an " + "iterable")) == NULL) + goto finally; + ndataparts = PySequence_Fast_GET_SIZE(data_fast); + if (ndataparts > INT_MAX) { + PyErr_SetString(socket_error, "sendmsg() argument 1 is too long"); + goto finally; + } + msg.msg_iovlen = ndataparts; + if (ndataparts > 0 && + ((msg.msg_iov = iovs = PyMem_New(struct iovec, ndataparts)) == NULL || + (databufs = PyMem_New(Py_buffer, ndataparts)) == NULL)) { + PyErr_NoMemory(); + goto finally; + } + for (; ndatabufs < ndataparts; ndatabufs++) { + if (!PyArg_Parse(PySequence_Fast_GET_ITEM(data_fast, ndatabufs), + "s*;sendmsg() argument 1 must be an iterable of " + "buffer-compatible objects", + &databufs[ndatabufs])) + goto finally; + iovs[ndatabufs].iov_base = databufs[ndatabufs].buf; + iovs[ndatabufs].iov_len = databufs[ndatabufs].len; + } + + if (cmsg_arg == NULL) + ncmsgs = 0; + else { + if ((cmsg_fast = PySequence_Fast(cmsg_arg, + "sendmsg() argument 2 must be an " + "iterable")) == NULL) + goto finally; + ncmsgs = PySequence_Fast_GET_SIZE(cmsg_fast); + } + +#ifndef CMSG_SPACE + if (ncmsgs > 1) { + PyErr_SetString(socket_error, + "sending multiple control messages is not supported " + "on this system"); + goto finally; + } +#endif + /* Save level, type and Py_buffer for each control message, + and calculate total size. */ + if (ncmsgs > 0 && (cmsgs = PyMem_New(struct cmsginfo, ncmsgs)) == NULL) { + PyErr_NoMemory(); + goto finally; + } + controllen = controllen_last = 0; + while (ncmsgbufs < ncmsgs) { + size_t bufsize, space; + + if (!PyArg_Parse(PySequence_Fast_GET_ITEM(cmsg_fast, ncmsgbufs), + "(iis*):[sendmsg() ancillary data items]", + &cmsgs[ncmsgbufs].level, + &cmsgs[ncmsgbufs].type, + &cmsgs[ncmsgbufs].data)) + goto finally; + bufsize = cmsgs[ncmsgbufs++].data.len; + +#ifdef CMSG_SPACE + if (!get_CMSG_SPACE(bufsize, &space)) { +#else + if (!get_CMSG_LEN(bufsize, &space)) { +#endif + PyErr_SetString(socket_error, "ancillary data item too large"); + goto finally; + } + controllen += space; + if (controllen > SOCKLEN_T_LIMIT || controllen < controllen_last) { + PyErr_SetString(socket_error, "too much ancillary data"); + goto finally; + } + controllen_last = controllen; + } + + /* Construct ancillary data block from control message info. */ + if (ncmsgbufs > 0) { + struct cmsghdr *cmsgh = NULL; + + if ((msg.msg_control = controlbuf = + PyMem_Malloc(controllen)) == NULL) { + PyErr_NoMemory(); + goto finally; + } + msg.msg_controllen = controllen; + + /* Need to zero out the buffer as a workaround for glibc's + CMSG_NXTHDR() implementation. After getting the pointer to + the next header, it checks its (uninitialized) cmsg_len + member to see if the "message" fits in the buffer, and + returns NULL if it doesn't. Zero-filling the buffer + ensures that that doesn't happen. */ + memset(controlbuf, 0, controllen); + + for (i = 0; i < ncmsgbufs; i++) { + size_t msg_len, data_len = cmsgs[i].data.len; + int enough_space = 0; + + cmsgh = (i == 0) ? CMSG_FIRSTHDR(&msg) : CMSG_NXTHDR(&msg, cmsgh); + if (cmsgh == NULL) { + PyErr_Format(PyExc_RuntimeError, + "unexpected NULL result from %s()", + (i == 0) ? "CMSG_FIRSTHDR" : "CMSG_NXTHDR"); + goto finally; + } + if (!get_CMSG_LEN(data_len, &msg_len)) { + PyErr_SetString(PyExc_RuntimeError, + "item size out of range for CMSG_LEN()"); + goto finally; + } + if (cmsg_min_space(&msg, cmsgh, msg_len)) { + size_t space; + + cmsgh->cmsg_len = msg_len; + if (get_cmsg_data_space(&msg, cmsgh, &space)) + enough_space = (space >= data_len); + } + if (!enough_space) { + PyErr_SetString(PyExc_RuntimeError, + "ancillary data does not fit in calculated " + "space"); + goto finally; + } + cmsgh->cmsg_level = cmsgs[i].level; + cmsgh->cmsg_type = cmsgs[i].type; + memcpy(CMSG_DATA(cmsgh), cmsgs[i].data.buf, data_len); + } + } + + /* Make the system call. */ + if (!IS_SELECTABLE(s)) { + select_error(); + goto finally; + } + + Py_BEGIN_ALLOW_THREADS; + timeout = internal_select(s, 1); + if (!timeout) + bytes_sent = sendmsg(s->sock_fd, &msg, flags); + Py_END_ALLOW_THREADS; + + if (timeout == 1) { + PyErr_SetString(socket_timeout, "timed out"); + goto finally; + } + + if (bytes_sent < 0) { + s->errorhandler(); + goto finally; + } + retval = PyInt_FromSsize_t(bytes_sent); + +finally: + PyMem_Free(controlbuf); + for (i = 0; i < ncmsgbufs; i++) + PyBuffer_Release(&cmsgs[i].data); + PyMem_Free(cmsgs); + Py_XDECREF(cmsg_fast); + for (i = 0; i < ndatabufs; i++) + PyBuffer_Release(&databufs[i]); + PyMem_Free(databufs); + PyMem_Free(iovs); + Py_XDECREF(data_fast); + return retval; +} + +PyDoc_STRVAR(sendmsg_doc, +"sendmsg(buffers[, ancdata[, flags[, address]]]) -> count\n\ +\n\ +Send normal and ancillary data to the socket, gathering the\n\ +non-ancillary data from a series of buffers and concatenating it into\n\ +a single message. The buffers argument specifies the non-ancillary\n\ +data as an iterable of buffer-compatible objects (e.g. bytes objects).\n\ +The ancdata argument specifies the ancillary data (control messages)\n\ +as an iterable of zero or more tuples (cmsg_level, cmsg_type,\n\ +cmsg_data), where cmsg_level and cmsg_type are integers specifying the\n\ +protocol level and protocol-specific type respectively, and cmsg_data\n\ +is a buffer-compatible object holding the associated data. The flags\n\ +argument defaults to 0 and has the same meaning as for send(). If\n\ +address is supplied and not None, it sets a destination address for\n\ +the message. The return value is the number of bytes of non-ancillary\n\ +data sent."); +#endif /* CMSG_LEN */ + + /* s.shutdown(how) method */ static PyObject * @@ -3019,6 +3699,14 @@ static PyMethodDef sock_methods[] = { {"sleeptaskw", (PyCFunction)sock_sleeptaskw, METH_O, sleeptaskw_doc}, #endif +#ifdef CMSG_LEN + {"recvmsg", (PyCFunction)sock_recvmsg, METH_VARARGS, + recvmsg_doc}, + {"recvmsg_into", (PyCFunction)sock_recvmsg_into, METH_VARARGS, + recvmsg_into_doc,}, + {"sendmsg", (PyCFunction)sock_sendmsg, METH_VARARGS, + sendmsg_doc}, +#endif {NULL, NULL} /* sentinel */ }; @@ -4287,6 +4975,68 @@ A value of None indicates that new socke When the socket module is first imported, the default is None."); +#ifdef CMSG_LEN +/* Python interface to CMSG_LEN(length). */ + +static PyObject * +socket_CMSG_LEN(PyObject *self, PyObject *args) +{ + Py_ssize_t length; + size_t result; + + if (!PyArg_ParseTuple(args, "n:CMSG_LEN", &length)) + return NULL; + if (length < 0 || !get_CMSG_LEN(length, &result)) { + PyErr_Format(PyExc_OverflowError, "CMSG_LEN() argument out of range"); + return NULL; + } + return PyInt_FromSize_t(result); +} + +PyDoc_STRVAR(CMSG_LEN_doc, +"CMSG_LEN(length) -> control message length\n\ +\n\ +Return the total length, without trailing padding, of an ancillary\n\ +data item with associated data of the given length. This value can\n\ +often be used as the buffer size for recvmsg() to receive a single\n\ +item of ancillary data, but RFC 3542 requires portable applications to\n\ +use CMSG_SPACE() and thus include space for padding, even when the\n\ +item will be the last in the buffer. Raises OverflowError if length\n\ +is outside the permissible range of values."); + + +#ifdef CMSG_SPACE +/* Python interface to CMSG_SPACE(length). */ + +static PyObject * +socket_CMSG_SPACE(PyObject *self, PyObject *args) +{ + Py_ssize_t length; + size_t result; + + if (!PyArg_ParseTuple(args, "n:CMSG_SPACE", &length)) + return NULL; + if (length < 0 || !get_CMSG_SPACE(length, &result)) { + PyErr_SetString(PyExc_OverflowError, + "CMSG_SPACE() argument out of range"); + return NULL; + } + return PyInt_FromSize_t(result); +} + +PyDoc_STRVAR(CMSG_SPACE_doc, +"CMSG_SPACE(length) -> buffer size\n\ +\n\ +Return the buffer size needed for recvmsg() to receive an ancillary\n\ +data item with associated data of the given length, along with any\n\ +trailing padding. The buffer space needed to receive multiple items\n\ +is the sum of the CMSG_SPACE() values for their associated data\n\ +lengths. Raises OverflowError if length is outside the permissible\n\ +range of values."); +#endif /* CMSG_SPACE */ +#endif /* CMSG_LEN */ + + /* List of functions exported by this module. */ static PyMethodDef socket_methods[] = { @@ -4338,6 +5088,14 @@ static PyMethodDef socket_methods[] = { METH_NOARGS, getdefaulttimeout_doc}, {"setdefaulttimeout", socket_setdefaulttimeout, METH_O, setdefaulttimeout_doc}, +#ifdef CMSG_LEN + {"CMSG_LEN", socket_CMSG_LEN, + METH_VARARGS, CMSG_LEN_doc}, +#ifdef CMSG_SPACE + {"CMSG_SPACE", socket_CMSG_SPACE, + METH_VARARGS, CMSG_SPACE_doc}, +#endif +#endif {NULL, NULL} /* Sentinel */ }; @@ -4836,6 +5594,15 @@ init_socket(void) #ifdef SO_SETFIB PyModule_AddIntConstant(m, "SO_SETFIB", SO_SETFIB); #endif +#ifdef SO_PASSCRED + PyModule_AddIntConstant(m, "SO_PASSCRED", SO_PASSCRED); +#endif +#ifdef SO_PEERCRED + PyModule_AddIntConstant(m, "SO_PEERCRED", SO_PEERCRED); +#endif +#ifdef LOCAL_PEERCRED + PyModule_AddIntConstant(m, "LOCAL_PEERCRED", LOCAL_PEERCRED); +#endif /* Maximum number of connections for "listen" */ #ifdef SOMAXCONN @@ -4844,6 +5611,17 @@ init_socket(void) PyModule_AddIntConstant(m, "SOMAXCONN", 5); /* Common value */ #endif + /* Ancilliary message types */ +#ifdef SCM_RIGHTS + PyModule_AddIntConstant(m, "SCM_RIGHTS", SCM_RIGHTS); +#endif +#ifdef SCM_CREDENTIALS + PyModule_AddIntConstant(m, "SCM_CREDENTIALS", SCM_CREDENTIALS); +#endif +#ifdef SCM_CREDS + PyModule_AddIntConstant(m, "SCM_CREDS", SCM_CREDS); +#endif + /* Flags for send, recv */ #ifdef MSG_OOB PyModule_AddIntConstant(m, "MSG_OOB", MSG_OOB); @@ -4875,6 +5653,33 @@ init_socket(void) #ifdef MSG_ETAG PyModule_AddIntConstant(m, "MSG_ETAG", MSG_ETAG); #endif +#ifdef MSG_NOSIGNAL + PyModule_AddIntConstant(m, "MSG_NOSIGNAL", MSG_NOSIGNAL); +#endif +#ifdef MSG_NOTIFICATION + PyModule_AddIntConstant(m, "MSG_NOTIFICATION", MSG_NOTIFICATION); +#endif +#ifdef MSG_CMSG_CLOEXEC + PyModule_AddIntConstant(m, "MSG_CMSG_CLOEXEC", MSG_CMSG_CLOEXEC); +#endif +#ifdef MSG_ERRQUEUE + PyModule_AddIntConstant(m, "MSG_ERRQUEUE", MSG_ERRQUEUE); +#endif +#ifdef MSG_CONFIRM + PyModule_AddIntConstant(m, "MSG_CONFIRM", MSG_CONFIRM); +#endif +#ifdef MSG_MORE + PyModule_AddIntConstant(m, "MSG_MORE", MSG_MORE); +#endif +#ifdef MSG_EOF + PyModule_AddIntConstant(m, "MSG_EOF", MSG_EOF); +#endif +#ifdef MSG_BCAST + PyModule_AddIntConstant(m, "MSG_BCAST", MSG_BCAST); +#endif +#ifdef MSG_MCAST + PyModule_AddIntConstant(m, "MSG_MCAST", MSG_MCAST); +#endif /* Protocol level and numbers, usable for [gs]etsockopt */ #ifdef SOL_SOCKET @@ -5014,6 +5819,9 @@ init_socket(void) #ifdef IPPROTO_VRRP PyModule_AddIntConstant(m, "IPPROTO_VRRP", IPPROTO_VRRP); #endif +#ifdef IPPROTO_SCTP + PyModule_AddIntConstant(m, "IPPROTO_SCTP", IPPROTO_SCTP); +#endif #ifdef IPPROTO_BIP PyModule_AddIntConstant(m, "IPPROTO_BIP", IPPROTO_BIP); #endif