diff --git a/Doc/library/ssl.rst b/Doc/library/ssl.rst --- a/Doc/library/ssl.rst +++ b/Doc/library/ssl.rst @@ -470,6 +470,16 @@ Constants .. versionadded:: 3.2 +.. data:: HAS_NPN + + Whether the OpenSSL library has built-in support for *Next Protocol + Negotiation* as described in the `NPN draft specification + `_. When true, + you can use the :meth:`SSLContext.set_npn_protocols` method to advertise + which protocols you want to support. + + .. versionadded:: 3.3 + .. data:: CHANNEL_BINDING_TYPES List of supported TLS channel binding types. Strings in this list @@ -609,6 +619,15 @@ SSL sockets also have the following addi .. versionadded:: 3.3 +.. method:: SSLSocket.selected_npn_protocol() + + Returns the protocol that was selected during the TLS/SSL handshake. If + :meth:`SSLContext.set_npn_protocols` was not called, or if the other party + does not support NPN, or if the handshake has not yet happened, this will + return ``None``. + + .. versionadded:: 3.3 + .. method:: SSLSocket.unwrap() Performs the SSL shutdown handshake, which removes the TLS layer from the @@ -617,7 +636,6 @@ SSL sockets also have the following addi returned socket should always be used for further communication with the other side of the connection, rather than the original socket. - .. attribute:: SSLSocket.context The :class:`SSLContext` object this SSL socket is tied to. If the SSL @@ -715,6 +733,21 @@ to speed up repeated connections from th when connected, the :meth:`SSLSocket.cipher` method of SSL sockets will give the currently selected cipher. +.. method:: SSLContext.set_npn_protocols(protocols) + + Specify which protocols the socket should avertise during the SSL/TLS + handshake. It should be a list of strings, like ``['http/1.1', 'spdy/2']``, + ordered by preference. The selection of a protocol will happen during the + handshake, and will play out according to the `NPN draft specification + `_. After a + successful handshake, the :meth:`SSLSocket.selected_protocol` method will + return the agreed-upon protocol. + + This method will raise :exc:`NotImplementedError` if :data:`HAS_NPN` is + False. + + .. versionadded:: 3.3 + .. method:: SSLContext.load_dh_params(dhfile) Load the key generation parameters for Diffie-Helman (DH) key exchange. diff --git a/Lib/ssl.py b/Lib/ssl.py --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -90,7 +90,7 @@ from _ssl import ( SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE, ) -from _ssl import HAS_SNI, HAS_ECDH +from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1) from _ssl import _OPENSSL_API_VERSION @@ -209,6 +209,17 @@ class SSLContext(_SSLContext): server_hostname=server_hostname, _context=self) + def set_npn_protocols(self, npn_protocols): + protos = bytearray() + for protocol in npn_protocols: + b = bytes(protocol, 'ascii') + if len(b) == 0 or len(b) > 255: + raise SSLError('NPN protocols must be 1 to 255 in length') + protos.append(len(b)) + protos.extend(b) + + self._set_npn_protocols(protos) + class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps @@ -220,7 +231,7 @@ class SSLSocket(socket): ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, - suppress_ragged_eofs=True, ciphers=None, + suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, server_hostname=None, _context=None): @@ -240,6 +251,8 @@ class SSLSocket(socket): self.context.load_verify_locations(ca_certs) if certfile: self.context.load_cert_chain(certfile, keyfile) + if npn_protocols: + self.context.set_npn_protocols(npn_protocols) if ciphers: self.context.set_ciphers(ciphers) self.keyfile = keyfile @@ -340,6 +353,13 @@ class SSLSocket(socket): self._checkClosed() return self._sslobj.peer_certificate(binary_form) + def selected_npn_protocol(self): + self._checkClosed() + if not self._sslobj or not _ssl.HAS_NPN: + return None + else: + return self._sslobj.selected_npn_protocol() + def cipher(self): self._checkClosed() if not self._sslobj: @@ -568,7 +588,8 @@ def wrap_socket(sock, keyfile=None, cert server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True, ciphers=None): + suppress_ragged_eofs=True, + ciphers=None): return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, server_side=server_side, cert_reqs=cert_reqs, diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -879,6 +879,7 @@ else: try: self.sslconn = self.server.context.wrap_socket( self.sock, server_side=True) + self.server.selected_protocols.append(self.sslconn.selected_npn_protocol()) except ssl.SSLError as e: # XXX Various errors can have happened here, for example # a mismatching protocol version, an invalid certificate, @@ -901,6 +902,8 @@ else: cipher = self.sslconn.cipher() if support.verbose and self.server.chatty: sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") + sys.stdout.write(" server: selected protocol is now " + + str(self.sslconn.selected_npn_protocol()) + "\n") return True def read(self): @@ -979,7 +982,7 @@ else: def __init__(self, certificate=None, ssl_version=None, certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, - ciphers=None, context=None): + npn_protocols=None, ciphers=None, context=None): if context: self.context = context else: @@ -992,6 +995,8 @@ else: self.context.load_verify_locations(cacerts) if certificate: self.context.load_cert_chain(certificate) + if npn_protocols: + self.context.set_npn_protocols(npn_protocols) if ciphers: self.context.set_ciphers(ciphers) self.chatty = chatty @@ -1001,6 +1006,7 @@ else: self.port = support.bind_port(self.sock) self.flag = None self.active = False + self.selected_protocols = [] self.conn_errors = [] threading.Thread.__init__(self) self.daemon = True @@ -1195,6 +1201,7 @@ else: Launch a server, connect a client to it and try various reads and writes. """ + stats = {} server = ThreadedEchoServer(context=server_context, chatty=chatty, connectionchatty=False) @@ -1220,12 +1227,14 @@ else: if connectionchatty: if support.verbose: sys.stdout.write(" client: closing connection.\n") - stats = { + stats.update({ 'compression': s.compression(), 'cipher': s.cipher(), - } + 'client_npn_protocol': s.selected_npn_protocol() + }) s.close() - return stats + stats['server_npn_protocols'] = server.selected_protocols + return stats def try_protocol_combo(server_protocol, client_protocol, expect_success, certsreqs=None, server_options=0, client_options=0): @@ -1853,6 +1862,43 @@ else: if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: self.fail("Non-DH cipher: " + cipher[0]) + def test_selected_npn_protocol(self): + # selected_npn_protocol() is None unless NPN is used + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_npn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") + def test_npn_protocols(self): + server_protocols = ['http/1.1', 'spdy/2'] + protocol_tests = [ + (['http/1.1', 'spdy/2'], 'http/1.1'), + (['spdy/2', 'http/1.1'], 'http/1.1'), + (['spdy/2', 'test'], 'spdy/2'), + (['abc', 'def'], 'abc') + ] + for client_protocols, expected in protocol_tests: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(CERTFILE) + server_context.set_npn_protocols(server_protocols) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.load_cert_chain(CERTFILE) + client_context.set_npn_protocols(client_protocols) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) + + msg = "failed trying %s (s) and %s (c).\n" \ + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_npn_protocol'] + self.assertEqual(client_result, expected, msg % (client_result, "client")) + server_result = stats['server_npn_protocols'][-1] \ + if len(stats['server_npn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, msg % (server_result, "server")) + def test_main(verbose=False): if support.verbose: diff --git a/Modules/_ssl.c b/Modules/_ssl.c --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -159,6 +159,10 @@ static unsigned int _ssl_locks_count = 0 typedef struct { PyObject_HEAD SSL_CTX *ctx; +#ifdef OPENSSL_NPN_NEGOTIATED + char *npn_protocols; + int npn_protocols_len; +#endif } PySSLContext; typedef struct { @@ -1015,6 +1019,20 @@ static PyObject *PySSL_cipher (PySSLSock return NULL; } +#ifdef OPENSSL_NPN_NEGOTIATED +static PyObject *PySSL_selected_npn_protocol(PySSLSocket *self) { + const unsigned char *out; + unsigned int outlen; + + SSL_get0_next_proto_negotiated(self->ssl, + &out, &outlen); + + if (out == NULL) + Py_RETURN_NONE; + return PyUnicode_FromStringAndSize((char *) out, outlen); +} +#endif + static PyObject *PySSL_compression(PySSLSocket *self) { #ifdef OPENSSL_NO_COMP Py_RETURN_NONE; @@ -1487,6 +1505,9 @@ static PyMethodDef PySSLMethods[] = { {"peer_certificate", (PyCFunction)PySSL_peercert, METH_VARARGS, PySSL_peercert_doc}, {"cipher", (PyCFunction)PySSL_cipher, METH_NOARGS}, +#ifdef OPENSSL_NPN_NEGOTIATED + {"selected_npn_protocol", (PyCFunction)PySSL_selected_npn_protocol, METH_NOARGS}, +#endif {"compression", (PyCFunction)PySSL_compression, METH_NOARGS}, {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS, PySSL_SSLshutdown_doc}, @@ -1597,6 +1618,9 @@ static void context_dealloc(PySSLContext *self) { SSL_CTX_free(self->ctx); +#ifdef OPENSSL_NPN_NEGOTIATED + PyMem_Free(self->npn_protocols); +#endif Py_TYPE(self)->tp_free(self); } @@ -1621,6 +1645,87 @@ set_ciphers(PySSLContext *self, PyObject Py_RETURN_NONE; } +#ifdef OPENSSL_NPN_NEGOTIATED +/* this callback gets passed to SSL_CTX_set_next_protos_advertise_cb */ +static int +_advertiseNPN_cb(SSL *s, + const unsigned char **data, unsigned int *len, + void *args) +{ + PySSLContext *ssl_ctx = (PySSLContext *) args; + + if (ssl_ctx->npn_protocols == NULL) { + *data = (unsigned char *) ""; + *len = 0; + } else { + *data = (unsigned char *) ssl_ctx->npn_protocols; + *len = ssl_ctx->npn_protocols_len; + } + + return SSL_TLSEXT_ERR_OK; +} +/* this callback gets passed to SSL_CTX_set_next_proto_select_cb */ +static int +_selectNPN_cb(SSL *s, + unsigned char **out, unsigned char *outlen, + const unsigned char *server, unsigned int server_len, + void *args) +{ + PySSLContext *ssl_ctx = (PySSLContext *) args; + + unsigned char *client = (unsigned char *) ssl_ctx->npn_protocols; + int client_len; + + if (client == NULL) { + client = (unsigned char *) ""; + client_len = 0; + } else { + client_len = ssl_ctx->npn_protocols_len; + } + + SSL_select_next_proto(out, outlen, + server, server_len, + client, client_len); + + return SSL_TLSEXT_ERR_OK; +} +#endif + +static PyObject * +_set_npn_protocols(PySSLContext *self, PyObject *args) +{ +#ifdef OPENSSL_NPN_NEGOTIATED + Py_buffer protos; + + if (!PyArg_ParseTuple(args, "y*:set_npn_protocols", &protos)) + return NULL; + + self->npn_protocols = PyMem_Malloc(protos.len); + if (self->npn_protocols == NULL) { + PyBuffer_Release(&protos); + return PyErr_NoMemory(); + } + memcpy(self->npn_protocols, protos.buf, protos.len); + self->npn_protocols_len = (int) protos.len; + + /* set both server and client callbacks, because the context can + * be used to create both types of sockets */ + SSL_CTX_set_next_protos_advertised_cb(self->ctx, + _advertiseNPN_cb, + self); + SSL_CTX_set_next_proto_select_cb(self->ctx, + _selectNPN_cb, + self); + + PyBuffer_Release(&protos); + Py_RETURN_NONE; +#else + PyErr_SetString(PyExc_NotImplementedError, + "The NPN extension requires OpenSSL 1.0.1 or later."); + return NULL; +#endif +} + static PyObject * get_verify_mode(PySSLContext *self, void *c) { @@ -2097,6 +2202,8 @@ static struct PyMethodDef context_method METH_VARARGS | METH_KEYWORDS, NULL}, {"set_ciphers", (PyCFunction) set_ciphers, METH_VARARGS, NULL}, + {"_set_npn_protocols", (PyCFunction) _set_npn_protocols, + METH_VARARGS, NULL}, {"load_cert_chain", (PyCFunction) load_cert_chain, METH_VARARGS | METH_KEYWORDS, NULL}, {"load_dh_params", (PyCFunction) load_dh_params, @@ -2590,6 +2697,14 @@ PyInit__ssl(void) Py_INCREF(r); PyModule_AddObject(m, "HAS_ECDH", r); +#ifdef OPENSSL_NPN_NEGOTIATED + r = Py_True; +#else + r = Py_False; +#endif + Py_INCREF(r); + PyModule_AddObject(m, "HAS_NPN", r); + /* OpenSSL version */ /* SSLeay() gives us the version of the library linked against, which could be different from the headers version.