Add sendmsg(), recvmsg() and recvmsg_into() methods to socket object. Based on a patch by Heiko Wundram. diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst --- a/Doc/library/socket.rst +++ b/Doc/library/socket.rst @@ -158,6 +158,7 @@ SOMAXCONN MSG_* SOL_* + SCM_* IPPROTO_* IPPORT_* INADDR_* @@ -445,6 +446,44 @@ Availability: Unix (maybe not all platforms). +.. function:: CMSG_LEN(length) + + Return the value which would be stored in the ``cmsg_len`` field of + a control message with associated data of the given *length*. This + value can often be used with :meth:`~socket.recvmsg` as the + ancillary data buffer size to receive a single control message + without trailing padding, but :rfc:`3542` requires portable + applications to use :func:`CMSG_SPACE` and include space for + padding, even when the message is the last in the buffer. Raises + :exc:`OverflowError` if *length* is outside the permissible range + of values. + + Availability: most Unix systems (XXX: other platforms?). + + .. versionadded:: XXX + + +.. function:: CMSG_SPACE(length) + + Return the number of bytes needed to hold a control message with + associated data of the given *length*, along with any trailing + padding needed. :rfc:`3542` specifies the ancillary data buffer + size needed to receive a set of control messages as the sum of the + :func:`CMSG_SPACE` values for the associated data items. Raises + :exc:`OverflowError` if *length* is outside the permissible range + of values. + + Note that some systems might support ancillary data without + providing this function. Also note that setting the buffer size + using the results of this function may not precisely limit the + amount of ancillary data that can be received, since additional + data may be able to fit into the padding area. + + Availability: most Unix systems (XXX: other platforms?). + + .. versionadded:: XXX + + .. function:: getdefaulttimeout() Return the default timeout in floating seconds for new socket objects. A value @@ -608,6 +647,103 @@ to zero. (The format of *address* depends on the address family --- see above.) +.. method:: socket.recvmsg(bufsize[, ancbufsize[, flags]]) + + Receive normal data (up to *bufsize* bytes) and ancillary data from + the socket. The *ancbufsize* argument sets the size in bytes of + the internal buffer used to receive the ancillary data; it defaults + to 0, meaning that no ancillary data will be received. Appropriate + buffer sizes for ancillary data can be calculated using + :func:`CMSG_LEN` or :func:`CMSG_SPACE`, and items which do not fit + into the buffer might be truncated or discarded. The *flags* + argument defaults to 0 and has the same meaning as for + :meth:`recv`. + + The return value is a 4-tuple: ``(data, ancdata, msg_flags, + address)``. The *data* item is a :class:`bytes` object holding the + non-ancillary data received. The *ancdata* item is a list of zero + or more tuples ``(cmsg_level, cmsg_type, cmsg_data)`` representing + the ancillary data (control messages) received: *cmsg_level* and + *cmsg_type* are integers specifying the protocol level and + protocol-specific type respectively, and *cmsg_data* is a + :class:`bytes` object holding the associated data. The *msg_flags* + item is the bitwise OR of various flags indicating conditions on + the received message; see your system documentation for details. + If the receiving socket is unconnected, *address* is the address of + the sending socket, if available; otherwise, its value is + unspecified. + + On some systems, :meth:`sendmsg` and :meth:`recvmsg` can be used to + pass file descriptors between processes over an :const:`AF_UNIX` + socket. When this facility is used (it is often restricted to + :const:`SOCK_STREAM` sockets), :meth:`recvmsg` will return, in its + ancillary data, items of the form ``(socket.SOL_SOCKET, + socket.SCM_RIGHTS, fds)``, where *fds* is a :class:`bytes` object + representing the new file descriptors as a binary array of the + native C :ctype:`int` type. If :meth:`recvmsg` encounters an error + after the system call returns, it will attempt to close any file + descriptors received via this mechanism before raising an + exception. + + On systems which support the :const:`SCM_RIGHTS` mechanism, the + following function will receive up to *maxfds* file descriptors, + returning the message data and a list containing the descriptors + (while ignoring unexpected conditions such as unrelated control + messages being received). See also :meth:`sendmsg`. :: + + import socket, array + + def recv_fds(sock, msglen, maxfds): + fds = array.array("i") # Array of ints + msg, ancdata, flags, addr = sock.recvmsg(msglen, socket.CMSG_LEN(maxfds * fds.itemsize)) + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if (cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS): + # Append data, ignoring any truncated integers at the end. + fds.fromstring(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + return msg, list(fds) + + Availability: most Unix systems (XXX: other platforms?). + + .. versionadded:: XXX + + +.. method:: socket.recvmsg_into(buffers[, ancbufsize[, flags]]) + + Receive normal data and ancillary data from the socket, behaving as + :meth:`recvmsg` would, but scatter the non-ancillary data into a + series of buffers instead of returning a new bytes object. The + *buffers* argument must be an iterable of objects that export + writable buffers (e.g. :class:`bytearray` objects); these will be + filled with successive chunks of the non-ancillary data until it + has all been written or there are no more buffers. The operating + system may set a limit (:func:`~os.sysconf` value ``SC_IOV_MAX``) + on the number of buffers that can be used. The *ancbufsize* and + *flags* arguments have the same meaning as for :meth:`recvmsg`. + + The return value is a 4-tuple: ``(nbytes, ancdata, msg_flags, + address)``, where *nbytes* is the total number of bytes of + non-ancillary data written into the buffers, and *ancdata*, + *msg_flags* and *address* are the same as for :meth:`recvmsg`. + + Example:: + + >>> import socket + >>> s1, s2 = socket.socketpair() + >>> b1 = bytearray(b'----') + >>> b2 = bytearray(b'0123456789') + >>> b3 = bytearray(b'--------------') + >>> s1.send(b'Mary had a little lamb') + 22 + >>> s2.recvmsg_into([b1, memoryview(b2)[2:9], b3]) + (22, [], 0, None) + >>> [b1, b2, b3] + [bytearray(b'Mary'), bytearray(b'01 had a 9'), bytearray(b'little lamb---')] + + Availability: most Unix systems (XXX: other platforms?). + + .. versionadded:: XXX + + .. method:: socket.recvfrom_into(buffer[, nbytes[, flags]]) Receive data from the socket, writing it into *buffer* instead of creating a @@ -655,6 +791,41 @@ above.) +.. method:: socket.sendmsg(buffers[, ancdata[, flags[, address]]]) + + Send normal and ancillary data to the socket, gathering the + non-ancillary data from a series of buffers and concatenating it + into a single message. The *buffers* argument specifies the + non-ancillary data as an iterable of buffer-compatible objects + (e.g. :class:`bytes` objects); the operating system may set a limit + (:func:`~os.sysconf` value ``SC_IOV_MAX``) on the number of buffers + that can be used. The *ancdata* argument specifies the ancillary + data (control messages) as an iterable of zero or more tuples + ``(cmsg_level, cmsg_type, cmsg_data)``, where *cmsg_level* and + *cmsg_type* are integers specifying the protocol level and + protocol-specific type respectively, and *cmsg_data* is a + buffer-compatible object holding the associated data. Note that + some systems (in particular, systems without :func:`CMSG_SPACE`) + might support sending only one control message per call. The + *flags* argument defaults to 0 and has the same meaning as for + :meth:`send`. If *address* is supplied and not ``None``, it sets a + destination address for the message. The return value is the + number of bytes of non-ancillary data sent. + + The following function sends the list of file descriptors *fds* + over an :const:`AF_UNIX` socket, on systems which support the + :const:`SCM_RIGHTS` mechanism. See also :meth:`recvmsg`. :: + + import socket, array + + def send_fds(sock, msg, fds): + return sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))]) + + Availability: most Unix systems (XXX: other platforms?). + + .. versionadded:: XXX + + .. method:: socket.setblocking(flag) Set blocking or non-blocking mode of the socket: if *flag* is 0, the socket is diff --git a/Lib/ssl.py b/Lib/ssl.py --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -223,6 +223,14 @@ else: return socket.sendto(self, data, addr, flags) + def sendmsg(self, *args, **kwargs): + self._checkClosed() + if self._sslobj: + raise ValueError("sendmsg not allowed on instances of %s" % + self.__class__) + else: + return socket.sendmsg(self, *args, **kwargs) + def sendall(self, data, flags=0): self._checkClosed() if self._sslobj: @@ -292,6 +300,22 @@ else: return socket.recvfrom_into(self, buffer, nbytes, flags) + def recvmsg(self, *args, **kwargs): + self._checkClosed() + if self._sslobj: + raise ValueError("recvmsg not allowed on instances of %s" % + self.__class__) + else: + return socket.recvmsg(self, *args, **kwargs) + + def recvmsg_into(self, *args, **kwargs): + self._checkClosed() + if self._sslobj: + raise ValueError("recvmsg_into not allowed on instances of %s" % + self.__class__) + else: + return socket.recvmsg_into(self, *args, **kwargs) + def pending(self): self._checkClosed() if self._sslobj: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -7,6 +7,8 @@ import io import socket import select +import tempfile +import _testcapi import _thread as thread import threading import time @@ -21,6 +23,9 @@ HOST = support.HOST MSG = b'Michael Gilfix was here\n' +# Size in bytes of the int type +SIZEOF_INT = array.array("i").itemsize + class SocketTCPTest(unittest.TestCase): def setUp(self): @@ -42,6 +47,25 @@ self.serv.close() self.serv = None +class ThreadSafeCleanupTestCase(unittest.TestCase): + """Subclass of unittest.TestCase with thread-safe cleanup methods. + + This subclass protects the addCleanup() and doCleanups() methods + with a recursive lock. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._cleanup_lock = threading.RLock() + + def addCleanup(self, *args, **kwargs): + with self._cleanup_lock: + return super().addCleanup(*args, **kwargs) + + def doCleanups(self, *args, **kwargs): + with self._cleanup_lock: + return super().doCleanups(*args, **kwargs) + class ThreadableTest: """Threadable Test class @@ -217,6 +241,175 @@ self.cli = None ThreadableTest.clientTearDown(self) +class SocketTestBase(unittest.TestCase): + + def setUp(self): + self.serv = self.newSocket() + self.bindServer() + + def bindServer(self): + self.bindSock(self.serv) + self.serv_addr = self.serv.getsockname() + + def tearDown(self): + self.serv.close() + self.serv = None + +class SocketListeningTestBase(SocketTestBase): + + def setUp(self): + super().setUp() + self.serv.listen(1) + +class ThreadedSocketTestBase(SocketTestBase, ThreadableTest): + + def __init__(self, *args, **kwargs): + SocketTestBase.__init__(self, *args, **kwargs) + ThreadableTest.__init__(self) + + def clientSetUp(self): + self.cli = self.newClientSocket() + self.bindClient() + + def newClientSocket(self): + return self.newSocket() + + def bindClient(self): + self.bindSock(self.cli) + self.cli_addr = self.cli.getsockname() + + def clientTearDown(self): + self.cli.close() + self.cli = None + ThreadableTest.clientTearDown(self) + +class ConnectedSocketTestBase(SocketListeningTestBase, ThreadedSocketTestBase): + + def setUp(self): + super().setUp() + # Indicate explicitly we're ready for the client thread to + # proceed and then perform the blocking call to accept + self.serverExplicitReady() + conn, addr = self.serv.accept() + self.cli_conn = conn + + def tearDown(self): + self.cli_conn.close() + self.cli_conn = None + super().tearDown() + + def clientSetUp(self): + super().clientSetUp() + self.cli.connect(self.serv_addr) + self.serv_conn = self.cli + + def clientTearDown(self): + self.serv_conn.close() + self.serv_conn = None + super().clientTearDown() + +class UnixSocketTestBase(ThreadSafeCleanupTestCase, SocketTestBase): + + def setUp(self): + self.dir_path = tempfile.mkdtemp() + self.addCleanup(os.rmdir, self.dir_path) + super().setUp() + + def bindSock(self, sock): + path = tempfile.mktemp(dir=self.dir_path) + sock.bind(path) + self.addCleanup(os.unlink, path) + +class UnixStreamBase(UnixSocketTestBase): + + def newSocket(self): + return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + +class UnixDgramBase(UnixSocketTestBase): + + def newSocket(self): + return socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + +class InetTestBase(SocketTestBase): + + host = HOST + + def setUp(self): + super().setUp() + self.port = self.serv_addr[1] + + def bindSock(self, sock): + support.bind_port(sock, host=self.host) + +class TCPTestBase(InetTestBase): + + def newSocket(self): + return socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +class UDPTestBase(InetTestBase): + + def newSocket(self): + return socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + +class SCTPStreamBase(InetTestBase): + + def newSocket(self): + return socket.socket(socket.AF_INET, socket.SOCK_STREAM, + socket.IPPROTO_SCTP) + +class Inet6TestBase(InetTestBase): + + # Don't use "localhost" here - it may not have an IPv6 address + # assigned to it by default (e.g. in /etc/hosts), and if someone + # has assigned it an IPv4-mapped address, then it's unlikely to + # work with the full IPv6 API. + host = "::1" + +class UDP6TestBase(Inet6TestBase): + + def newSocket(self): + return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + +def skipWithClientIf(condition, reason): + def client_pass(*args, **kwargs): + pass + def skipdec(obj): + retval = unittest.skip(reason)(obj) + if not isinstance(obj, type): + retval.client_skip = lambda f: client_pass + return retval + def noskipdec(obj): + if not (isinstance(obj, type) or hasattr(obj, "client_skip")): + obj.client_skip = lambda f: f + return obj + return skipdec if condition else noskipdec + +def requireAttrs(obj, *attrs): + missing = [name for name in attrs if not hasattr(obj, name)] + return skipWithClientIf( + missing, "don't have " + ", ".join(name for name in missing)) + +def requireSocket(*args): + err = None + missing = [obj for obj in args if + isinstance(obj, str) and not hasattr(socket, obj)] + if missing: + err = "don't have " + ", ".join(name for name in missing) + else: + callargs = [getattr(socket, obj) if isinstance(obj, str) else obj + for obj in args] + try: + s = socket.socket(*callargs) + except socket.error as e: + # XXX: check errno? + err = str(e) + else: + s.close() + return skipWithClientIf( + err is not None, + "can't create socket({0}): {1}".format( + ", ".join(str(o) for o in args), err)) + ####################################################################### ## Begin Tests @@ -677,6 +870,1101 @@ def _testRecvFromNegative(self): self.cli.sendto(MSG, 0, (HOST, self.port)) +class CmsgMacroTests(unittest.TestCase): + + # Match the definition in socketmodule.c + socklen_t_limit = min(0x7fffffff, _testcapi.INT_MAX) + + @requireAttrs(socket, "CMSG_LEN") + def testCMSG_LEN(self): + # Test CMSG_LEN() with various valid and invalid values, + # checking the assumptions used by recvmsg() and sendmsg(). + toobig = self.socklen_t_limit - socket.CMSG_LEN(0) + 1 + values = list(range(257)) + list(range(toobig - 257, toobig)) + + # struct cmsghdr has at least three members, two of which are ints + self.assertGreater(socket.CMSG_LEN(0), array.array("i").itemsize * 2) + for n in values: + ret = socket.CMSG_LEN(n) + # This is how recvmsg() calculates the data size + self.assertEqual(ret - socket.CMSG_LEN(0), n) + self.assertLessEqual(ret, self.socklen_t_limit) + + self.assertRaises(OverflowError, socket.CMSG_LEN, -1) + # sendmsg() shares code with these functions, and requires + # that it reject values over the limit. + self.assertRaises(OverflowError, socket.CMSG_LEN, toobig) + self.assertRaises(OverflowError, socket.CMSG_LEN, sys.maxsize) + + @requireAttrs(socket, "CMSG_SPACE") + def testCMSG_SPACE(self): + # Test CMSG_SPACE() with various valid and invalid values, + # checking the assumptions used by sendmsg(). + toobig = self.socklen_t_limit - socket.CMSG_SPACE(1) + 1 + values = list(range(257)) + list(range(toobig - 257, toobig)) + + last = socket.CMSG_SPACE(0) + # struct cmsghdr has at least three members, two of which are ints + self.assertGreater(last, array.array("i").itemsize * 2) + for n in values: + ret = socket.CMSG_SPACE(n) + self.assertGreaterEqual(ret, last) + self.assertGreaterEqual(ret, socket.CMSG_LEN(n)) + self.assertGreaterEqual(ret, n + socket.CMSG_LEN(0)) + self.assertLessEqual(ret, self.socklen_t_limit) + last = ret + + self.assertRaises(OverflowError, socket.CMSG_SPACE, -1) + # sendmsg() shares code with these functions, and requires + # that it reject values over the limit. + self.assertRaises(OverflowError, socket.CMSG_SPACE, toobig) + self.assertRaises(OverflowError, socket.CMSG_SPACE, sys.maxsize) + +# XXX: like the other datagram (UDP) tests in this module, the code +# here assumes that datagram delivery on the local machine will be +# reliable. +class SendrecvmsgBase(ThreadSafeCleanupTestCase): + + # Time in seconds to wait before considering a test failed, or + # None for no timeout. Not all the tests actually set a timeout. + fail_timeout = 3.0 + + def setUp(self): + self.misc_event = threading.Event() + super().setUp() + + def sendToServer(self, msg): + return self.cli_sock.send(msg) + + sendmsg_to_server_defaults = () + + def sendmsgToServer(self, *args): + return self.cli_sock.sendmsg( + *(args + self.sendmsg_to_server_defaults[len(args):])) + + def registerRecvmsgResult(self, result): + pass + + def checkRecvmsgAddress(self, addr1, addr2): + self.assertEqual(addr1, addr2) + + # Flags that are normally unset in msg_flags + msg_flags_common_unset = 0 + for name in ("MSG_CTRUNC", "MSG_OOB"): + msg_flags_common_unset |= getattr(socket, name, 0) + + # Flags that are normally set + msg_flags_common_set = 0 + + # Flags set when a complete record has been received (e.g. MSG_EOR + # for SCTP) + msg_flags_eor_indicator = 0 + + # Flags set when a complete record has not been received (e.g. MSG_TRUNC + # for datagram sockets) + msg_flags_non_eor_indicator = 0 + + class TestError(Exception): + pass + + def checkFlags(self, flags, eor=None, checkset=0, checkunset=0, ignore=0): + defaultset = self.msg_flags_common_set + defaultunset = self.msg_flags_common_unset + if eor: + defaultset |= self.msg_flags_eor_indicator + defaultunset |= self.msg_flags_non_eor_indicator + elif eor is not None: + defaultset |= self.msg_flags_non_eor_indicator + defaultunset |= self.msg_flags_eor_indicator + defaultset &= ~checkunset + defaultunset &= ~checkset + checkset |= defaultset + checkunset |= defaultunset + inboth = checkset & checkunset & ~ignore + if inboth: + raise self.TestError("contradictory set, unset requirements for " + "flags {0:#x}".format(inboth)) + mask = (checkset | checkunset) & ~ignore + self.assertEqual(flags & mask, checkset & mask) + +class SendrecvmsgDgramFlagsBase(SendrecvmsgBase): + + @property + def msg_flags_non_eor_indicator(self): + return super().msg_flags_non_eor_indicator | socket.MSG_TRUNC + +class SendrecvmsgSCTPFlagsBase(SendrecvmsgBase): + + @property + def msg_flags_eor_indicator(self): + return super().msg_flags_eor_indicator | socket.MSG_EOR + +class SendrecvmsgConnectionlessBase(SendrecvmsgBase): + + @property + def serv_sock(self): + return self.serv + + @property + def cli_sock(self): + return self.cli + + @property + def sendmsg_to_server_defaults(self): + return ([], [], 0, self.serv_addr) + + def sendToServer(self, msg): + return self.cli_sock.sendto(msg, self.serv_addr) + +class SendrecvmsgDgramBase(SendrecvmsgDgramFlagsBase, + SendrecvmsgConnectionlessBase): + pass + +class SendrecvmsgConnectedBase(SendrecvmsgBase): + + @property + def serv_sock(self): + return self.cli_conn + + @property + def cli_sock(self): + return self.serv_conn + + def checkRecvmsgAddress(self, addr1, addr2): + # Address is currently "unspecified" for a connected socket, + # so we don't examine it + pass + +class SendrecvmsgSCTPStreamBase(SendrecvmsgSCTPFlagsBase, + SendrecvmsgConnectedBase): + pass + +class SendrecvmsgServerTimeoutBase(SendrecvmsgBase): + + def setUp(self): + super().setUp() + self.serv_sock.settimeout(self.fail_timeout) + +class SendmsgTests(SendrecvmsgServerTimeoutBase): + + def testSendmsg(self): + # Send a simple message with sendmsg() + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsg(self): + self.assertEqual(self.sendmsgToServer([MSG]), len(MSG)) + + def testSendmsgDataGenerator(self): + # Send from buffer obtained from a generator (not a sequence) + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgDataGenerator(self): + self.assertEqual(self.sendmsgToServer((o for o in [MSG])), + len(MSG)) + + def testSendmsgAncillaryGenerator(self): + # Gather (empty) ancillary data from a generator + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgAncillaryGenerator(self): + self.assertEqual(self.sendmsgToServer([MSG], (o for o in [])), + len(MSG)) + + def testSendmsgArray(self): + # Send data from an array instead of the usual bytes object + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgArray(self): + self.assertEqual(self.sendmsgToServer([array.array("b", MSG)]), + len(MSG)) + + def testSendmsgGather(self): + # Send message data from more than one buffer (gather write) + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgGather(self): + self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG)) + + def testSendmsgBadArgs(self): + # Test for sendmsg() rejecting invalid arguments + self.assertEqual(self.serv_sock.recv(1000), b"done") + + def _testSendmsgBadArgs(self): + self.assertRaises(TypeError, self.cli_sock.sendmsg) + self.assertRaises(TypeError, self.sendmsgToServer, + b"not in an iterable") + self.assertRaises(TypeError, self.sendmsgToServer, + object()) + self.assertRaises(TypeError, self.sendmsgToServer, + [object()]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG, object()]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], object()) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [], object()) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [], 0, object()) + self.sendToServer(b"done") + + def testSendmsgBadCmsg(self): + # Check that invalid ancillary data items are rejected + self.assertEqual(self.serv_sock.recv(1000), b"done") + + def _testSendmsgBadCmsg(self): + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [object()]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(object(), 0, b"data")]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, object(), b"data")]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, 0, object())]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, 0)]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, 0, b"data", 42)]) + self.sendToServer(b"done") + + @requireAttrs(socket, "CMSG_SPACE") + def testSendmsgBadMultiCmsg(self): + # Check that invalid ancillary data items are rejected when + # more than one item is present + self.assertEqual(self.serv_sock.recv(1000), b"done") + + @testSendmsgBadMultiCmsg.client_skip + def _testSendmsgBadMultiCmsg(self): + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [0, 0, b""]) + self.assertRaises(TypeError, self.sendmsgToServer, + [MSG], [(0, 0, b""), object()]) + self.sendToServer(b"done") + + def testSendmsgExcessCmsgReject(self): + # Check that sendmsg() rejects excess ancillary data items + # when the number that can be sent is limited. + self.assertEqual(self.serv_sock.recv(1000), b"done") + + def _testSendmsgExcessCmsgReject(self): + if not hasattr(socket, "CMSG_SPACE"): + # Can only send one item + with self.assertRaises(socket.error) as cm: + self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")]) + self.assertIsNone(cm.exception.errno) + self.sendToServer(b"done") + + def testSendmsgAfterClose(self): + # Check that sendmsg() on a closed socket fails + pass + + def _testSendmsgAfterClose(self): + self.cli_sock.close() + self.assertRaises(socket.error, self.sendmsgToServer, [MSG]) + +class SendmsgStreamTests(SendmsgTests): + + def testSendmsgExplicitNoneAddr(self): + # Check that peer address can be specified as None + self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) + + def _testSendmsgExplicitNoneAddr(self): + self.assertEqual(self.sendmsgToServer([MSG], [], 0, None), len(MSG)) + + def testSendmsgTimeout(self): + # Check that timeout works with sendmsg() + self.assertEqual(self.serv_sock.recv(512), b"a"*512) + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + + def _testSendmsgTimeout(self): + try: + self.cli_sock.settimeout(0.03) + with self.assertRaises(socket.timeout): + while True: + self.sendmsgToServer([b"a"*512]) + finally: + self.misc_event.set() + + # XXX: would be nice to have more tests for sendmsg flags + # argument. MSG_DONTWAIT is easy enough to test, but it's not + # defined by POSIX (Linux, BSD and OS X do have it, apparently). + + @requireAttrs(socket, "MSG_DONTWAIT") + def testSendmsgDontWait(self): + # Check that MSG_DONTWAIT in flags causes non-blocking behaviour + self.assertEqual(self.serv_sock.recv(512), b"a"*512) + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + + @testSendmsgDontWait.client_skip + def _testSendmsgDontWait(self): + try: + with self.assertRaises(socket.error) as cm: + while True: + self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT) + self.assertIn(cm.exception.errno, + (errno.EAGAIN, errno.EWOULDBLOCK)) + finally: + self.misc_event.set() + +class SendmsgConnectionlessTests(SendmsgTests): + + def testSendmsgNoDestAddr(self): + # Check that sendmsg() fails without destination address + pass + + def _testSendmsgNoDestAddr(self): + self.assertRaises(socket.error, self.cli_sock.sendmsg, + [MSG]) + self.assertRaises(socket.error, self.cli_sock.sendmsg, + [MSG], [], 0, None) + +class RecvmsgMixin(SendrecvmsgBase): + + def doRecvmsg(self, sock, bufsize, *args): + result = sock.recvmsg(bufsize, *args) + self.registerRecvmsgResult(result) + return result + +class RecvmsgIntoMixin(SendrecvmsgBase): + + def doRecvmsg(self, sock, bufsize, *args): + # Simulate recvmsg() using recvmsg_into(), so we can use tests + # for the former on the latter as well. + buf = bytearray(bufsize) + result = sock.recvmsg_into([buf], *args) + self.registerRecvmsgResult(result) + self.assertGreaterEqual(result[0], 0) + self.assertLessEqual(result[0], bufsize) + return (bytes(buf[:result[0]]),) + result[1:] + +class RecvmsgGenericTests(SendrecvmsgBase): + + def testRecvmsg(self): + # Receive a simple message with recvmsg[_into]() + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG)) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsg(self): + self.sendToServer(MSG) + + def testRecvmsgExplicitDefaults(self): + # Test recvmsg[_into]() with default arguments provided explicitly + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 0, 0) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgExplicitDefaults(self): + self.sendToServer(MSG) + + def testRecvmsgShorter(self): + # Receive a message smaller than buffer + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG) + 42) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgShorter(self): + self.sendToServer(MSG) + + def testRecvmsgTrunc(self): + # Receive part of message, check for truncation indicators + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG) - 3) + self.assertEqual(msg, MSG[:-3]) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=False) + + def _testRecvmsgTrunc(self): + self.sendToServer(MSG) + + def testRecvmsgShortAncillaryBuf(self): + # Test ancillary data buffer too small to hold any ancillary data + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 1) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgShortAncillaryBuf(self): + self.sendToServer(MSG) + + def testRecvmsgLongAncillaryBuf(self): + # Test large ancillary data buffer + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 10240) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgLongAncillaryBuf(self): + self.sendToServer(MSG) + + def testRecvmsgAfterClose(self): + # Check that recvmsg[_into]() on a closed socket fails + self.serv_sock.close() + self.assertRaises(socket.error, self.doRecvmsg, self.serv_sock, 1024) + + def _testRecvmsgAfterClose(self): + pass + + def testRecvmsgTimeout(self): + # Check that timeout works + try: + self.serv_sock.settimeout(0.03) + self.assertRaises(socket.timeout, + self.doRecvmsg, self.serv_sock, len(MSG)) + finally: + self.misc_event.set() + + def _testRecvmsgTimeout(self): + self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout)) + + @requireAttrs(socket, "MSG_PEEK") + def testRecvmsgPeek(self): + # Check that MSG_PEEK in flags enables examination of pending + # data without consuming it + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG) - 3, 0, + socket.MSG_PEEK) + self.assertEqual(msg, MSG[:-3]) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + # Ignoring MSG_TRUNC here (so this test is the same for stream + # and datagram sockets). Some wording in POSIX seems to + # suggest that it needn't be set when peeking, but that may + # just be a slip. + self.checkFlags(flags, eor=False, + ignore=getattr(socket, "MSG_TRUNC", 0)) + + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 0, + socket.MSG_PEEK) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG)) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + @testRecvmsgPeek.client_skip + def _testRecvmsgPeek(self): + self.sendToServer(MSG) + + @requireAttrs(socket.socket, "sendmsg") + def testRecvmsgFromSendmsg(self): + # Test receiving with recvmsg[_into]() a simple message sent + # using sendmsg() + self.serv_sock.settimeout(self.fail_timeout) + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG)) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + @testRecvmsgFromSendmsg.client_skip + def _testRecvmsgFromSendmsg(self): + self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG)) + +class RecvmsgGenericStreamTests(RecvmsgGenericTests): + + def testRecvmsgEOF(self): + # Receive end-of-stream indicator (b"", peer socket closed) + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024) + self.assertEqual(msg, b"") + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=None) # Might not have end-of-record marker + + def _testRecvmsgEOF(self): + self.cli_sock.close() + + def testRecvmsgOverflow(self): + # Receive a message in more than one chunk + seg1, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG) - 3) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=False) + + seg2, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + msg = seg1 + seg2 + self.assertEqual(msg, MSG) + + def _testRecvmsgOverflow(self): + self.sendToServer(MSG) + +class RecvmsgTests(RecvmsgMixin, RecvmsgGenericTests): + + def testRecvmsgBadArgs(self): + # Test for recvmsg() rejecting invalid arguments + self.assertRaises(TypeError, self.serv_sock.recvmsg) + self.assertRaises(ValueError, self.serv_sock.recvmsg, + -1, 0, 0) + self.assertRaises(ValueError, self.serv_sock.recvmsg, + len(MSG), -1, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg, + [bytearray(10)], 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg, + object(), 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg, + len(MSG), object(), 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg, + len(MSG), 0, object()) + + msg, ancdata, flags, addr = self.serv_sock.recvmsg(len(MSG), 0, 0) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgBadArgs(self): + self.sendToServer(MSG) + +class RecvmsgIntoTests(RecvmsgIntoMixin, RecvmsgGenericTests): + + def testRecvmsgIntoBadArgs(self): + # Test for recvmsg_into() rejecting invalid arguments + buf = bytearray(len(MSG)) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + len(MSG), 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + buf, 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [object()], 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [b"I'm not writable"], 0, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [buf, object()], 0, 0) + self.assertRaises(ValueError, self.serv_sock.recvmsg_into, + [buf], -1, 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [buf], object(), 0) + self.assertRaises(TypeError, self.serv_sock.recvmsg_into, + [buf], 0, object()) + + nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf], 0, 0) + self.assertEqual(nbytes, len(MSG)) + self.assertEqual(buf, bytearray(MSG)) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgIntoBadArgs(self): + self.sendToServer(MSG) + + def testRecvmsgIntoGenerator(self): + # Receive into buffer obtained from a generator (not a sequence) + buf = bytearray(len(MSG)) + nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into( + (o for o in [buf])) + self.assertEqual(nbytes, len(MSG)) + self.assertEqual(buf, bytearray(MSG)) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgIntoGenerator(self): + self.sendToServer(MSG) + + def testRecvmsgIntoArray(self): + # Receive into an array rather than the usual bytearray + buf = array.array("b", bytes(len(MSG))) + nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf]) + self.assertEqual(nbytes, len(MSG)) + self.assertEqual(buf.tostring(), MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgIntoArray(self): + self.sendToServer(MSG) + + def testRecvmsgIntoScatter(self): + # Receive into multiple buffers (scatter write) + b1 = bytearray(b"----") + b2 = bytearray(b"0123456789") + b3 = bytearray(b"--------------") + nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into( + [b1, memoryview(b2)[2:9], b3]) + self.assertEqual(nbytes, len(b"Mary had a little lamb")) + self.assertEqual(b1, bytearray(b"Mary")) + self.assertEqual(b2, bytearray(b"01 had a 9")) + self.assertEqual(b3, bytearray(b"little lamb---")) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True) + + def _testRecvmsgIntoScatter(self): + self.sendToServer(b"Mary had a little lamb") + +class SCMRightsTest(SendrecvmsgServerTimeoutBase): + + # Invalid file descriptor value that's unlikely to evaluate to a + # real FD even if one of its bytes is replaced with a different + # value (which shouldn't actually happen) + badfd = -0x5555 + + def newFDs(self, n): + fds = [] + for i in range(n): + fd, path = tempfile.mkstemp() + self.addCleanup(os.unlink, path) + self.addCleanup(os.close, fd) + os.write(fd, str(i).encode()) + fds.append(fd) + return fds + + def checkFDs(self, fds): + for n, fd in enumerate(fds): + os.lseek(fd, 0, os.SEEK_SET) + self.assertEqual(os.read(fd, 1024), str(n).encode()) + + def registerRecvmsgResult(self, result): + self.addCleanup(self.closeRecvmsgFDs, result) + + def closeRecvmsgFDs(self, recvmsg_result): + for cmsg_level, cmsg_type, cmsg_data in recvmsg_result[1]: + if (cmsg_level == socket.SOL_SOCKET and + cmsg_type == socket.SCM_RIGHTS): + fds = array.array("i") + fds.fromstring(cmsg_data[: + len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + for fd in fds: + os.close(fd) + + def createAndSendFDs(self, count): + self.assertEqual( + self.sendmsgToServer([MSG], + [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", self.newFDs(count)))]), + len(MSG)) + + def checkRecvmsgFDs(self, numfds, result): + msg, ancdata, flags, addr = result + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC) + self.assertEqual(len(ancdata), 1) + self.assertIsInstance(ancdata[0], tuple) + cmsg_type, cmsg_level, cmsg_data = ancdata[0] + self.assertEqual(cmsg_type, socket.SOL_SOCKET) + self.assertEqual(cmsg_level, socket.SCM_RIGHTS) + self.assertIsInstance(cmsg_data, bytes) + self.assertEqual(len(cmsg_data) % SIZEOF_INT, 0) + fds = array.array("i") + fds.fromstring(cmsg_data) + self.assertEqual(len(fds), numfds) + self.checkFDs(fds) + + def testFDPassSimple(self): + # Pass a single FD (array read from bytes object) + self.checkRecvmsgFDs(1, self.doRecvmsg(self.serv_sock, + len(MSG), 10240)) + + def _testFDPassSimple(self): + self.assertEqual( + self.sendmsgToServer( + [MSG], + [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", self.newFDs(1)).tostring())]), + len(MSG)) + + def testMultipleFDPass(self): + # Pass multiple FDs in a single array + self.checkRecvmsgFDs(4, self.doRecvmsg(self.serv_sock, + len(MSG), 10240)) + + def _testMultipleFDPass(self): + self.createAndSendFDs(4) + + def testFDPassCMSG_LEN(self): + # Test using CMSG_LEN() to calculate ancillary buffer size + self.checkRecvmsgFDs( + 4, self.doRecvmsg(self.serv_sock, len(MSG), + socket.CMSG_LEN(4 * SIZEOF_INT))) + + def _testFDPassCMSG_LEN(self): + self.createAndSendFDs(4) + + @requireAttrs(socket, "CMSG_SPACE") + def testFDPassCMSG_SPACE(self): + # Test using CMSG_SPACE() to calculate ancillary buffer size + self.checkRecvmsgFDs( + 4, self.doRecvmsg(self.serv_sock, len(MSG), + socket.CMSG_SPACE(4 * SIZEOF_INT))) + + @testFDPassCMSG_SPACE.client_skip + def _testFDPassCMSG_SPACE(self): + self.createAndSendFDs(4) + + @requireAttrs(socket, "CMSG_SPACE") + def testFDPassSeparate(self): + # Pass two FDs in two separate arrays + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 10240) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC) + self.assertLessEqual(len(ancdata), 2) + fds = array.array("i") + # Arrays may have been combined in a single control message + for cmsg_type, cmsg_level, cmsg_data in ancdata: + self.assertEqual(cmsg_type, socket.SOL_SOCKET) + self.assertEqual(cmsg_level, socket.SCM_RIGHTS) + self.assertEqual(len(cmsg_data) % SIZEOF_INT, 0) + fds.fromstring(cmsg_data) + self.assertEqual(len(fds), 2) + self.checkFDs(fds) + + @testFDPassSeparate.client_skip + def _testFDPassSeparate(self): + fd0, fd1 = self.newFDs(2) + self.assertEqual( + self.sendmsgToServer([MSG], [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd0])), + (socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd1]))]), + len(MSG)) + + def sendAncillaryIfPossible(self, msg, ancdata): + # Try to send msg and ancdata to server, but if the system + # call fails, just send msg with no ancillary data. + try: + nbytes = self.sendmsgToServer([msg], ancdata) + except socket.error as e: + # Check that it was the system call that failed + self.assertIsInstance(e.errno, int) + nbytes = self.sendmsgToServer([msg]) + self.assertEqual(nbytes, len(msg)) + + def testFDPassEmpty(self): + # Try to pass an empty FD array + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 10240) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC) + self.assertIn(ancdata, ([], + [(socket.SOL_SOCKET, socket.SCM_RIGHTS, b"")])) + + def _testFDPassEmpty(self): + self.sendAncillaryIfPossible(MSG, [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + b"")]) + + def testFDPassPartialInt(self): + # Try to pass a truncated FD array + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 10240) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC) + self.assertLessEqual(len(ancdata), 1) + for cmsg_type, cmsg_level, cmsg_data in ancdata: + self.assertEqual(cmsg_type, socket.SOL_SOCKET) + self.assertEqual(cmsg_level, socket.SCM_RIGHTS) + self.assertLess(len(cmsg_data), SIZEOF_INT) + + def _testFDPassPartialInt(self): + self.sendAncillaryIfPossible( + MSG, + [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [self.badfd]).tostring()[:-1])]) + + @requireAttrs(socket, "CMSG_SPACE") + def testFDPassPartialIntInMiddle(self): + # Try to pass two FD arrays, the first of which is truncated + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), 10240) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC) + self.assertLessEqual(len(ancdata), 2) + fds = array.array("i") + # Arrays may have been combined in a single control message + for cmsg_type, cmsg_level, cmsg_data in ancdata: + self.assertEqual(cmsg_type, socket.SOL_SOCKET) + self.assertEqual(cmsg_level, socket.SCM_RIGHTS) + fds.fromstring(cmsg_data[: + len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + self.assertLessEqual(len(fds), 2) + self.checkFDs(fds) + + @testFDPassPartialIntInMiddle.client_skip + def _testFDPassPartialIntInMiddle(self): + fd0, fd1 = self.newFDs(2) + self.sendAncillaryIfPossible( + MSG, + [(socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd0, self.badfd]).tostring()[:-1]), + (socket.SOL_SOCKET, + socket.SCM_RIGHTS, + array.array("i", [fd1]))]) + + def checkTruncatedHeader(self, ancbuf): + # Check that no ancillary data items are returned when data is + # truncated inside the cmsghdr structure + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), ancbuf) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.assertEqual(ancdata, []) + self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC) + + def testCmsgTrunc0(self): + self.checkTruncatedHeader(0) + + def _testCmsgTrunc0(self): + self.createAndSendFDs(1) + + def testCmsgTrunc1(self): + self.checkTruncatedHeader(1) + + def _testCmsgTrunc1(self): + self.createAndSendFDs(1) + + def testCmsgTrunc2Int(self): + # The cmsghdr structure has at least three members, two of + # which are ints, so we shouldn't see any ancillary data at + # all here. + self.checkTruncatedHeader(SIZEOF_INT * 2) + + def _testCmsgTrunc2Int(self): + self.createAndSendFDs(1) + + # The following tests try to truncate the control message in the + # middle of the FD array. The cmsghdr structure has three + # members, of which two have type int, so at least three int-sized + # slots will be lost to the header. + + def checkTruncatedArray(self, ancbuf, fds_sent): + msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, + len(MSG), ancbuf) + self.assertEqual(msg, MSG) + self.checkRecvmsgAddress(addr, self.cli_addr) + self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC) + self.assertLessEqual(len(ancdata), 1) + # The OS may simply not return any message if the array doesn't fit + if len(ancdata) == 0: + return + + # If it does, check that the array is indeed truncated + cmsg_type, cmsg_level, cmsg_data = ancdata[0] + self.assertEqual(cmsg_type, socket.SOL_SOCKET) + self.assertEqual(cmsg_level, socket.SCM_RIGHTS) + self.assertIsInstance(cmsg_data, bytes) + fds = array.array("i") + fds.fromstring(cmsg_data[: + len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + self.assertLess(len(fds), fds_sent) + self.checkFDs(fds) + + # These should truncate in the middle of the array provided the + # header occupies no more than 4 ints. + + def testCmsgTrunc4IntPlus1(self): + self.checkTruncatedArray(ancbuf=(SIZEOF_INT * 4) + 1, fds_sent=3) + + def _testCmsgTrunc4IntPlus1(self): + self.createAndSendFDs(3) + + def testCmsgTrunc5Int(self): + self.checkTruncatedArray(ancbuf=SIZEOF_INT * 5, fds_sent=3) + + def _testCmsgTrunc5Int(self): + self.createAndSendFDs(3) + + # These should truncate in the middle of the array provided the + # header occupies no more than 8 ints. + + def testCmsgTrunc8IntPlus1(self): + self.checkTruncatedArray(ancbuf=(SIZEOF_INT * 8) + 1, fds_sent=7) + + def _testCmsgTrunc8IntPlus1(self): + self.createAndSendFDs(7) + + def testCmsgTrunc9Int(self): + self.checkTruncatedArray(ancbuf=SIZEOF_INT * 9, fds_sent=7) + + def _testCmsgTrunc9Int(self): + self.createAndSendFDs(7) + + def testCmsgTruncLen0(self): + self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0), fds_sent=1) + + def _testCmsgTruncLen0(self): + self.createAndSendFDs(1) + + def testCmsgTruncLen0Plus1(self): + self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0) + 1, fds_sent=2) + + def _testCmsgTruncLen0Plus1(self): + self.createAndSendFDs(2) + + def testCmsgTruncLen1(self): + self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(SIZEOF_INT), + fds_sent=2) + + def _testCmsgTruncLen1(self): + self.createAndSendFDs(2) + +@requireAttrs(socket.socket, "sendmsg") +class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgDgramBase, + UDPTestBase, ThreadedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +class RecvmsgUDPTest(RecvmsgTests, RecvmsgGenericTests, SendrecvmsgDgramBase, + UDPTestBase, ThreadedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +class RecvmsgIntoUDPTest(RecvmsgIntoTests, RecvmsgGenericTests, + SendrecvmsgDgramBase, UDPTestBase, + ThreadedSocketTestBase): + pass + +# XXX TODO: test receiving RFC 3542 ancillary data items for IPv6 UDP +# (perhaps sending the simple ones too). + +@requireAttrs(socket.socket, "sendmsg") +@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support") +@requireSocket("AF_INET6", "SOCK_DGRAM") +class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgDgramBase, + UDP6TestBase, ThreadedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support") +@requireSocket("AF_INET6", "SOCK_DGRAM") +class RecvmsgUDP6Test(RecvmsgTests, RecvmsgGenericTests, SendrecvmsgDgramBase, + UDP6TestBase, ThreadedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support") +@requireSocket("AF_INET6", "SOCK_DGRAM") +class RecvmsgIntoUDP6Test(RecvmsgIntoTests, RecvmsgGenericTests, + SendrecvmsgDgramBase, UDP6TestBase, + ThreadedSocketTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgConnectedBase, + TCPTestBase, ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests, + SendrecvmsgConnectedBase, TCPTestBase, + ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, + SendrecvmsgConnectedBase, TCPTestBase, + ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgConnectedBase, + SCTPStreamBase, ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, + SendrecvmsgConnectedBase, SCTPStreamBase, + ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") +class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, + SendrecvmsgConnectedBase, SCTPStreamBase, + ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +@requireAttrs(socket, "AF_UNIX") +class SendmsgUnixDgramTest(SendmsgConnectionlessTests, SendrecvmsgDgramBase, + UnixDgramBase, ThreadedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@requireAttrs(socket, "AF_UNIX") +class RecvmsgUnixDgramTest(RecvmsgTests, RecvmsgGenericTests, + SendrecvmsgDgramBase, UnixDgramBase, + ThreadedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@requireAttrs(socket, "AF_UNIX") +class RecvmsgIntoUnixDgramTest(RecvmsgIntoTests, RecvmsgGenericTests, + SendrecvmsgDgramBase, UnixDgramBase, + ThreadedSocketTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg") +@requireAttrs(socket, "AF_UNIX") +class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgConnectedBase, + UnixStreamBase, ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg") +@requireAttrs(socket, "AF_UNIX") +class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, + SendrecvmsgConnectedBase, UnixStreamBase, + ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "recvmsg_into") +@requireAttrs(socket, "AF_UNIX") +class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, + SendrecvmsgConnectedBase, UnixStreamBase, + ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg", "recvmsg") +@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS") +class RecvmsgSCMRightsStreamTest(RecvmsgMixin, SCMRightsTest, + SendrecvmsgConnectedBase, UnixStreamBase, + ConnectedSocketTestBase): + pass + +@requireAttrs(socket.socket, "sendmsg", "recvmsg_into") +@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS") +class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest, + SendrecvmsgConnectedBase, UnixStreamBase, + ConnectedSocketTestBase): + pass + class TCPCloserTest(ThreadedTCPSocketTest): def testClose(self): @@ -1430,6 +2718,29 @@ NetworkConnectionAttributesTest, NetworkConnectionBehaviourTest, ]) + tests.extend([ + CmsgMacroTests, + SendmsgUDPTest, + RecvmsgUDPTest, + RecvmsgIntoUDPTest, + SendmsgUDP6Test, + RecvmsgUDP6Test, + RecvmsgIntoUDP6Test, + SendmsgTCPTest, + RecvmsgTCPTest, + RecvmsgIntoTCPTest, + SendmsgSCTPStreamTest, + RecvmsgSCTPStreamTest, + RecvmsgIntoSCTPStreamTest, + SendmsgUnixDgramTest, + RecvmsgUnixDgramTest, + RecvmsgIntoUnixDgramTest, + SendmsgUnixStreamTest, + RecvmsgUnixStreamTest, + RecvmsgIntoUnixStreamTest, + RecvmsgSCMRightsStreamTest, + RecvmsgIntoSCMRightsStreamTest, + ]) if hasattr(socket, "socketpair"): tests.append(BasicSocketPairTest) if sys.platform == 'linux2': diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -1078,17 +1078,30 @@ count, addr = s.recvfrom_into(b) return b[:count] + def _recvmsg(*args, **kwargs): + return s.recvmsg(*args, **kwargs)[0] + + def _recvmsg_into(bufsize, *args, **kwargs): + b = bytearray(bufsize) + return bytes(b[:s.recvmsg_into([b], *args, **kwargs)[0]]) + + def _sendmsg(msg, *args, **kwargs): + return s.sendmsg([msg]) + # (name, method, whether to expect success, *args) send_methods = [ ('send', s.send, True, []), ('sendto', s.sendto, False, ["some.address"]), + ('sendmsg', _sendmsg, False, []), ('sendall', s.sendall, True, []), ] recv_methods = [ ('recv', s.recv, True, []), ('recvfrom', s.recvfrom, False, ["some.address"]), + ('recvmsg', _recvmsg, False, [100]), ('recv_into', _recv_into, True, []), ('recvfrom_into', _recvfrom_into, False, []), + ('recvmsg_into', _recvmsg_into, False, [100]), ] data_prefix = "PREFIX_" diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c --- a/Modules/socketmodule.c +++ b/Modules/socketmodule.c @@ -249,6 +249,7 @@ #ifdef HAVE_SYS_TYPES_H #include #endif +#include /* Generic socket object definitions and includes */ #define PySocket_BUILDING_SOCKET @@ -456,6 +457,17 @@ #include #endif +/* Largest value to try to store in a socklen_t (used when handling + ancillary data). POSIX requires socklen_t to hold at least + (2**31)-1 and recommends against storing larger values, but + socklen_t was originally int in the BSD interface, so to be on the + safe side we use the smaller of (2**31)-1 and INT_MAX. */ +#if INT_MAX > 0x7fffffff +#define SOCKLEN_T_LIMIT 0x7fffffff +#else +#define SOCKLEN_T_LIMIT INT_MAX +#endif + #ifdef Py_SOCKET_FD_CAN_BE_GE_FD_SETSIZE /* Platform can select file descriptors beyond FD_SETSIZE */ #define IS_SELECTABLE(s) 1 @@ -1571,6 +1583,100 @@ } +#ifdef CMSG_LEN +/* If length is in range, set *result to CMSG_LEN(length) and return + true; otherwise, return false. */ +static int +get_CMSG_LEN(size_t length, size_t *result) +{ + size_t tmp; + + if (length > (SOCKLEN_T_LIMIT - CMSG_LEN(0))) + return 0; + tmp = CMSG_LEN(length); + if (tmp > SOCKLEN_T_LIMIT || tmp < length) + return 0; + *result = tmp; + return 1; +} + +#ifdef CMSG_SPACE +/* If length is in range, set *result to CMSG_SPACE(length) and return + true; otherwise, return false. */ +static int +get_CMSG_SPACE(size_t length, size_t *result) +{ + size_t tmp; + + /* Use CMSG_SPACE(1) here in order to take account of the + padding necessary before *and* after the data. */ + if (length > (SOCKLEN_T_LIMIT - CMSG_SPACE(1))) + return 0; + tmp = CMSG_SPACE(length); + if (tmp > SOCKLEN_T_LIMIT || tmp < length) + return 0; + *result = tmp; + return 1; +} +#endif + +/* Return true iff cmsgh's cmsg_len member is wholly contained in the + buffer msg->msg_control. */ +static int +cmsg_len_in_buffer(struct msghdr *msg, struct cmsghdr *cmsgh) +{ + size_t cmsg_offset; + static const size_t cmsg_len_end = + offsetof(struct cmsghdr, cmsg_len) + sizeof(cmsgh->cmsg_len); + + if (cmsgh == NULL || msg->msg_controllen < 0) + return 0; + cmsg_offset = (char *)cmsgh - (char *)msg->msg_control; + return (cmsg_offset <= (size_t)-1 - cmsg_len_end && + cmsg_offset + cmsg_len_end <= msg->msg_controllen); +} + +/* Return true iff cmsgh (which must have an initialized cmsg_len + member) points to a valid control message with space for data_len + bytes in CMSG_DATA(cmsgh). */ +static int +cmsg_valid_with_space(struct msghdr *msg, struct cmsghdr *cmsgh, + size_t data_len) +{ + size_t cmsg_offset, data_offset; + char *data_ptr; + + cmsg_offset = (char *)cmsgh - (char *)msg->msg_control; + if (cmsgh->cmsg_len < CMSG_LEN(0) || + cmsg_offset > (size_t)-1 - cmsgh->cmsg_len || + cmsg_offset + cmsgh->cmsg_len > msg->msg_controllen) + return 0; + if ((data_ptr = (char *)CMSG_DATA(cmsgh)) == NULL) + return 0; + data_offset = data_ptr - (char *)msg->msg_control; + return (data_offset <= (size_t)-1 - data_len && + data_offset + data_len <= msg->msg_controllen); +} + +/* If cmsgh points to a valid control message in msg->msg_control, + store the length of its associated data in *data_len and return + true; otherwise return false. */ +static int +get_cmsg_data_len(struct msghdr *msg, struct cmsghdr *cmsgh, size_t *data_len) +{ + size_t result; + + if (!cmsg_len_in_buffer(msg, cmsgh) || cmsgh->cmsg_len < CMSG_LEN(0)) + return 0; + result = cmsgh->cmsg_len - CMSG_LEN(0); + if (!cmsg_valid_with_space(msg, cmsgh, result)) + return 0; + *data_len = result; + return 1; +} +#endif /* CMSG_LEN */ + + /* s._accept() -> (fd, address) */ static PyObject * @@ -2498,6 +2604,320 @@ Like recv_into(buffer[, nbytes[, flags]]) but also return the sender's address info."); +#ifdef CMSG_LEN +/* + * Call recvmsg() with the supplied iovec structures, flags, and + * ancillary data buffer size (controllen). Returns the tuple return + * value for recvmsg() or recvmsg_into(), with the first item provided + * by the supplied makeval() function. makeval() will be called with + * the length read and makeval_data as arguments, and must return a + * new reference (which will be decrefed if there is a subsequent + * error). On error, closes any file descriptors received via + * SCM_RIGHTS. + */ +static PyObject * +sock_recvmsg_guts(PySocketSockObject *s, struct iovec *iov, int iovlen, + int flags, Py_ssize_t controllen, + PyObject *(*makeval)(ssize_t, void *), void *makeval_data) +{ + ssize_t bytes_received; + int timeout; + sock_addr_t addrbuf; + static const struct msghdr msg_blank; + struct msghdr msg; + PyObject *cmsg_list = NULL, *retval = NULL; + struct cmsghdr *cmsgh; + size_t cmsgdatalen; + + msg = msg_blank; /* Set all members to 0 or NULL */ + msg.msg_iov = iov; + msg.msg_iovlen = iovlen; + + if (controllen < 0 || controllen > SOCKLEN_T_LIMIT) { + PyErr_SetString(PyExc_ValueError, + "invalid ancillary data buffer length"); + return NULL; + } + + msg.msg_controllen = controllen; + if (msg.msg_controllen > 0 && + (msg.msg_control = PyMem_Malloc(msg.msg_controllen)) == NULL) + return PyErr_NoMemory(); + + /* XXX: POSIX says that msg_name and msg_namelen "shall be + ignored" when the socket is connected (Linux fills them in + anyway for AF_UNIX sockets at least). Normally msg_namelen + seems to be set to 0 if there's no address, but try to + initialize msg_name to something that won't be mistaken for + a real address if that doesn't happen. */ + memset(&addrbuf, 0, sizeof(addrbuf)); + SAS2SA(&addrbuf)->sa_family = AF_UNSPEC; + msg.msg_name = &addrbuf; + msg.msg_namelen = sizeof(addrbuf); + + /* Make the system call. */ + if (!IS_SELECTABLE(s)) { + select_error(); + goto finally; + } + Py_BEGIN_ALLOW_THREADS; + timeout = internal_select(s, 0); + if (!timeout) + bytes_received = recvmsg(s->sock_fd, &msg, flags); + Py_END_ALLOW_THREADS; + + if (timeout) { + PyErr_SetString(socket_timeout, "timed out"); + goto finally; + } + else if (bytes_received < 0) { + s->errorhandler(); + goto finally; + } + + /* Make list of (level, type, data) tuples from control messages. */ + if ((cmsg_list = PyList_New(0)) == NULL) + goto err_closefds; + for (cmsgh = CMSG_FIRSTHDR(&msg); cmsgh != NULL; + cmsgh = CMSG_NXTHDR(&msg, cmsgh)) { + PyObject *bytes, *tuple; + int tmp; + + if (!get_cmsg_data_len(&msg, cmsgh, &cmsgdatalen)) { + PyErr_SetString(PyExc_RuntimeError, + "received malformed or improperly-" + "truncated ancillary data"); + goto err_closefds; + } + if (cmsgdatalen > PY_SSIZE_T_MAX) { + PyErr_SetString(socket_error, + "control message too long"); + goto err_closefds; + } + + bytes = PyBytes_FromStringAndSize((char *)CMSG_DATA(cmsgh), + cmsgdatalen); + tuple = Py_BuildValue("iiN", (int)cmsgh->cmsg_level, + (int)cmsgh->cmsg_type, bytes); + if (tuple == NULL) + goto err_closefds; + tmp = PyList_Append(cmsg_list, tuple); + Py_DECREF(tuple); + if (tmp != 0) + goto err_closefds; + } + + retval = Py_BuildValue("NOiN", + (*makeval)(bytes_received, makeval_data), + cmsg_list, + (int)msg.msg_flags, + makesockaddr( + s->sock_fd, SAS2SA(&addrbuf), + ((msg.msg_namelen > sizeof(addrbuf)) ? + sizeof(addrbuf) : msg.msg_namelen), + s->sock_proto)); + if (retval == NULL) + goto err_closefds; + +finally: + Py_XDECREF(cmsg_list); + PyMem_Free(msg.msg_control); + return retval; + +err_closefds: +#ifdef SCM_RIGHTS + /* Close all descriptors coming from SCM_RIGHTS, so they don't leak. */ + for (cmsgh = CMSG_FIRSTHDR(&msg); cmsgh != NULL; + cmsgh = CMSG_NXTHDR(&msg, cmsgh)) { + if (!get_cmsg_data_len(&msg, cmsgh, &cmsgdatalen)) + break; + if (cmsgh->cmsg_level == SOL_SOCKET && + cmsgh->cmsg_type == SCM_RIGHTS) { + size_t numfds; + int *fdp; + + numfds = cmsgdatalen / sizeof(int); + fdp = (int *)CMSG_DATA(cmsgh); + while (numfds-- > 0) + close(*fdp++); + } + } +#endif /* SCM_RIGHTS */ + goto finally; +} + + +static PyObject * +makeval_recvmsg(ssize_t received, void *data) +{ + PyObject **buf = data; + + if (received < PyBytes_GET_SIZE(*buf)) + _PyBytes_Resize(buf, received); + Py_XINCREF(*buf); + return *buf; +} + +/* s.recvmsg(bufsize[, ancbufsize[, flags]]) method */ + +static PyObject * +sock_recvmsg(PySocketSockObject *s, PyObject *args) +{ + Py_ssize_t bufsize, ancbufsize = 0; + int flags = 0; + struct iovec iov; + PyObject *buf = NULL, *retval = NULL; + + if (!PyArg_ParseTuple(args, "n|ni:recvmsg", + &bufsize, &ancbufsize, &flags)) + return NULL; + + if (bufsize < 0) { + PyErr_SetString(PyExc_ValueError, + "negative buffer size in recvmsg()"); + return NULL; + } + if ((buf = PyBytes_FromStringAndSize(NULL, bufsize)) == NULL) + return NULL; + iov.iov_base = PyBytes_AS_STRING(buf); + iov.iov_len = bufsize; + + /* Note that we're passing a pointer to *our pointer* to the + bytes object here (&buf); makeval_recvmsg() may incref the + object, or deallocate it and set our pointer to NULL. */ + retval = sock_recvmsg_guts(s, &iov, 1, flags, ancbufsize, + &makeval_recvmsg, &buf); + Py_XDECREF(buf); + return retval; +} + +PyDoc_STRVAR(recvmsg_doc, +"recvmsg(bufsize[, ancbufsize[, flags]]) -> (data, ancdata, msg_flags, address)\n\ +\n\ +Receive normal data (up to bufsize bytes) and ancillary data from the\n\ +socket. The ancbufsize argument sets the size in bytes of the\n\ +internal buffer used to receive the ancillary data; it defaults to 0,\n\ +meaning that no ancillary data will be received. Appropriate buffer\n\ +sizes for ancillary data can be calculated using CMSG_LEN() or\n\ +CMSG_SPACE(), and items which do not fit into the buffer might be\n\ +truncated or discarded. The flags argument defaults to 0 and has the\n\ +same meaning as for recv().\n\ +\n\ +The return value is a 4-tuple: (data, ancdata, msg_flags, address).\n\ +The data item is a bytes object holding the non-ancillary data\n\ +received. The ancdata item is a list of zero or more tuples\n\ +(cmsg_level, cmsg_type, cmsg_data) representing the ancillary data\n\ +(control messages) received: cmsg_level and cmsg_type are integers\n\ +specifying the protocol level and protocol-specific type respectively,\n\ +and cmsg_data is a bytes object holding the associated data. The\n\ +msg_flags item is the bitwise OR of various flags indicating\n\ +conditions on the received message; see your system documentation for\n\ +details. If the receiving socket is unconnected, address is the\n\ +address of the sending socket, if available; otherwise, its value is\n\ +unspecified.\n\ +\n\ +If recvmsg() encounters an error after the system call returns, it\n\ +will attempt to close any file descriptors received via the SCM_RIGHTS\n\ +mechanism before raising an exception."); + + +static PyObject * +makeval_recvmsg_into(ssize_t received, void *data) +{ + return PyLong_FromSsize_t(received); +} + +/* s.recvmsg_into(buffers[, ancbufsize[, flags]]) method */ + +static PyObject * +sock_recvmsg_into(PySocketSockObject *s, PyObject *args) +{ + Py_ssize_t ancbufsize = 0; + int flags = 0; + struct iovec *iovs = NULL; + Py_ssize_t i, nitems, nbufs = 0; + Py_buffer *bufs = NULL; + PyObject *buffers_arg, *fast, *retval = NULL; + + if (!PyArg_ParseTuple(args, "O|ni:recvmsg_into", + &buffers_arg, &ancbufsize, &flags)) + return NULL; + + if ((fast = PySequence_Fast(buffers_arg, "recvmsg_into() argument 1 " + "must be an iterable")) == NULL) + return NULL; + nitems = PySequence_Fast_GET_SIZE(fast); + if (nitems > INT_MAX) { + PyErr_SetString(socket_error, + "recvmsg_into() argument 1 is too long"); + goto finally; + } + + /* Fill in an iovec for each item, and save the Py_buffer + structs to release afterwards. */ + if (nitems > 0 && + ((iovs = PyMem_New(struct iovec, nitems)) == NULL || + (bufs = PyMem_New(Py_buffer, nitems)) == NULL)) { + PyErr_NoMemory(); + goto finally; + } + for (; nbufs < nitems; nbufs++) { + if (!PyArg_Parse(PySequence_Fast_GET_ITEM(fast, nbufs), + "w*;recvmsg_into() argument 1 must be an " + "iterable of single-segment read-write " + "buffers", + &bufs[nbufs])) + goto finally; + iovs[nbufs].iov_base = bufs[nbufs].buf; + iovs[nbufs].iov_len = bufs[nbufs].len; + } + + retval = sock_recvmsg_guts(s, iovs, nitems, flags, ancbufsize, + &makeval_recvmsg_into, NULL); +finally: + for (i = 0; i < nbufs; i++) + PyBuffer_Release(&bufs[i]); + PyMem_Free(bufs); + PyMem_Free(iovs); + Py_DECREF(fast); + return retval; +} + +PyDoc_STRVAR(recvmsg_into_doc, +"recvmsg_into(buffers[, ancbufsize[, flags]]) -> (nbytes, ancdata, msg_flags, address)\n\ +\n\ +Receive normal data and ancillary data from the socket, scattering the\n\ +non-ancillary data into a series of buffers. The buffers argument\n\ +must be an iterable of objects that export writable buffers\n\ +(e.g. bytearray objects); these will be filled with successive chunks\n\ +of the non-ancillary data until it has all been written or there are\n\ +no more buffers. The ancbufsize argument sets the size in bytes of\n\ +the internal buffer used to receive the ancillary data; it defaults to\n\ +0, meaning that no ancillary data will be received. Appropriate\n\ +buffer sizes for ancillary data can be calculated using CMSG_LEN() or\n\ +CMSG_SPACE(), and items which do not fit into the buffer might be\n\ +truncated or discarded. The flags argument defaults to 0 and has the\n\ +same meaning as for recv().\n\ +\n\ +The return value is a 4-tuple: (nbytes, ancdata, msg_flags, address).\n\ +The nbytes item is the total number of bytes of non-ancillary data\n\ +written into the buffers. The ancdata item is a list of zero or more\n\ +tuples (cmsg_level, cmsg_type, cmsg_data) representing the ancillary\n\ +data (control messages) received: cmsg_level and cmsg_type are\n\ +integers specifying the protocol level and protocol-specific type\n\ +respectively, and cmsg_data is a bytes object holding the associated\n\ +data. The msg_flags item is the bitwise OR of various flags\n\ +indicating conditions on the received message; see your system\n\ +documentation for details. If the receiving socket is unconnected,\n\ +address is the address of the sending socket, if available; otherwise,\n\ +its value is unspecified.\n\ +\n\ +If recvmsg_into() encounters an error after the system call returns,\n\ +it will attempt to close any file descriptors received via the\n\ +SCM_RIGHTS mechanism before raising an exception."); +#endif /* CMSG_LEN */ + + /* s.send(data [,flags]) method */ static PyObject * @@ -2672,6 +3092,237 @@ For IP sockets, the address is a pair (hostaddr, port)."); +#ifdef CMSG_LEN +/* s.sendmsg(buffers[, ancdata[, flags[, address]]]) method */ + +static PyObject * +sock_sendmsg(PySocketSockObject *s, PyObject *args) +{ + Py_ssize_t i, ndataparts, ndatabufs = 0, ncmsgs, ncmsgbufs = 0; + Py_buffer *databufs = NULL; + sock_addr_t addrbuf; + static const struct msghdr msg_blank; + struct msghdr msg; + struct cmsginfo { + int level; + int type; + Py_buffer data; + } *cmsgs = NULL; + size_t controllen, controllen_last; + ssize_t bytes_sent; + int addrlen, timeout, flags = 0; + PyObject *data_arg, *cmsg_arg = NULL, *addr_arg = NULL, + *data_fast = NULL, *cmsg_fast = NULL, *retval = NULL; + + if (!PyArg_ParseTuple(args, "O|OiO:sendmsg", &data_arg, &cmsg_arg, + &flags, &addr_arg)) + return NULL; + + msg = msg_blank; /* Set all members to 0 or NULL */ + + /* Fill in an iovec for each message part, and save the + Py_buffer structs to release afterwards. */ + if ((data_fast = PySequence_Fast(data_arg, "sendmsg() argument 1 must " + "be an iterable")) == NULL) + goto finally; + ndataparts = PySequence_Fast_GET_SIZE(data_fast); + if (ndataparts > INT_MAX) { + PyErr_SetString(socket_error, + "sendmsg() argument 1 is too long"); + goto finally; + } + msg.msg_iovlen = ndataparts; + if (ndataparts > 0 && + ((msg.msg_iov = PyMem_New(struct iovec, ndataparts)) == NULL || + (databufs = PyMem_New(Py_buffer, ndataparts)) == NULL)) { + PyErr_NoMemory(); + goto finally; + } + for (; ndatabufs < ndataparts; ndatabufs++) { + if (!PyArg_Parse( + PySequence_Fast_GET_ITEM(data_fast, ndatabufs), + "y*;sendmsg() argument 1 must be an iterable of " + "buffer-compatible objects", + &databufs[ndatabufs])) + goto finally; + msg.msg_iov[ndatabufs].iov_base = databufs[ndatabufs].buf; + msg.msg_iov[ndatabufs].iov_len = databufs[ndatabufs].len; + } + + if (cmsg_arg == NULL) + ncmsgs = 0; + else { + if ((cmsg_fast = PySequence_Fast(cmsg_arg, + "sendmsg() argument 2 must " + "be an iterable")) == NULL) + goto finally; + ncmsgs = PySequence_Fast_GET_SIZE(cmsg_fast); + } + +#ifndef CMSG_SPACE + if (ncmsgs > 1) { + PyErr_SetString(socket_error, + "sending multiple control messages is not " + "supported on this system"); + goto finally; + } +#endif + /* Save level, type and Py_buffer for each control message, + and calculate total size. */ + if (ncmsgs > 0 && + (cmsgs = PyMem_New(struct cmsginfo, ncmsgs)) == NULL) { + PyErr_NoMemory(); + goto finally; + } + controllen = controllen_last = 0; + while (ncmsgbufs < ncmsgs) { + size_t bufsize, space; + + if (!PyArg_Parse( + PySequence_Fast_GET_ITEM(cmsg_fast, ncmsgbufs), + "(iiy*):[sendmsg() ancillary data items]", + &cmsgs[ncmsgbufs].level, + &cmsgs[ncmsgbufs].type, + &cmsgs[ncmsgbufs].data)) + goto finally; + bufsize = cmsgs[ncmsgbufs++].data.len; + +#ifdef CMSG_SPACE + if (!get_CMSG_SPACE(bufsize, &space)) { +#else + if (!get_CMSG_LEN(bufsize, &space)) { +#endif + PyErr_SetString(socket_error, + "ancillary data item too large"); + goto finally; + } + controllen += space; + if (controllen > SOCKLEN_T_LIMIT || + controllen < controllen_last) { + PyErr_SetString(socket_error, + "too much ancillary data"); + goto finally; + } + controllen_last = controllen; + } + + /* Construct ancillary data block from control message info. */ + msg.msg_controllen = controllen; + if (controllen > 0) { + struct cmsghdr *cmsgh = NULL; + + if ((msg.msg_control = PyMem_Malloc(controllen)) == NULL) { + PyErr_NoMemory(); + goto finally; + } + /* Need to zero out the buffer as a workaround for + glibc's CMSG_NXTHDR() implementation. After + getting the pointer to the next header, it checks + its (uninitialized) cmsg_len member to see if the + "message" fits in the buffer, and returns NULL if + it doesn't. Zero-filling the buffer ensures that + that doesn't happen. */ + memset(msg.msg_control, 0, controllen); + + for (i = 0; i < ncmsgbufs; i++) { + size_t msg_len, data_len = cmsgs[i].data.len; + int enough_space = 0; + + cmsgh = (i == 0) ? + CMSG_FIRSTHDR(&msg) : CMSG_NXTHDR(&msg, cmsgh); + if (cmsgh == NULL) { + PyErr_Format(PyExc_RuntimeError, + "unexpected NULL result from " + "%s()", (i == 0) ? + "CMSG_FIRSTHDR" : "CMSG_NXTHDR"); + goto finally; + } + if (!get_CMSG_LEN(data_len, &msg_len)) { + PyErr_SetString(PyExc_RuntimeError, + "item size out of range for " + "CMSG_LEN()"); + goto finally; + } + if (cmsg_len_in_buffer(&msg, cmsgh)) { + cmsgh->cmsg_len = msg_len; + enough_space = cmsg_valid_with_space(&msg, + cmsgh, + data_len); + } + if (!enough_space) { + PyErr_SetString(PyExc_RuntimeError, + "ancillary data does not fit " + "in calculated space"); + goto finally; + } + cmsgh->cmsg_level = cmsgs[i].level; + cmsgh->cmsg_type = cmsgs[i].type; + memcpy(CMSG_DATA(cmsgh), cmsgs[i].data.buf, data_len); + } + } + + /* Parse destination address. */ + if (addr_arg != NULL && addr_arg != Py_None) { + if (!getsockaddrarg(s, addr_arg, SAS2SA(&addrbuf), &addrlen)) + goto finally; + msg.msg_name = &addrbuf; + msg.msg_namelen = addrlen; + } + + /* Make the system call. */ + if (!IS_SELECTABLE(s)) { + select_error(); + goto finally; + } + Py_BEGIN_ALLOW_THREADS; + timeout = internal_select(s, 1); + if (!timeout) + bytes_sent = sendmsg(s->sock_fd, &msg, flags); + Py_END_ALLOW_THREADS; + + if (timeout) { + PyErr_SetString(socket_timeout, "timed out"); + goto finally; + } + else if (bytes_sent < 0) { + s->errorhandler(); + goto finally; + } + retval = PyLong_FromSsize_t(bytes_sent); + +finally: + PyMem_Free(msg.msg_control); + for (i = 0; i < ncmsgbufs; i++) + PyBuffer_Release(&cmsgs[i].data); + PyMem_Free(cmsgs); + Py_XDECREF(cmsg_fast); + for (i = 0; i < ndatabufs; i++) + PyBuffer_Release(&databufs[i]); + PyMem_Free(databufs); + PyMem_Free(msg.msg_iov); + Py_XDECREF(data_fast); + return retval; +} + +PyDoc_STRVAR(sendmsg_doc, +"sendmsg(buffers[, ancdata[, flags[, address]]]) -> count\n\ +\n\ +Send normal and ancillary data to the socket, gathering the\n\ +non-ancillary data from a series of buffers and concatenating it into\n\ +a single message. The buffers argument specifies the non-ancillary\n\ +data as an iterable of buffer-compatible objects (e.g. bytes objects).\n\ +The ancdata argument specifies the ancillary data (control messages)\n\ +as an iterable of zero or more tuples (cmsg_level, cmsg_type,\n\ +cmsg_data), where cmsg_level and cmsg_type are integers specifying the\n\ +protocol level and protocol-specific type respectively, and cmsg_data\n\ +is a buffer-compatible object holding the associated data. The flags\n\ +argument defaults to 0 and has the same meaning as for send(). If\n\ +address is supplied and not None, it sets a destination address for\n\ +the message. The return value is the number of bytes of non-ancillary\n\ +data sent."); +#endif /* CMSG_LEN */ + + /* s.shutdown(how) method */ static PyObject * @@ -2796,6 +3447,14 @@ setsockopt_doc}, {"shutdown", (PyCFunction)sock_shutdown, METH_O, shutdown_doc}, +#ifdef CMSG_LEN + {"recvmsg", (PyCFunction)sock_recvmsg, METH_VARARGS, + recvmsg_doc}, + {"recvmsg_into", (PyCFunction)sock_recvmsg_into, METH_VARARGS, + recvmsg_into_doc,}, + {"sendmsg", (PyCFunction)sock_sendmsg, METH_VARARGS, + sendmsg_doc}, +#endif {NULL, NULL} /* sentinel */ }; @@ -4044,6 +4703,70 @@ When the socket module is first imported, the default is None."); +#ifdef CMSG_LEN +/* Python interface to CMSG_LEN(length). */ + +static PyObject * +socket_CMSG_LEN(PyObject *self, PyObject *args) +{ + Py_ssize_t length; + size_t result; + + if (!PyArg_ParseTuple(args, "n:CMSG_LEN", &length)) + return NULL; + if (length < 0 || !get_CMSG_LEN(length, &result)) { + PyErr_Format(PyExc_OverflowError, + "CMSG_LEN() argument out of range"); + return NULL; + } + return PyLong_FromSize_t(result); +} + +PyDoc_STRVAR(CMSG_LEN_doc, +"CMSG_LEN(length) -> control message length\n\ +\n\ +Return the value which would be stored in the cmsg_len field of a\n\ +control message with associated data of the given length. This value\n\ +can often be used with recvmsg() as the ancillary data buffer size to\n\ +receive a single control message without trailing padding, but RFC\n\ +3542 requires portable applications to use CMSG_SPACE() and include\n\ +space for padding, even when the message is the last in the buffer.\n\ +Raises OverflowError if length is outside the permissible range of\n\ +values."); + + +#ifdef CMSG_SPACE +/* Python interface to CMSG_SPACE(length). */ + +static PyObject * +socket_CMSG_SPACE(PyObject *self, PyObject *args) +{ + Py_ssize_t length; + size_t result; + + if (!PyArg_ParseTuple(args, "n:CMSG_SPACE", &length)) + return NULL; + if (length < 0 || !get_CMSG_SPACE(length, &result)) { + PyErr_SetString(PyExc_OverflowError, + "CMSG_SPACE() argument out of range"); + return NULL; + } + return PyLong_FromSize_t(result); +} + +PyDoc_STRVAR(CMSG_SPACE_doc, +"CMSG_SPACE(length) -> buffer size\n\ +\n\ +Return the number of bytes needed to hold a control message with\n\ +associated data of the given length, along with any trailing padding\n\ +needed. RFC 3542 specifies the ancillary data buffer size needed to\n\ +receive a set of control messages as the sum of the CMSG_SPACE()\n\ +values for the associated data items. Raises OverflowError if length\n\ +is outside the permissible range of values."); +#endif /* CMSG_SPACE */ +#endif /* CMSG_LEN */ + + /* List of functions exported by this module. */ static PyMethodDef socket_methods[] = { @@ -4095,6 +4818,14 @@ METH_NOARGS, getdefaulttimeout_doc}, {"setdefaulttimeout", socket_setdefaulttimeout, METH_O, setdefaulttimeout_doc}, +#ifdef CMSG_LEN + {"CMSG_LEN", socket_CMSG_LEN, + METH_VARARGS, CMSG_LEN_doc}, +#ifdef CMSG_SPACE + {"CMSG_SPACE", socket_CMSG_SPACE, + METH_VARARGS, CMSG_SPACE_doc}, +#endif +#endif {NULL, NULL} /* Sentinel */ }; @@ -4551,6 +5282,15 @@ #ifdef SO_TYPE PyModule_AddIntConstant(m, "SO_TYPE", SO_TYPE); #endif +#ifdef SO_PASSCRED + PyModule_AddIntConstant(m, "SO_PASSCRED", SO_PASSCRED); +#endif +#ifdef SO_PEERCRED + PyModule_AddIntConstant(m, "SO_PEERCRED", SO_PEERCRED); +#endif +#ifdef LOCAL_PEERCRED + PyModule_AddIntConstant(m, "LOCAL_PEERCRED", LOCAL_PEERCRED); +#endif /* Maximum number of connections for "listen" */ #ifdef SOMAXCONN @@ -4559,6 +5299,17 @@ PyModule_AddIntConstant(m, "SOMAXCONN", 5); /* Common value */ #endif + /* Ancilliary message types */ +#ifdef SCM_RIGHTS + PyModule_AddIntConstant(m, "SCM_RIGHTS", SCM_RIGHTS); +#endif +#ifdef SCM_CREDENTIALS + PyModule_AddIntConstant(m, "SCM_CREDENTIALS", SCM_CREDENTIALS); +#endif +#ifdef SCM_CREDS + PyModule_AddIntConstant(m, "SCM_CREDS", SCM_CREDS); +#endif + /* Flags for send, recv */ #ifdef MSG_OOB PyModule_AddIntConstant(m, "MSG_OOB", MSG_OOB); @@ -4590,6 +5341,33 @@ #ifdef MSG_ETAG PyModule_AddIntConstant(m, "MSG_ETAG", MSG_ETAG); #endif +#ifdef MSG_NOSIGNAL + PyModule_AddIntConstant(m, "MSG_NOSIGNAL", MSG_NOSIGNAL); +#endif +#ifdef MSG_NOTIFICATION + PyModule_AddIntConstant(m, "MSG_NOTIFICATION", MSG_NOTIFICATION); +#endif +#ifdef MSG_CMSG_CLOEXEC + PyModule_AddIntConstant(m, "MSG_CMSG_CLOEXEC", MSG_CMSG_CLOEXEC); +#endif +#ifdef MSG_ERRQUEUE + PyModule_AddIntConstant(m, "MSG_ERRQUEUE", MSG_ERRQUEUE); +#endif +#ifdef MSG_CONFIRM + PyModule_AddIntConstant(m, "MSG_CONFIRM", MSG_CONFIRM); +#endif +#ifdef MSG_MORE + PyModule_AddIntConstant(m, "MSG_MORE", MSG_MORE); +#endif +#ifdef MSG_EOF + PyModule_AddIntConstant(m, "MSG_EOF", MSG_EOF); +#endif +#ifdef MSG_BCAST + PyModule_AddIntConstant(m, "MSG_BCAST", MSG_BCAST); +#endif +#ifdef MSG_MCAST + PyModule_AddIntConstant(m, "MSG_MCAST", MSG_MCAST); +#endif /* Protocol level and numbers, usable for [gs]etsockopt */ #ifdef SOL_SOCKET @@ -4729,6 +5507,9 @@ #ifdef IPPROTO_VRRP PyModule_AddIntConstant(m, "IPPROTO_VRRP", IPPROTO_VRRP); #endif +#ifdef IPPROTO_SCTP + PyModule_AddIntConstant(m, "IPPROTO_SCTP", IPPROTO_SCTP); +#endif #ifdef IPPROTO_BIP PyModule_AddIntConstant(m, "IPPROTO_BIP", IPPROTO_BIP); #endif