diff -r 6b0ca4cb7e4e Lib/test/test_wsgiref.py --- a/Lib/test/test_wsgiref.py Sat Feb 23 19:37:01 2013 +0100 +++ b/Lib/test/test_wsgiref.py Mon Feb 25 13:10:18 2013 -0600 @@ -330,6 +330,7 @@ for alt in hop, hop.title(), hop.upper(), hop.lower(): self.assertFalse(util.is_hop_by_hop(alt)) + class HeaderTests(TestCase): def testMappingInterface(self): @@ -366,7 +367,6 @@ def testRequireList(self): self.assertRaises(TypeError, Headers, "foo") - def testExtras(self): h = Headers([]) self.assertEqual(str(h),'\r\n') @@ -385,6 +385,16 @@ '\r\n' ) + def testDisallowNewlines(self): + h = Headers([]) + self.assertRaises(AssertionError, h.add_header, 'foo', 'bar\rbaz: bat') + self.assertRaises(AssertionError, h.add_header, 'foo', 'bar\nbaz: bat') + self.assertRaises(AssertionError, h.add_header, 'foo:\rbar', 'baz') + self.assertRaises(AssertionError, h.add_header, 'foo:\nbar', 'baz') + self.assertRaises(AssertionError, h.add_header, 'foo', 'bar', baz='bat\rqux: spam') + self.assertRaises(AssertionError, h.add_header, 'foo', 'bar', baz='bat\nqux: spam') + + class ErrorHandler(BaseCGIHandler): """Simple handler subclass for testing BaseHandler""" diff -r 6b0ca4cb7e4e Lib/wsgiref/headers.py --- a/Lib/wsgiref/headers.py Sat Feb 23 19:37:01 2013 +0100 +++ b/Lib/wsgiref/headers.py Mon Feb 25 13:10:18 2013 -0600 @@ -10,6 +10,10 @@ import re tspecials = re.compile(r'[ \(\)<>@,;:\\"/\[\]\?=]') +# Regular expression that matches characters not allowed in headers, per +# the WSGI spec. +bad_header_value_re = re.compile(r'[\000-\037]') + def _formatparam(param, value=None, quote=1): """Convenience function to format and return a key=value pair. @@ -40,10 +44,14 @@ def _convert_string_type(self, value): """Convert/check value type.""" - if type(value) is str: - return value - raise AssertionError("Header names/values must be" - " of type str (got {0})".format(repr(value))) + if type(value) is not str: + raise AssertionError("Header names/values must be" + " of type str (got {0})".format(repr(value))) + bad_match = bad_header_value_re.search(value) + if bad_match: + error_str = "Bad header value: {0!r} (bad char: {1!r})" + raise AssertionError(error_str.format(value, bad_match.group(0))) + return value def __len__(self): """Return the total number of headers, including duplicates.""" @@ -52,8 +60,7 @@ def __setitem__(self, name, val): """Set the value of a header.""" del self[name] - self._headers.append( - (self._convert_string_type(name), self._convert_string_type(val))) + self.add_header(name, val) def __delitem__(self,name): """Delete all occurrences of a header, if present. @@ -148,8 +155,7 @@ and value 'value'.""" result = self.get(name) if result is None: - self._headers.append((self._convert_string_type(name), - self._convert_string_type(value))) + self.add_header(name, value) return value else: return result @@ -181,4 +187,5 @@ else: v = self._convert_string_type(v) parts.append(_formatparam(k.replace('_', '-'), v)) - self._headers.append((self._convert_string_type(_name), "; ".join(parts))) + self._headers.append((self._convert_string_type(_name), + "; ".join(parts)))