diff -r 53e94a687570 Lib/collections/__init__.py --- a/Lib/collections/__init__.py Tue Jan 27 11:10:18 2015 -0500 +++ b/Lib/collections/__init__.py Tue Jan 27 17:48:49 2015 -0800 @@ -1,5 +1,6 @@ __all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList', - 'UserString', 'Counter', 'OrderedDict', 'ChainMap'] + 'UserString', 'Counter', 'OrderedDict', 'ChainMap', + 'TransformDict'] # For backwards compatibility, continue to make the collections ABCs # available through the collections module. @@ -913,6 +914,136 @@ def clear(self): 'Clear maps[0], leaving maps[1:] intact.' self.maps[0].clear() + + +######################################################################## +### TransformDict +######################################################################## + +_sentinel = object() + +class TransformDict(MutableMapping): + '''Dictionary that calls a transformation function when looking + up keys, but preserves the original keys. + + >>> d = TransformDict(str.lower) + >>> d['Foo'] = 5 + >>> d['foo'] == d['FOO'] == d['Foo'] == 5 + True + >>> set(d.keys()) + {'Foo'} + ''' + + __slots__ = ('_transform', '_original', '_data') + + def __init__(self, transform, init_dict=None, **kwargs): + '''Create a new TransformDict with the given *transform* function. + *init_dict* and *kwargs* are optional initializers, as in the + dict constructor. + ''' + if not callable(transform): + raise TypeError("expected a callable, got %r" % transform.__class__) + self._transform = transform + # transformed => original + self._original = {} + self._data = {} + if init_dict: + self.update(init_dict) + if kwargs: + self.update(kwargs) + + def getitem(self, key): + 'D.getitem(key) -> (stored key, value)' + transformed = self._transform(key) + original = self._original[transformed] + value = self._data[transformed] + return original, value + + @property + def transform_func(self): + "This TransformDict's transformation function" + return self._transform + + # Minimum set of methods required for MutableMapping + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._original.values()) + + def __getitem__(self, key): + return self._data[self._transform(key)] + + def __setitem__(self, key, value): + transformed = self._transform(key) + self._data[transformed] = value + self._original.setdefault(transformed, key) + + def __delitem__(self, key): + transformed = self._transform(key) + del self._data[transformed] + del self._original[transformed] + + # Methods overriden to mitigate the performance overhead. + + def clear(self): + 'D.clear() -> None. Remove all items from D.' + self._data.clear() + self._original.clear() + + def __contains__(self, key): + return self._transform(key) in self._data + + def get(self, key, default=None): + 'D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.' + return self._data.get(self._transform(key), default) + + def pop(self, key, default=_sentinel): + '''D.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised. + ''' + transformed = self._transform(key) + if default is _sentinel: + del self._original[transformed] + return self._data.pop(transformed) + else: + self._original.pop(transformed, None) + return self._data.pop(transformed, default) + + def popitem(self): + '''D.popitem() -> (k, v), remove and return some (key, value) pair + as a 2-tuple; but raise KeyError if D is empty. + ''' + transformed, value = self._data.popitem() + return self._original.pop(transformed), value + + # Other methods + + def copy(self): + 'D.copy() -> a shallow copy of D' + other = self.__class__(self._transform) + other._original = self._original.copy() + other._data = self._data.copy() + return other + + __copy__ = copy + + def __getstate__(self): + return (self._transform, self._data, self._original) + + def __setstate__(self, state): + self._transform, self._data, self._original = state + + def __repr__(self): + try: + equiv = dict(self) + except TypeError: + # Some keys are unhashable, fall back on .items() + equiv = list(self.items()) + return '%s(%r, %s)' % (self.__class__.__name__, + self._transform, repr(equiv)) + ################################################################################ diff -r 53e94a687570 Lib/http/client.py --- a/Lib/http/client.py Tue Jan 27 11:10:18 2015 -0500 +++ b/Lib/http/client.py Tue Jan 27 17:48:49 2015 -0800 @@ -73,6 +73,8 @@ import os import socket import collections + +from http import protocol from urllib.parse import urlsplit __all__ = ["HTTPResponse", "HTTPConnection", @@ -108,6 +110,34 @@ _MAXHEADERS = 100 +class HTTPRequest: + def __init__(self, path, method='GET', headers=None, body=None): + self.path = path + self.method = method + self.body = body + self.headers = collections.TransformDict( + lambda k: k.lower(), headers or {}) + + @property + def content_length(self): + if isinstance(self.body, str) or isinstance(self.body, bytes): + return len(self.body) + elif isinstance(self.body, list): + size = 0 + for line in self.body: + size += len(line) + return size + elif hasattr(self.body, 'read'): + try: + return os.fstat(self.body.fileno()).st_size + except (AttributeError, OSError): + logger.warn('Unable to stat file') + + raise ValueError( + 'Unable to determine content length for {}'.format( + self.path)) + + class HTTPMessage(email.message.Message): # XXX The only usage of this method is in # http.server.CGIHTTPRequestHandler. Maybe move the code there so @@ -683,17 +713,18 @@ debuglevel = 0 def __init__(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, - source_address=None): + source_address=None, proto=None): self.timeout = timeout self.source_address = source_address self.sock = None - self._buffer = [] self.__response = None self.__state = _CS_IDLE self._method = None self._tunnel_host = None self._tunnel_port = None self._tunnel_headers = {} + self._protocol = proto or protocol.HTTP11() + self._request = None (self.host, self.port) = self._get_hostport(host, port) @@ -843,28 +874,7 @@ raise TypeError("data should be a bytes-like object " "or an iterable, got %r" % type(data)) - def _output(self, s): - """Add a line of output to the current request buffer. - - Assumes that the line does *not* end with \\r\\n. - """ - self._buffer.append(s) - - def _send_output(self, message_body=None): - """Send the currently buffered request and clear the buffer. - - Appends an extra \\r\\n to the buffer. - A message_body may be specified, to be appended to the request. - """ - self._buffer.extend((b"", b"")) - msg = b"\r\n".join(self._buffer) - del self._buffer[:] - - self.send(msg) - if message_body is not None: - self.send(message_body) - - def putrequest(self, method, url, skip_host=0, skip_accept_encoding=0): + def putrequest(self, method, url): """Send a request to the server. `method' specifies an HTTP request method, e.g. 'GET'. @@ -878,7 +888,6 @@ if self.__response and self.__response.isclosed(): self.__response = None - # in certain cases, we cannot issue another request on this connection. # this occurs when: # 1) we are in the process of sending a request. (_CS_REQ_STARTED) @@ -902,90 +911,7 @@ else: raise CannotSendRequest(self.__state) - # Save the method we use, we need it later in the response phase - self._method = method - if not url: - url = '/' - request = '%s %s %s' % (method, url, self._http_vsn_str) - - # Non-ASCII characters should have been eliminated earlier - self._output(request.encode('ascii')) - - if self._http_vsn == 11: - # Issue some standard headers for better HTTP/1.1 compliance - - if not skip_host: - # this header is issued *only* for HTTP/1.1 - # connections. more specifically, this means it is - # only issued when the client uses the new - # HTTPConnection() class. backwards-compat clients - # will be using HTTP/1.0 and those clients may be - # issuing this header themselves. we should NOT issue - # it twice; some web servers (such as Apache) barf - # when they see two Host: headers - - # If we need a non-standard port,include it in the - # header. If the request is going through a proxy, - # but the host of the actual URL, not the host of the - # proxy. - - netloc = '' - if url.startswith('http'): - nil, netloc, nil, nil, nil = urlsplit(url) - - if netloc: - try: - netloc_enc = netloc.encode("ascii") - except UnicodeEncodeError: - netloc_enc = netloc.encode("idna") - self.putheader('Host', netloc_enc) - else: - if self._tunnel_host: - host = self._tunnel_host - port = self._tunnel_port - else: - host = self.host - port = self.port - - try: - host_enc = host.encode("ascii") - except UnicodeEncodeError: - host_enc = host.encode("idna") - - # As per RFC 273, IPv6 address should be wrapped with [] - # when used as Host header - - if host.find(':') >= 0: - host_enc = b'[' + host_enc + b']' - - if port == self.default_port: - self.putheader('Host', host_enc) - else: - host_enc = host_enc.decode("ascii") - self.putheader('Host', "%s:%s" % (host_enc, port)) - - # note: we are assuming that clients will not attempt to set these - # headers since *this* library must deal with the - # consequences. this also means that when the supporting - # libraries are updated to recognize other forms, then this - # code should be changed (removed or updated). - - # we only want a Content-Encoding of "identity" since we don't - # support encodings such as x-gzip or x-deflate. - if not skip_accept_encoding: - self.putheader('Accept-Encoding', 'identity') - - # we can accept "chunked" Transfer-Encodings, but no others - # NOTE: no TE header implies *only* "chunked" - #self.putheader('TE', 'chunked') - - # if TE is supplied in the header, then it must appear in a - # Connection header. - #self.putheader('Connection', 'TE') - - else: - # For HTTP/1.0, the server will assume "not chunked" - pass + self._request = HTTPRequest(url, method=method) def putheader(self, header, *values): """Send a request header line to the server. @@ -995,17 +921,7 @@ if self.__state != _CS_REQ_STARTED: raise CannotSendHeader() - if hasattr(header, 'encode'): - header = header.encode('ascii') - values = list(values) - for i, one_value in enumerate(values): - if hasattr(one_value, 'encode'): - values[i] = one_value.encode('latin-1') - elif isinstance(one_value, int): - values[i] = str(one_value).encode('ascii') - value = b'\r\n\t'.join(values) - header = header + b': ' + value - self._output(header) + self._request.headers[header] = list(values) def endheaders(self, message_body=None): """Indicate that the last header line has been sent to the server. @@ -1020,48 +936,44 @@ self.__state = _CS_REQ_SENT else: raise CannotSendHeader() - self._send_output(message_body) + + self._request.body = message_body + + for chunk in self._protocol.serialize( + self._get_host(), self._request): + self.send(chunk) + + def _get_host(self): + netloc = '' + if self._request.path.startswith('http'): + nil, netloc, nil, nil, nil = urlsplit(url) + + if netloc: + host = netloc + else: + if self._tunnel_host: + host = self._tunnel_host + port = self._tunnel_port + else: + host = self.host + port = self.port + + # As per RFC 273, IPv6 address should be wrapped with [] + # when used as Host header + + if host.find(':') >= 0: + host = '[' + host + ']' + + if port != self.default_port: + host = '{}:{}'.format(host, str(port)) + + return host def request(self, method, url, body=None, headers={}): """Send a complete request to the server.""" - self._send_request(method, url, body, headers) - - def _set_content_length(self, body): - # Set the content-length based on the body. - thelen = None - try: - thelen = str(len(body)) - except TypeError as te: - # If this is a file-like object, try to - # fstat its file descriptor - try: - thelen = str(os.fstat(body.fileno()).st_size) - except (AttributeError, OSError): - # Don't send a length if this failed - if self.debuglevel > 0: print("Cannot stat!!") - - if thelen is not None: - self.putheader('Content-Length', thelen) - - def _send_request(self, method, url, body, headers): - # Honor explicitly requested Host: and Accept-Encoding: headers. - header_names = dict.fromkeys([k.lower() for k in headers]) - skips = {} - if 'host' in header_names: - skips['skip_host'] = 1 - if 'accept-encoding' in header_names: - skips['skip_accept_encoding'] = 1 - - self.putrequest(method, url, **skips) - - if body is not None and ('content-length' not in header_names): - self._set_content_length(body) - for hdr, value in headers.items(): - self.putheader(hdr, value) - if isinstance(body, str): - # RFC 2616 Section 3.7.1 says that text default has a - # default charset of iso-8859-1. - body = body.encode('iso-8859-1') + self.putrequest(method, url) + for key, vals in headers.items(): + self.putheader(key, vals) self.endheaders(body) def getresponse(self): diff -r 53e94a687570 Lib/test/test_httplib.py --- a/Lib/test/test_httplib.py Tue Jan 27 11:10:18 2015 -0500 +++ b/Lib/test/test_httplib.py Tue Jan 27 17:48:49 2015 -0800 @@ -8,6 +8,7 @@ import unittest TestCase = unittest.TestCase +from datetime import datetime from test import support here = os.path.dirname(__file__) @@ -111,65 +112,47 @@ # Some headers are added automatically, but should not be added by # .request() if they are explicitly set. - class HeaderCountingBuffer(list): - def __init__(self): - self.count = {} - def append(self, item): - kv = item.split(b':') - if len(kv) > 1: - # item is a 'Key: Value' header string - lcKey = kv[0].decode('ascii').lower() - self.count.setdefault(lcKey, 0) - self.count[lcKey] += 1 - list.append(self, item) - for explicit_header in True, False: - for header in 'Content-length', 'Host', 'Accept-encoding': + for header in 'Content-Length', 'Host', 'Accept-Encoding': conn = client.HTTPConnection('example.com') conn.sock = FakeSocket('blahblahblah') - conn._buffer = HeaderCountingBuffer() body = 'spamspamspam' headers = {} if explicit_header: headers[header] = str(len(body)) conn.request('POST', '/', body, headers) - self.assertEqual(conn._buffer.count[header.lower()], 1) + + self.assertIn(header.encode('ascii'), conn.sock.data) def test_content_length_0(self): - class ContentLengthChecker(list): - def __init__(self): - list.__init__(self) - self.content_length = None - def append(self, item): - kv = item.split(b':', 1) - if len(kv) > 1 and kv[0].lower() == b'content-length': - self.content_length = kv[1].strip() - list.append(self, item) - # POST with empty body conn = client.HTTPConnection('example.com') conn.sock = FakeSocket(None) - conn._buffer = ContentLengthChecker() conn.request('POST', '/', '') - self.assertEqual(conn._buffer.content_length, b'0', - 'Header Content-Length not set') + self.assertIn(b'Content-Length: 0', conn.sock.data) # PUT request with empty body conn = client.HTTPConnection('example.com') conn.sock = FakeSocket(None) - conn._buffer = ContentLengthChecker() conn.request('PUT', '/', '') - self.assertEqual(conn._buffer.content_length, b'0', - 'Header Content-Length not set') + self.assertIn(b'Content-Length: 0', conn.sock.data) def test_putheader(self): conn = client.HTTPConnection('example.com') conn.sock = FakeSocket(None) conn.putrequest('GET','/') - conn.putheader('Content-length', 42) - self.assertIn(b'Content-length: 42', conn._buffer) + conn.putheader('Content-Length', 42) + conn.endheaders() + self.assertIn(b'Content-Length: 42', conn.sock.data) + + def test_invalid_header_value(self): + conn = client.HTTPConnection('example.com') + conn.sock = FakeSocket(None) + conn.putrequest('GET', '/') + conn.putheader('Content-Length', 42, datetime.now()) + self.assertRaises(ValueError, conn.endheaders) def test_ipv6host_header(self): # Default host header on IPv6 transaction should wrapped by [] if