Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(12808)

Side by Side Diff: Lib/unittest/case.py

Issue 15836: unittest assertRaises should verify excClass is actually a BaseException class
Patch Set: Created 4 years, 4 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
View unified diff | Download patch
« no previous file with comments | « Lib/test/test_importlib/builtin/test_loader.py ('k') | Lib/unittest/test/test_case.py » ('j') | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
1 """Test case implementation""" 1 """Test case implementation"""
2 2
3 import sys 3 import sys
4 import functools 4 import functools
5 import difflib 5 import difflib
6 import logging 6 import logging
7 import pprint 7 import pprint
8 import re 8 import re
9 import warnings 9 import warnings
10 import collections 10 import collections
(...skipping 101 matching lines...) Expand 10 before | Expand all | Expand 10 after
112 Skip a test unless the condition is true. 112 Skip a test unless the condition is true.
113 """ 113 """
114 if not condition: 114 if not condition:
115 return skip(reason) 115 return skip(reason)
116 return _id 116 return _id
117 117
118 def expectedFailure(test_item): 118 def expectedFailure(test_item):
119 test_item.__unittest_expecting_failure__ = True 119 test_item.__unittest_expecting_failure__ = True
120 return test_item 120 return test_item
121 121
122 def _is_subtype(expected, basetype):
123 if isinstance(expected, tuple):
124 return all(_is_subtype(e, basetype) for e in expected)
125 return isinstance(expected, type) and issubclass(expected, basetype)
122 126
123 class _BaseTestCaseContext: 127 class _BaseTestCaseContext:
124 128
125 def __init__(self, test_case): 129 def __init__(self, test_case):
126 self.test_case = test_case 130 self.test_case = test_case
127 131
128 def _raiseFailure(self, standardMsg): 132 def _raiseFailure(self, standardMsg):
129 msg = self.test_case._formatMessage(self.msg, standardMsg) 133 msg = self.test_case._formatMessage(self.msg, standardMsg)
130 raise self.test_case.failureException(msg) 134 raise self.test_case.failureException(msg)
131 135
132 class _AssertRaisesBaseContext(_BaseTestCaseContext): 136 class _AssertRaisesBaseContext(_BaseTestCaseContext):
133 137
134 def __init__(self, expected, test_case, expected_regex=None): 138 def __init__(self, expected, test_case, expected_regex=None):
135 _BaseTestCaseContext.__init__(self, test_case) 139 _BaseTestCaseContext.__init__(self, test_case)
136 self.expected = expected 140 self.expected = expected
137 self.test_case = test_case 141 self.test_case = test_case
138 if expected_regex is not None: 142 if expected_regex is not None:
139 expected_regex = re.compile(expected_regex) 143 expected_regex = re.compile(expected_regex)
140 self.expected_regex = expected_regex 144 self.expected_regex = expected_regex
141 self.obj_name = None 145 self.obj_name = None
142 self.msg = None 146 self.msg = None
143 147
144 def handle(self, name, args, kwargs): 148 def handle(self, name, args, kwargs):
145 """ 149 """
146 If args is empty, assertRaises/Warns is being used as a 150 If args is empty, assertRaises/Warns is being used as a
147 context manager, so check for a 'msg' kwarg and return self. 151 context manager, so check for a 'msg' kwarg and return self.
148 If args is not empty, call a callable passing positional and keyword 152 If args is not empty, call a callable passing positional and keyword
149 arguments. 153 arguments.
150 """ 154 """
155 if not _is_subtype(self.expected, self._base_type):
Martin Panter 2015/05/19 12:37:58 Why is this check moved from __init__() to handle(
storchaka 2015/05/19 12:47:33 Only because the name parameter is passed to handl
Martin Panter 2015/05/19 13:27:57 Fair enough, that makes sense
156 raise TypeError('%s() arg 1 must be %s' %
157 (name, self._base_type_str))
151 if args and args[0] is None: 158 if args and args[0] is None:
152 warnings.warn("callable is None", 159 warnings.warn("callable is None",
153 DeprecationWarning, 3) 160 DeprecationWarning, 3)
154 args = () 161 args = ()
155 if not args: 162 if not args:
156 self.msg = kwargs.pop('msg', None) 163 self.msg = kwargs.pop('msg', None)
157 if kwargs: 164 if kwargs:
158 warnings.warn('%r is an invalid keyword argument for ' 165 warnings.warn('%r is an invalid keyword argument for '
159 'this function' % next(iter(kwargs)), 166 'this function' % next(iter(kwargs)),
160 DeprecationWarning, 3) 167 DeprecationWarning, 3)
161 return self 168 return self
162 169
163 callable_obj, *args = args 170 callable_obj, *args = args
164 try: 171 try:
165 self.obj_name = callable_obj.__name__ 172 self.obj_name = callable_obj.__name__
166 except AttributeError: 173 except AttributeError:
167 self.obj_name = str(callable_obj) 174 self.obj_name = str(callable_obj)
168 with self: 175 with self:
169 callable_obj(*args, **kwargs) 176 callable_obj(*args, **kwargs)
170 177
171 178
172 class _AssertRaisesContext(_AssertRaisesBaseContext): 179 class _AssertRaisesContext(_AssertRaisesBaseContext):
173 """A context manager used to implement TestCase.assertRaises* methods.""" 180 """A context manager used to implement TestCase.assertRaises* methods."""
181
182 _base_type = BaseException
183 _base_type_str = 'an exception type or tuple of exception types'
174 184
175 def __enter__(self): 185 def __enter__(self):
176 return self 186 return self
177 187
178 def __exit__(self, exc_type, exc_value, tb): 188 def __exit__(self, exc_type, exc_value, tb):
179 if exc_type is None: 189 if exc_type is None:
180 try: 190 try:
181 exc_name = self.expected.__name__ 191 exc_name = self.expected.__name__
182 except AttributeError: 192 except AttributeError:
183 exc_name = str(self.expected) 193 exc_name = str(self.expected)
(...skipping 14 matching lines...) Expand all
198 208
199 expected_regex = self.expected_regex 209 expected_regex = self.expected_regex
200 if not expected_regex.search(str(exc_value)): 210 if not expected_regex.search(str(exc_value)):
201 self._raiseFailure('"{}" does not match "{}"'.format( 211 self._raiseFailure('"{}" does not match "{}"'.format(
202 expected_regex.pattern, str(exc_value))) 212 expected_regex.pattern, str(exc_value)))
203 return True 213 return True
204 214
205 215
206 class _AssertWarnsContext(_AssertRaisesBaseContext): 216 class _AssertWarnsContext(_AssertRaisesBaseContext):
207 """A context manager used to implement TestCase.assertWarns* methods.""" 217 """A context manager used to implement TestCase.assertWarns* methods."""
218
219 _base_type = Warning
220 _base_type_str = 'a warning type or tuple of warning types'
208 221
209 def __enter__(self): 222 def __enter__(self):
210 # The __warningregistry__'s need to be in a pristine state for tests 223 # The __warningregistry__'s need to be in a pristine state for tests
211 # to work properly. 224 # to work properly.
212 for v in sys.modules.values(): 225 for v in sys.modules.values():
213 if getattr(v, '__warningregistry__', None): 226 if getattr(v, '__warningregistry__', None):
214 v.__warningregistry__ = {} 227 v.__warningregistry__ = {}
215 self.warnings_manager = warnings.catch_warnings(record=True) 228 self.warnings_manager = warnings.catch_warnings(record=True)
216 self.warnings = self.warnings_manager.__enter__() 229 self.warnings = self.warnings_manager.__enter__()
217 warnings.simplefilter("always", self.expected) 230 warnings.simplefilter("always", self.expected)
(...skipping 1172 matching lines...) Expand 10 before | Expand all | Expand 10 after
1390 return "{} {}".format(self.test_case.id(), self._subDescription()) 1403 return "{} {}".format(self.test_case.id(), self._subDescription())
1391 1404
1392 def shortDescription(self): 1405 def shortDescription(self):
1393 """Returns a one-line description of the subtest, or None if no 1406 """Returns a one-line description of the subtest, or None if no
1394 description has been provided. 1407 description has been provided.
1395 """ 1408 """
1396 return self.test_case.shortDescription() 1409 return self.test_case.shortDescription()
1397 1410
1398 def __str__(self): 1411 def __str__(self):
1399 return "{} {}".format(self.test_case, self._subDescription()) 1412 return "{} {}".format(self.test_case, self._subDescription())
OLDNEW
« no previous file with comments | « Lib/test/test_importlib/builtin/test_loader.py ('k') | Lib/unittest/test/test_case.py » ('j') | no next file with comments »

RSS Feeds Recent Issues | This issue
This is Rietveld 894c83f36cb7+