diff --git a/Lib/ssl.py b/Lib/ssl.py index 8ad4a33..27842b3 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -749,9 +749,9 @@ class SSLSocket(socket): self.ssl_version = ssl_version self.ca_certs = ca_certs self.ciphers = ciphers - # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get - # mixed in. - if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM: + if sock is not None: + type = sock.type + if type & SOCK_STREAM != SOCK_STREAM: raise NotImplementedError("only stream sockets are supported") if server_side: if server_hostname: diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index d203cdd..c80c782 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -181,6 +181,26 @@ class BasicSocketTests(unittest.TestCase): ctx = ssl.SSLContext(proto) self.assertIs(ctx.protocol, proto) + def test_constructor(self): + s = socket.socket() + sfd = s.fileno() + + # Create secure socket from socket + ss = ssl.SSLSocket(sock=s) + self.assertTrue(ss.fileno() == sfd) + + s = socket.socket() + sfd = s.detach() + + # Create secure socket from scratch + ss = ssl.SSLSocket() + self.assertTrue(ss.fileno() >= 0) + self.assertTrue(ss.fileno() != sfd) + + # Create secure socket from fileno + ss = ssl.SSLSocket(fileno=sfd) + self.assertTrue(ss.fileno() == sfd) + def test_random(self): v = ssl.RAND_status() if support.verbose: