diff --git a/Lib/httplib.py b/Lib/httplib.py --- a/Lib/httplib.py +++ b/Lib/httplib.py @@ -700,17 +700,33 @@ self._tunnel_host = None self._tunnel_port = None self._tunnel_headers = {} - - self._set_hostport(host, port) if strict is not None: self.strict = strict + (self.host, self.port) = self._get_hostport(host, port) + + # This is stored as an instance variable to allow unittests + # to replace with a suitable mock + self._create_connection = socket.create_connection + def set_tunnel(self, host, port=None, headers=None): - """ Sets up the host and the port for the HTTP CONNECT Tunnelling. + """ Set up host and port for HTTP CONNECT tunnelling. + + In a connection that uses HTTP Connect tunneling, the host passed to the + constructor is used as proxy server that relays all communication to the + endpoint passed to set_tunnel. This is done by sending a HTTP CONNECT + request to the proxy server when the connection is established. + + This method must be called before the HTML connection has been + established. The headers argument should be a mapping of extra HTTP headers to send with the CONNECT request. """ + # Verify if this is required. + if self.sock: + raise RuntimeError("Can't setup tunnel for established connection.") + self._tunnel_host = host self._tunnel_port = port if headers: @@ -718,7 +734,7 @@ else: self._tunnel_headers.clear() - def _set_hostport(self, host, port): + def _get_hostport(self, host, port): if port is None: i = host.rfind(':') j = host.rfind(']') # ipv6 addresses have [...] @@ -735,15 +751,14 @@ port = self.default_port if host and host[0] == '[' and host[-1] == ']': host = host[1:-1] - self.host = host - self.port = port + return (host, port) def set_debuglevel(self, level): self.debuglevel = level def _tunnel(self): - self._set_hostport(self._tunnel_host, self._tunnel_port) - self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port)) + (host, port) = self._get_hostport(self._tunnel_host, self._tunnel_port) + self.send("CONNECT %s:%d HTTP/1.0\r\n" % (host, port)) for header, value in self._tunnel_headers.iteritems(): self.send("%s: %s\r\n" % (header, value)) self.send("\r\n") @@ -768,8 +783,8 @@ def connect(self): """Connect to the host and port specified in __init__.""" - self.sock = socket.create_connection((self.host,self.port), - self.timeout, self.source_address) + self.sock = self._create_connection((self.host,self.port), + self.timeout, self.source_address) if self._tunnel_host: self._tunnel() @@ -907,17 +922,24 @@ netloc_enc = netloc.encode("idna") self.putheader('Host', netloc_enc) else: + if self._tunnel_host: + host = self._tunnel_host + port = self._tunnel_port + else: + host = self.host + port = self.port + try: - host_enc = self.host.encode("ascii") + host_enc = host.encode("ascii") except UnicodeEncodeError: - host_enc = self.host.encode("idna") + host_enc = host.encode("idna") # Wrap the IPv6 Host Header with [] (RFC 2732) if host_enc.find(':') >= 0: host_enc = "[" + host_enc + "]" - if self.port == self.default_port: + if port == self.default_port: self.putheader('Host', host_enc) else: - self.putheader('Host', "%s:%s" % (host_enc, self.port)) + self.putheader('Host', "%s:%s" % (host_enc, port)) # note: we are assuming that clients will not attempt to set these # headers since *this* library must deal with the @@ -1168,8 +1190,8 @@ def connect(self): "Connect to a host on a given (SSL) port." - sock = socket.create_connection((self.host, self.port), - self.timeout, self.source_address) + sock = self._create_connection((self.host, self.port), + self.timeout, self.source_address) if self._tunnel_host: self.sock = sock self._tunnel() diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -13,10 +13,12 @@ HOST = test_support.HOST class FakeSocket: - def __init__(self, text, fileclass=StringIO.StringIO): + def __init__(self, text, fileclass=StringIO.StringIO, host=None, port=None): self.text = text self.fileclass = fileclass self.data = '' + self.host = host + self.port = port def sendall(self, data): self.data += ''.join(data) @@ -26,6 +28,9 @@ raise httplib.UnimplementedFileMode() return self.fileclass(self.text) + def close(self): + pass + class EPipeSocket(FakeSocket): def __init__(self, text, pipe_trigger): @@ -526,9 +531,48 @@ self.fail("Port incorrectly parsed: %s != %s" % (p, c.host)) +class TunnelTests(TestCase): + def test_connect(self): + response_text = ( + 'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT + 'HTTP/1.1 200 OK\r\n' # Reply to HEAD + 'Content-Length: 42\r\n\r\n' + ) + + def create_connection(address, timeout=None, source_address=None): + return FakeSocket(response_text, host=address[0], port=address[1]) + + conn = httplib.HTTPConnection('proxy.com') + conn._create_connection = create_connection + + # Once connected, we should not be able to tunnel anymore + conn.connect() + self.assertRaises(RuntimeError, conn.set_tunnel, 'destination.com') + + # But if close the connection, we are good. + conn.close() + conn.set_tunnel('destination.com') + conn.request('HEAD', '/', '') + + self.assertEqual(conn.sock.host, 'proxy.com') + self.assertEqual(conn.sock.port, 80) + self.assertTrue('CONNECT destination.com' in conn.sock.data) + self.assertTrue('Host: destination.com' in conn.sock.data) + + self.assertTrue('Host: proxy.com' not in conn.sock.data) + + conn.close() + + conn.request('PUT', '/', '') + self.assertEqual(conn.sock.host, 'proxy.com') + self.assertEqual(conn.sock.port, 80) + self.assertTrue('CONNECT destination.com' in conn.sock.data) + self.assertTrue('Host: destination.com' in conn.sock.data) + + def test_main(verbose=None): test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, - HTTPSTimeoutTest, SourceAddressTest) + HTTPSTimeoutTest, SourceAddressTest, TunnelTests) if __name__ == '__main__': test_main() diff --git a/Misc/NEWS b/Misc/NEWS --- a/Misc/NEWS +++ b/Misc/NEWS @@ -49,6 +49,10 @@ Library ------- +- Issue #7776: Backport Fix ``Host:'' header and reconnection when using + http.client.HTTPConnection.set_tunnel() from Python 3. + Patch by Nikolaus Rath. + - Issue #21306: Backport hmac.compare_digest from Python 3. This is part of PEP 466.