Index: Lib/http/client.py =================================================================== --- Lib/http/client.py (révision 85322) +++ Lib/http/client.py (copie de travail) @@ -1047,13 +1047,29 @@ default_port = HTTPS_PORT + # XXX Should key_file and cert_file be deprecated in favour of context? + def __init__(self, host, port=None, key_file=None, cert_file=None, strict=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, - source_address=None): + source_address=None, *, context=None, check_hostname=None): super(HTTPSConnection, self).__init__(host, port, strict, timeout, source_address) self.key_file = key_file self.cert_file = cert_file + if context is None: + # Some reasonable defaults + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.options |= ssl.OP_NO_SSLv2 + will_verify = context.verify_mode != ssl.CERT_NONE + if check_hostname is None: + check_hostname = will_verify + elif check_hostname and not will_verify: + raise ValueError("check_hostname needs a SSL context with " + "either CERT_OPTIONAL or CERT_REQUIRED") + if key_file or cert_file: + context.load_cert_chain(certfile, keyfile) + self._context = context + self._check_hostname = check_hostname def connect(self): "Connect to a host on a given (SSL) port." @@ -1065,7 +1081,9 @@ self.sock = sock self._tunnel() - self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file) + self.sock = self._context.wrap_socket(sock) + if self._check_hostname: + ssl.match_hostname(self.sock.getpeercert(), self.host) __all__.append("HTTPSConnection") Index: Lib/test/test_httplib.py =================================================================== --- Lib/test/test_httplib.py (révision 85322) +++ Lib/test/test_httplib.py (copie de travail) @@ -1,6 +1,7 @@ import errno from http import client import io +import os import array import socket @@ -370,15 +371,60 @@ self.assertEqual(httpConn.sock.gettimeout(), 30) httpConn.close() -class HTTPSTimeoutTest(TestCase): -# XXX Here should be tests for HTTPS, there isn't any right now! +class HTTPSTest(TestCase): + def setUp(self): + if not hasattr(client, 'HTTPSConnection'): + self.skipTest('ssl support required') + def test_attributes(self): - # simple test to check it's storing it - if hasattr(client, 'HTTPSConnection'): - h = client.HTTPSConnection(HOST, TimeoutTest.PORT, timeout=30) - self.assertEqual(h.timeout, 30) + # simple test to check it's storing the timeout + h = client.HTTPSConnection(HOST, TimeoutTest.PORT, timeout=30) + self.assertEqual(h.timeout, 30) + def _check_svn_python_org(self, resp): + # Just a simple check that everything went fine + server_string = resp.getheader('server') + self.assertIn('Apache', server_string) + + def test_networked(self): + # Default settings: no cert verification is done + support.requires('network') + with support.transient_internet('svn.python.org'): + h = client.HTTPSConnection('svn.python.org', 443) + h.request('GET', '/') + resp = h.getresponse() + self._check_svn_python_org(resp) + + def test_networked_goodcert(self): + # We feed a CA cert that validates the server's cert + import ssl + support.requires('network') + with support.transient_internet('svn.python.org'): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + here = os.path.dirname(__file__) + context.load_verify_locations( + os.path.join(here, 'https_svn_python_org_root.pem')) + h = client.HTTPSConnection('svn.python.org', 443, context=context) + h.request('GET', '/') + resp = h.getresponse() + self._check_svn_python_org(resp) + + def test_networked_badcert(self): + # We feed a "CA" cert that is unrelated to the server's cert + import ssl + support.requires('network') + with support.transient_internet('svn.python.org'): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + here = os.path.dirname(__file__) + context.load_verify_locations(os.path.join(here, 'keycert.pem')) + h = client.HTTPSConnection('svn.python.org', 443, context=context) + with self.assertRaises(ssl.SSLError): + h.request('GET', '/') + + class RequestBodyTest(TestCase): """Test cases where a request includes a message body.""" @@ -488,7 +534,7 @@ def test_main(verbose=None): support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, - HTTPSTimeoutTest, RequestBodyTest, SourceAddressTest, + HTTPSTest, RequestBodyTest, SourceAddressTest, HTTPResponseTest) if __name__ == '__main__':