diff --git a/Lib/ssl.py b/Lib/ssl.py --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -861,6 +861,15 @@ class SSLSocket(socket): return None return self._sslobj.tls_unique_cb() + def version(self): + """ + Return a string identifying the protocol version used by the + current SSL channel, or None if there is no established channel. + """ + if self._sslobj is None: + return None + return self._sslobj.version() + def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -1942,7 +1942,8 @@ else: 'compression': s.compression(), 'cipher': s.cipher(), 'peercert': s.getpeercert(), - 'client_npn_protocol': s.selected_npn_protocol() + 'client_npn_protocol': s.selected_npn_protocol(), + 'version': s.version(), }) s.close() stats['server_npn_protocols'] = server.selected_protocols @@ -1950,6 +1951,13 @@ else: def try_protocol_combo(server_protocol, client_protocol, expect_success, certsreqs=None, server_options=0, client_options=0): + """ + Try to SSL-connect using *client_protocol* to *server_protocol*. + If *expect_success* is true, assert that the connection succeeds, + if it's false, assert that the connection fails. + Also, if *expect_success* is a string, assert that it is the protocol + version actually used by the connection. + """ if certsreqs is None: certsreqs = ssl.CERT_NONE certtype = { @@ -1979,8 +1987,8 @@ else: ctx.load_cert_chain(CERTFILE) ctx.load_verify_locations(CERTFILE) try: - server_params_test(client_context, server_context, - chatty=False, connectionchatty=False) + stats = server_params_test(client_context, server_context, + chatty=False, connectionchatty=False) # Protocol mismatch can result in either an SSLError, or a # "Connection reset by peer" error. except ssl.SSLError: @@ -1995,6 +2003,10 @@ else: "Client protocol %s succeeded with server protocol %s!" % (ssl.get_protocol_name(client_protocol), ssl.get_protocol_name(server_protocol))) + elif (expect_success is not True + and expect_success != stats['version']): + raise AssertionError("version mismatch: expected %r, got %r" + % (expect_success, stats['version'])) class ThreadedTests(unittest.TestCase): @@ -2225,17 +2237,17 @@ else: sys.stdout.write( " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" % str(x)) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3') try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) # Server with specific SSL options try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, @@ -2252,9 +2264,9 @@ else: """Connecting to an SSLv3 server with various client options""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, @@ -2262,7 +2274,7 @@ else: try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) if no_sslv2_implies_sslv3_hello(): # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, True, + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, 'SSLv3', client_options=ssl.OP_NO_SSLv2) @skip_if_broken_ubuntu_ssl @@ -2270,9 +2282,9 @@ else: """Connecting to a TLSv1 server with various client options""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) @@ -2287,14 +2299,14 @@ else: Testing against older TLS versions.""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1_1) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) @@ -2307,7 +2319,7 @@ else: Testing against older TLS versions.""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True, + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2', server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) if hasattr(ssl, 'PROTOCOL_SSLv2'): @@ -2316,7 +2328,7 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1_2) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) @@ -2697,6 +2709,21 @@ else: s.connect((HOST, server.port)) self.assertIn("no shared cipher", str(server.conn_errors[0])) + def test_version_basic(self): + """ + Basic tests for SSLSocket.version(). + More tests are done in the test_protocol_*() methods. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + self.assertIs(s.version(), None) + s.connect((HOST, server.port)) + self.assertEqual(s.version(), "TLSv1") + self.assertIs(s.version(), None) + @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") def test_default_ecdh_curve(self): # Issue #21015: elliptic curve-based Diffie Hellman key exchange diff --git a/Modules/_ssl.c b/Modules/_ssl.c --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -1402,6 +1402,18 @@ static PyObject *PySSL_cipher (PySSLSock return NULL; } +static PyObject *PySSL_version(PySSLSocket *self) +{ + const char *version; + + if (self->ssl == NULL) + Py_RETURN_NONE; + version = SSL_get_version(self->ssl); + if (!strcmp(version, "unknown")) + Py_RETURN_NONE; + return PyUnicode_FromString(version); +} + #ifdef OPENSSL_NPN_NEGOTIATED static PyObject *PySSL_selected_npn_protocol(PySSLSocket *self) { const unsigned char *out; @@ -1939,6 +1951,7 @@ static PyMethodDef PySSLMethods[] = { {"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS, PySSL_peercert_doc}, {"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS}, + {"version", (PyCFunction)PySSL_version, METH_NOARGS}, #ifdef OPENSSL_NPN_NEGOTIATED {"selected_npn_protocol", (PyCFunction)PySSL_selected_npn_protocol, METH_NOARGS}, #endif