diff --git a/Lib/socket.py b/Lib/socket.py index ac2e3dd..01009aa4 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -450,6 +450,29 @@ def fromfd(fd, family, type, proto=0): nfd = dup(fd) return socket(family, type, proto, nfd) +if hasattr(_socket, "fdtype"): + def fdtype(fd): + """fdtype(fd) -> (family, type, proto) + + Return (family, type, proto) for a socket given a file descriptor. + Raises OSError if the file descriptor is not a socket. + """ + family, type, proto = _socket.fdtype(fd) + return (_intenum_converter(family, AddressFamily), + _intenum_converter(type, SocketKind), + proto) + + def fromfd2(fd): + """fromfd2(fd) -> socket object + + Create a socket object from the given file descriptor. Unlike fromfd, + the descriptor is not duplicated. The family, type and protocol of + the socket is determined using fdtype(). Raises OSError if the file + descriptor is not a socket. + """ + family, type, proto = _socket.fdtype(fd) + return socket(family, type, proto, fd) + if hasattr(_socket.socket, "share"): def fromshare(info): """ fromshare(info) -> socket object diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 1ddd604..d06faaf 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -4938,6 +4938,36 @@ class NonblockConstantTest(unittest.TestCase): socket.setdefaulttimeout(t) +@unittest.skipUnless(hasattr(socket, 'fdtype'), 'test needs socket.fdtype') +class FdTypeTests(unittest.TestCase): + TYPES = [ + (socket.AF_INET, socket.SOCK_STREAM), + (socket.AF_INET, socket.SOCK_DGRAM), + ] + if hasattr(socket, 'AF_UNIX'): + TYPES.extend([ + (socket.AF_UNIX, socket.SOCK_STREAM), + (socket.AF_UNIX, socket.SOCK_DGRAM), + ]) + + def test_fdtype(self): + for family, kind in TYPES: + s = socket.socket(family, kind) + with s: + self.assertEqual(socket.fdtype(s.fileno()), + (family, kind, 0)) + + def test_fromfd2(self): + for family, kind in TYPES: + s = socket.socket(family, kind) + with s: + s2 = socket.fromfd2(s.fileno()) + with s2: + self.assertEqual(s.family, s2.family) + self.assertEqual(s.type, s2.type) + self.assertEqual(s.fileno(), s2.fileno()) + + @unittest.skipUnless(os.name == "nt", "Windows specific") @unittest.skipUnless(multiprocessing, "need multiprocessing") class TestSocketSharing(SocketTCPTest): diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c index dc57810..158484e 100644 --- a/Modules/socketmodule.c +++ b/Modules/socketmodule.c @@ -447,6 +447,18 @@ const char *inet_ntop(int af, const void *src, char *dst, socklen_t size); #define INADDR_NONE (-1) #endif +union sockaddr_union { + struct sockaddr sa; + struct sockaddr_in in4; +#ifdef ENABLE_IPV6 + struct sockaddr_in6 in6; +#endif +#if defined(AF_UNIX) + struct sockaddr_un un; +#endif + struct sockaddr_storage storage; +}; + /* XXX There's a problem here: *static* functions are not supposed to have a Py prefix (or use CapitalizedWords). Later... */ @@ -5136,6 +5148,62 @@ AF_UNIX if defined on the platform; otherwise, the default is AF_INET."); #endif /* HAVE_SOCKETPAIR */ +/* socket.fdtype() function */ + +#if defined(S_ISSOCK) && defined(SO_TYPE) +#define HAVE_FDTYPE /* have enough to implement fdtype() */ +#endif + +#ifdef HAVE_FDTYPE +static PyObject * +socket_fdtype(PyObject *self, PyObject *fdobj) +{ + SOCKET_T fd; + struct stat st_fd; + int sock_type; + union sockaddr_union sockaddr = {}; + socklen_t l; + int protocol; + + fd = PyLong_AsSocket_t(fdobj); + if (fd == (SOCKET_T)(-1) && PyErr_Occurred()) + return NULL; + + if (fstat(fd, &st_fd) < 0) { + return set_error(); + } + + if (!S_ISSOCK(st_fd.st_mode)) { + PyErr_SetString(PyExc_ValueError, "fdtype: " + "file descriptor is not a socket"); + return NULL; + } + + l = sizeof(sock_type); + if (getsockopt(fd, SOL_SOCKET, SO_TYPE, &sock_type, &l) < 0) { + return set_error(); + } + + l = sizeof(sockaddr); + if (getsockname(fd, &sockaddr.sa, &l) < 0) { + return set_error(); + } + /* FIXME: assume zero, how to find protocol? */ + protocol = 0; + return Py_BuildValue("iii", + sockaddr.sa.sa_family, + sock_type, + protocol); +} + +PyDoc_STRVAR(fdtype_doc, +"fdtype(integer) -> (family, type, protocol)\n\ +\n\ +Return the family, type and protocol for socket given a file descriptor.\ +"); +#endif /* HAVE_FDTYPE */ + + static PyObject * socket_ntohs(PyObject *self, PyObject *args) { @@ -6022,6 +6090,10 @@ static PyMethodDef socket_methods[] = { {"socketpair", socket_socketpair, METH_VARARGS, socketpair_doc}, #endif +#ifdef HAVE_FDTYPE + {"fdtype", socket_fdtype, + METH_O, fdtype_doc}, +#endif {"ntohs", socket_ntohs, METH_VARARGS, ntohs_doc}, {"ntohl", socket_ntohl,