Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(89468)

Delta Between Two Patch Sets: Lib/test/test_functools.py

Issue 16510: Using appropriate checks in tests
Left Patch Set: Created 5 years, 12 months ago
Right Patch Set: Created 5 years, 6 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « Lib/test/test_funcattrs.py ('k') | Lib/test/test_gc.py » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 import abc
1 import collections 2 import collections
2 import sys 3 from itertools import permutations
3 import unittest
4 from test import support
5 from weakref import proxy
6 import pickle 4 import pickle
7 from random import choice 5 from random import choice
6 import sys
7 from test import support
8 import unittest
9 from weakref import proxy
8 10
9 import functools 11 import functools
10 12
11 original_functools = functools
12 py_functools = support.import_fresh_module('functools', blocked=['_functools']) 13 py_functools = support.import_fresh_module('functools', blocked=['_functools'])
13 c_functools = support.import_fresh_module('functools', fresh=['_functools']) 14 c_functools = support.import_fresh_module('functools', fresh=['_functools'])
14 15
15 class BaseTest(unittest.TestCase): 16 decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
16 17
17 """Base class required for testing C and Py implementations."""
18
19 def setUp(self):
20
21 # The module must be explicitly set so that the proper
22 # interaction between the c module and the python module
23 # can be controlled.
24 self.partial = self.module.partial
25 super(BaseTest, self).setUp()
26
27 class BaseTestC(BaseTest):
28 module = c_functools
29
30 class BaseTestPy(BaseTest):
31 module = py_functools
32
33 PythonPartial = py_functools.partial
34 18
35 def capture(*args, **kw): 19 def capture(*args, **kw):
36 """capture all positional and keyword arguments""" 20 """capture all positional and keyword arguments"""
37 return args, kw 21 return args, kw
38 22
23
39 def signature(part): 24 def signature(part):
40 """ return the signature of a partial object """ 25 """ return the signature of a partial object """
41 return (part.func, part.args, part.keywords, part.__dict__) 26 return (part.func, part.args, part.keywords, part.__dict__)
42 27
43 class TestPartial(object): 28
44 29 class TestPartial:
45 partial = functools.partial
46 30
47 def test_basic_examples(self): 31 def test_basic_examples(self):
48 p = self.partial(capture, 1, 2, a=10, b=20) 32 p = self.partial(capture, 1, 2, a=10, b=20)
49 self.assertTrue(callable(p)) 33 self.assertTrue(callable(p))
50 self.assertEqual(p(3, 4, b=30, c=40), 34 self.assertEqual(p(3, 4, b=30, c=40),
51 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
52 p = self.partial(map, lambda x: x*10) 36 p = self.partial(map, lambda x: x*10)
53 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 37 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
54 38
55 def test_attributes(self): 39 def test_attributes(self):
56 p = self.partial(capture, 1, 2, a=10, b=20) 40 p = self.partial(capture, 1, 2, a=10, b=20)
57 # attributes should be readable 41 # attributes should be readable
58 self.assertEqual(p.func, capture) 42 self.assertEqual(p.func, capture)
59 self.assertEqual(p.args, (1, 2)) 43 self.assertEqual(p.args, (1, 2))
60 self.assertEqual(p.keywords, dict(a=10, b=20)) 44 self.assertEqual(p.keywords, dict(a=10, b=20))
61 # attributes should not be writable
62 if not isinstance(self.partial, type):
63 return
64 self.assertRaises(AttributeError, setattr, p, 'func', map)
65 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
66 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2) )
67
68 p = self.partial(hex)
69 try:
70 del p.__dict__
71 except TypeError:
72 pass
73 else:
74 self.fail('partial object allowed __dict__ to be deleted')
75 45
76 def test_argument_checking(self): 46 def test_argument_checking(self):
77 self.assertRaises(TypeError, self.partial) # need at least a func ar g 47 self.assertRaises(TypeError, self.partial) # need at least a func ar g
78 try: 48 try:
79 self.partial(2)() 49 self.partial(2)()
80 except TypeError: 50 except TypeError:
81 pass 51 pass
82 else: 52 else:
83 self.fail('First arg not checked for callability') 53 self.fail('First arg not checked for callability')
84 54
(...skipping 29 matching lines...) Expand all
114 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
115 # keyword args in the call override those in the partial object 85 # keyword args in the call override those in the partial object
116 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
117 87
118 def test_positional(self): 88 def test_positional(self):
119 # make sure positional arguments are captured correctly 89 # make sure positional arguments are captured correctly
120 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
121 p = self.partial(capture, *args) 91 p = self.partial(capture, *args)
122 expected = args + ('x',) 92 expected = args + ('x',)
123 got, empty = p('x') 93 got, empty = p('x')
124 self.assertTrue(expected == got and empty == {}) 94 self.assertEqual(got, expected)
95 self.assertEqual(empty, {})
125 96
126 def test_keyword(self): 97 def test_keyword(self):
127 # make sure keyword arguments are captured correctly 98 # make sure keyword arguments are captured correctly
128 for a in ['a', 0, None, 3.5]: 99 for a in ['a', 0, None, 3.5]:
129 p = self.partial(capture, a=a) 100 p = self.partial(capture, a=a)
130 expected = {'a':a,'x':None} 101 expected = {'a':a,'x':None}
131 empty, got = p(x=None) 102 empty, got = p(x=None)
132 self.assertTrue(expected == got and empty == ()) 103 self.assertEqual(got, expected)
104 self.assertEqual(empty, ())
133 105
134 def test_no_side_effects(self): 106 def test_no_side_effects(self):
135 # make sure there are no side effects that affect subsequent calls 107 # make sure there are no side effects that affect subsequent calls
136 p = self.partial(capture, 0, a=1) 108 p = self.partial(capture, 0, a=1)
137 args1, kw1 = p(1, b=2) 109 args1, kw1 = p(1, b=2)
138 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 110 self.assertEqual(args1, (0,1))
111 self.assertEqual(kw1, {'a':1,'b':2})
139 args2, kw2 = p() 112 args2, kw2 = p()
140 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 113 self.assertEqual(args2, (0,))
114 self.assertEqual(kw2, {'a':1})
141 115
142 def test_error_propagation(self): 116 def test_error_propagation(self):
143 def f(x, y): 117 def f(x, y):
144 x / y 118 x / y
145 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 119 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
146 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 120 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
147 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 121 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
148 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 122 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
149 123
150 def test_weakref(self): 124 def test_weakref(self):
151 f = self.partial(int, base=16) 125 f = self.partial(int, base=16)
152 p = proxy(f) 126 p = proxy(f)
153 self.assertEqual(f.func, p.func) 127 self.assertEqual(f.func, p.func)
154 f = None 128 f = None
155 self.assertRaises(ReferenceError, getattr, p, 'func') 129 self.assertRaises(ReferenceError, getattr, p, 'func')
156 130
157 def test_with_bound_and_unbound_methods(self): 131 def test_with_bound_and_unbound_methods(self):
158 data = list(map(str, range(10))) 132 data = list(map(str, range(10)))
159 join = self.partial(str.join, '') 133 join = self.partial(str.join, '')
160 self.assertEqual(join(data), '0123456789') 134 self.assertEqual(join(data), '0123456789')
161 join = self.partial(''.join) 135 join = self.partial(''.join)
162 self.assertEqual(join(data), '0123456789') 136 self.assertEqual(join(data), '0123456789')
163 137
138
139 @unittest.skipUnless(c_functools, 'requires the C _functools module')
140 class TestPartialC(TestPartial, unittest.TestCase):
141 if c_functools:
142 partial = c_functools.partial
143
144 def test_attributes_unwritable(self):
145 # attributes should not be writable
146 p = self.partial(capture, 1, 2, a=10, b=20)
147 self.assertRaises(AttributeError, setattr, p, 'func', map)
148 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
149 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2) )
150
151 p = self.partial(hex)
152 try:
153 del p.__dict__
154 except TypeError:
155 pass
156 else:
157 self.fail('partial object allowed __dict__ to be deleted')
158
164 def test_repr(self): 159 def test_repr(self):
165 args = (object(), object()) 160 args = (object(), object())
166 args_repr = ', '.join(repr(a) for a in args) 161 args_repr = ', '.join(repr(a) for a in args)
167 kwargs = {'a': object(), 'b': object()} 162 #kwargs = {'a': object(), 'b': object()}
163 kwargs = {'a': object()}
168 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items()) 164 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
169 if self.partial is functools.partial: 165 if self.partial is c_functools.partial:
170 name = 'functools.partial' 166 name = 'functools.partial'
171 else: 167 else:
172 name = self.partial.__name__ 168 name = self.partial.__name__
173 169
174 f = self.partial(capture) 170 f = self.partial(capture)
175 self.assertEqual('{}({!r})'.format(name, capture), 171 self.assertEqual('{}({!r})'.format(name, capture),
176 repr(f)) 172 repr(f))
177 173
178 f = self.partial(capture, *args) 174 f = self.partial(capture, *args)
179 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr), 175 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
180 repr(f)) 176 repr(f))
181 177
182 f = self.partial(capture, **kwargs) 178 f = self.partial(capture, **kwargs)
183 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr), 179 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
184 repr(f)) 180 repr(f))
185 181
186 f = self.partial(capture, *args, **kwargs) 182 f = self.partial(capture, *args, **kwargs)
187 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwa rgs_repr), 183 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwa rgs_repr),
188 repr(f)) 184 repr(f))
189 185
190 def test_pickle(self): 186 def test_pickle(self):
191 f = self.partial(signature, 'asdf', bar=True) 187 f = self.partial(signature, 'asdf', bar=True)
192 f.add_something_to__dict__ = True 188 f.add_something_to__dict__ = True
193 f_copy = pickle.loads(pickle.dumps(f)) 189 f_copy = pickle.loads(pickle.dumps(f))
194 self.assertEqual(signature(f), signature(f_copy)) 190 self.assertEqual(signature(f), signature(f_copy))
195 191
196 class TestPartialC(BaseTestC, TestPartial): 192 # Issue 6083: Reference counting bug
197 pass 193 def test_setstate_refcount(self):
198 194 class BadSequence:
199 class TestPartialPy(BaseTestPy, TestPartial): 195 def __len__(self):
200 196 return 4
201 def test_pickle(self): 197 def __getitem__(self, key):
202 raise unittest.SkipTest("Python implementation of partial isn't picklabl e") 198 if key == 0:
203 199 return max
204 def test_repr(self): 200 elif key == 1:
205 raise unittest.SkipTest("Python implementation of partial uses own repr" ) 201 return tuple(range(1000000))
206 202 elif key in (2, 3):
207 class TestPartialCSubclass(BaseTestC, TestPartial): 203 return {}
208 204 raise IndexError
205
206 f = self.partial(object)
207 self.assertRaisesRegex(SystemError,
208 "new style getargs format but argument is not a tuple",
209 f.__setstate__, BadSequence())
210
211
212 class TestPartialPy(TestPartial, unittest.TestCase):
213 partial = staticmethod(py_functools.partial)
214
215
216 if c_functools:
209 class PartialSubclass(c_functools.partial): 217 class PartialSubclass(c_functools.partial):
210 pass 218 pass
211 219
212 partial = staticmethod(PartialSubclass) 220
213 221 @unittest.skipUnless(c_functools, 'requires the C _functools module')
214 class TestPartialPySubclass(TestPartialPy): 222 class TestPartialCSubclass(TestPartialC):
215 223 if c_functools:
216 class PartialSubclass(c_functools.partial): 224 partial = PartialSubclass
217 pass 225
218 226
219 partial = staticmethod(PartialSubclass) 227 class TestPartialMethod(unittest.TestCase):
228
229 class A(object):
230 nothing = functools.partialmethod(capture)
231 positional = functools.partialmethod(capture, 1)
232 keywords = functools.partialmethod(capture, a=2)
233 both = functools.partialmethod(capture, 3, b=4)
234
235 nested = functools.partialmethod(positional, 5)
236
237 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
238
239 static = functools.partialmethod(staticmethod(capture), 8)
240 cls = functools.partialmethod(classmethod(capture), d=9)
241
242 a = A()
243
244 def test_arg_combinations(self):
245 self.assertEqual(self.a.nothing(), ((self.a,), {}))
246 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
247 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
248 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
249
250 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
251 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
252 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
253 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
254
255 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
256 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
257 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
258 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6} ))
259
260 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
261 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
262 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
263 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}) )
264
265 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
266
267 def test_nested(self):
268 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
269 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
270 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
271 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
272
273 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d' : 7}))
274
275 def test_over_partial(self):
276 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
277 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
278 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8 }))
279 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
280
281 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), { 'c': 6, 'd': 8}))
282
283 def test_bound_method_introspection(self):
284 obj = self.a
285 self.assertIs(obj.both.__self__, obj)
286 self.assertIs(obj.nested.__self__, obj)
287 self.assertIs(obj.over_partial.__self__, obj)
288 self.assertIs(obj.cls.__self__, self.A)
289 self.assertIs(self.A.cls.__self__, self.A)
290
291 def test_unbound_method_retrieval(self):
292 obj = self.A
293 self.assertFalse(hasattr(obj.both, "__self__"))
294 self.assertFalse(hasattr(obj.nested, "__self__"))
295 self.assertFalse(hasattr(obj.over_partial, "__self__"))
296 self.assertFalse(hasattr(obj.static, "__self__"))
297 self.assertFalse(hasattr(self.a.static, "__self__"))
298
299 def test_descriptors(self):
300 for obj in [self.A, self.a]:
301 with self.subTest(obj=obj):
302 self.assertEqual(obj.static(), ((8,), {}))
303 self.assertEqual(obj.static(5), ((8, 5), {}))
304 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
305 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
306
307 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
308 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
309 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
310 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9} ))
311
312 def test_overriding_keywords(self):
313 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
314 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
315
316 def test_invalid_args(self):
317 with self.assertRaises(TypeError):
318 class B(object):
319 method = functools.partialmethod(None, 1)
320
321 def test_repr(self):
322 self.assertEqual(repr(vars(self.A)['both']),
323 'functools.partialmethod({}, 3, b=4)'.format(capture))
324
325 def test_abstract(self):
326 class Abstract(abc.ABCMeta):
327
328 @abc.abstractmethod
329 def add(self, x, y):
330 pass
331
332 add5 = functools.partialmethod(add, 5)
333
334 self.assertTrue(Abstract.add.__isabstractmethod__)
335 self.assertTrue(Abstract.add5.__isabstractmethod__)
336
337 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nest ed, self.A.both]:
338 self.assertFalse(getattr(func, '__isabstractmethod__', False))
339
220 340
221 class TestUpdateWrapper(unittest.TestCase): 341 class TestUpdateWrapper(unittest.TestCase):
222 342
223 def check_wrapper(self, wrapper, wrapped, 343 def check_wrapper(self, wrapper, wrapped,
224 assigned=functools.WRAPPER_ASSIGNMENTS, 344 assigned=functools.WRAPPER_ASSIGNMENTS,
225 updated=functools.WRAPPER_UPDATES): 345 updated=functools.WRAPPER_UPDATES):
226 # Check attributes were assigned 346 # Check attributes were assigned
227 for name in assigned: 347 for name in assigned:
228 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 348 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
229 # Check attributes were updated 349 # Check attributes were updated
230 for name in updated: 350 for name in updated:
231 wrapper_attr = getattr(wrapper, name) 351 wrapper_attr = getattr(wrapper, name)
232 wrapped_attr = getattr(wrapped, name) 352 wrapped_attr = getattr(wrapped, name)
233 for key in wrapped_attr: 353 for key in wrapped_attr:
354 if name == "__dict__" and key == "__wrapped__":
355 # __wrapped__ is overwritten by the update code
356 continue
234 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 357 self.assertIs(wrapped_attr[key], wrapper_attr[key])
358 # Check __wrapped__
359 self.assertIs(wrapper.__wrapped__, wrapped)
360
235 361
236 def _default_update(self): 362 def _default_update(self):
237 def f(a:'This is a new annotation'): 363 def f(a:'This is a new annotation'):
238 """This is a test""" 364 """This is a test"""
239 pass 365 pass
240 f.attr = 'This is also a test' 366 f.attr = 'This is also a test'
367 f.__wrapped__ = "This is a bald faced lie"
241 def wrapper(b:'This is the prior annotation'): 368 def wrapper(b:'This is the prior annotation'):
242 pass 369 pass
243 functools.update_wrapper(wrapper, f) 370 functools.update_wrapper(wrapper, f)
244 return wrapper, f 371 return wrapper, f
245 372
246 def test_default_update(self): 373 def test_default_update(self):
247 wrapper, f = self._default_update() 374 wrapper, f = self._default_update()
248 self.check_wrapper(wrapper, f) 375 self.check_wrapper(wrapper, f)
249 self.assertIs(wrapper.__wrapped__, f) 376 self.assertIs(wrapper.__wrapped__, f)
250 self.assertEqual(wrapper.__name__, 'f') 377 self.assertEqual(wrapper.__name__, 'f')
(...skipping 54 matching lines...) Expand 10 before | Expand all | Expand 10 after
305 self.assertNotIn('attr', wrapper.__dict__) 432 self.assertNotIn('attr', wrapper.__dict__)
306 self.assertEqual(wrapper.dict_attr, {}) 433 self.assertEqual(wrapper.dict_attr, {})
307 # Wrapper must have expected attributes for updating 434 # Wrapper must have expected attributes for updating
308 del wrapper.dict_attr 435 del wrapper.dict_attr
309 with self.assertRaises(AttributeError): 436 with self.assertRaises(AttributeError):
310 functools.update_wrapper(wrapper, f, assign, update) 437 functools.update_wrapper(wrapper, f, assign, update)
311 wrapper.dict_attr = 1 438 wrapper.dict_attr = 1
312 with self.assertRaises(AttributeError): 439 with self.assertRaises(AttributeError):
313 functools.update_wrapper(wrapper, f, assign, update) 440 functools.update_wrapper(wrapper, f, assign, update)
314 441
442 @support.requires_docstrings
315 @unittest.skipIf(sys.flags.optimize >= 2, 443 @unittest.skipIf(sys.flags.optimize >= 2,
316 "Docstrings are omitted with -O2 and above") 444 "Docstrings are omitted with -O2 and above")
317 def test_builtin_update(self): 445 def test_builtin_update(self):
318 # Test for bug #1576241 446 # Test for bug #1576241
319 def wrapper(): 447 def wrapper():
320 pass 448 pass
321 functools.update_wrapper(wrapper, max) 449 functools.update_wrapper(wrapper, max)
322 self.assertEqual(wrapper.__name__, 'max') 450 self.assertEqual(wrapper.__name__, 'max')
323 self.assertTrue(wrapper.__doc__.startswith('max(')) 451 self.assertTrue(wrapper.__doc__.startswith('max('))
324 self.assertEqual(wrapper.__annotations__, {}) 452 self.assertEqual(wrapper.__annotations__, {})
325 453
454
326 class TestWraps(TestUpdateWrapper): 455 class TestWraps(TestUpdateWrapper):
327 456
328 def _default_update(self): 457 def _default_update(self):
329 def f(): 458 def f():
330 """This is a test""" 459 """This is a test"""
331 pass 460 pass
332 f.attr = 'This is also a test' 461 f.attr = 'This is also a test'
462 f.__wrapped__ = "This is still a bald faced lie"
333 @functools.wraps(f) 463 @functools.wraps(f)
334 def wrapper(): 464 def wrapper():
335 pass 465 pass
336 self.check_wrapper(wrapper, f)
337 return wrapper, f 466 return wrapper, f
338 467
339 def test_default_update(self): 468 def test_default_update(self):
340 wrapper, f = self._default_update() 469 wrapper, f = self._default_update()
470 self.check_wrapper(wrapper, f)
341 self.assertEqual(wrapper.__name__, 'f') 471 self.assertEqual(wrapper.__name__, 'f')
342 self.assertEqual(wrapper.__qualname__, f.__qualname__) 472 self.assertEqual(wrapper.__qualname__, f.__qualname__)
343 self.assertEqual(wrapper.attr, 'This is also a test') 473 self.assertEqual(wrapper.attr, 'This is also a test')
344 474
345 @unittest.skipIf(sys.flags.optimize >= 2, 475 @unittest.skipIf(sys.flags.optimize >= 2,
346 "Docstrings are omitted with -O2 and above") 476 "Docstrings are omitted with -O2 and above")
347 def test_default_update_doc(self): 477 def test_default_update_doc(self):
348 wrapper, _ = self._default_update() 478 wrapper, _ = self._default_update()
349 self.assertEqual(wrapper.__doc__, 'This is a test') 479 self.assertEqual(wrapper.__doc__, 'This is a test')
350 480
(...skipping 24 matching lines...) Expand all
375 @functools.wraps(f, assign, update) 505 @functools.wraps(f, assign, update)
376 @add_dict_attr 506 @add_dict_attr
377 def wrapper(): 507 def wrapper():
378 pass 508 pass
379 self.check_wrapper(wrapper, f, assign, update) 509 self.check_wrapper(wrapper, f, assign, update)
380 self.assertEqual(wrapper.__name__, 'wrapper') 510 self.assertEqual(wrapper.__name__, 'wrapper')
381 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 511 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
382 self.assertEqual(wrapper.__doc__, None) 512 self.assertEqual(wrapper.__doc__, None)
383 self.assertEqual(wrapper.attr, 'This is a different test') 513 self.assertEqual(wrapper.attr, 'This is a different test')
384 self.assertEqual(wrapper.dict_attr, f.dict_attr) 514 self.assertEqual(wrapper.dict_attr, f.dict_attr)
515
385 516
386 class TestReduce(unittest.TestCase): 517 class TestReduce(unittest.TestCase):
387 func = functools.reduce 518 func = functools.reduce
388 519
389 def test_reduce(self): 520 def test_reduce(self):
390 class Squares: 521 class Squares:
391 def __init__(self, max): 522 def __init__(self, max):
392 self.max = max 523 self.max = max
393 self.sofar = [] 524 self.sofar = []
394 525
(...skipping 61 matching lines...) Expand 10 before | Expand all | Expand 10 after
456 self.assertEqual(self.func(add, SequenceClass(5)), 10) 587 self.assertEqual(self.func(add, SequenceClass(5)), 10)
457 self.assertEqual(self.func(add, SequenceClass(5), 42), 52) 588 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
458 self.assertRaises(TypeError, self.func, add, SequenceClass(0)) 589 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
459 self.assertEqual(self.func(add, SequenceClass(0), 42), 42) 590 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
460 self.assertEqual(self.func(add, SequenceClass(1)), 0) 591 self.assertEqual(self.func(add, SequenceClass(1)), 0)
461 self.assertEqual(self.func(add, SequenceClass(1), 42), 42) 592 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
462 593
463 d = {"one": 1, "two": 2, "three": 3} 594 d = {"one": 1, "two": 2, "three": 3}
464 self.assertEqual(self.func(add, d), "".join(d.keys())) 595 self.assertEqual(self.func(add, d), "".join(d.keys()))
465 596
466 class TestCmpToKey(object): 597
598 class TestCmpToKey:
467 599
468 def test_cmp_to_key(self): 600 def test_cmp_to_key(self):
469 def cmp1(x, y): 601 def cmp1(x, y):
470 return (x > y) - (x < y) 602 return (x > y) - (x < y)
471 key = self.cmp_to_key(cmp1) 603 key = self.cmp_to_key(cmp1)
472 self.assertEqual(key(3), key(3)) 604 self.assertEqual(key(3), key(3))
473 self.assertGreater(key(3), key(1)) 605 self.assertGreater(key(3), key(1))
474 self.assertGreaterEqual(key(3), key(3)) 606 self.assertGreaterEqual(key(3), key(3))
475 607
476 def cmp2(x, y): 608 def cmp2(x, y):
(...skipping 10 matching lines...) Expand all
487 key = self.cmp_to_key(mycmp=cmp1) 619 key = self.cmp_to_key(mycmp=cmp1)
488 self.assertEqual(key(obj=3), key(obj=3)) 620 self.assertEqual(key(obj=3), key(obj=3))
489 self.assertGreater(key(obj=3), key(obj=1)) 621 self.assertGreater(key(obj=3), key(obj=1))
490 with self.assertRaises((TypeError, AttributeError)): 622 with self.assertRaises((TypeError, AttributeError)):
491 key(3) > 1 # rhs is not a K object 623 key(3) > 1 # rhs is not a K object
492 with self.assertRaises((TypeError, AttributeError)): 624 with self.assertRaises((TypeError, AttributeError)):
493 1 < key(3) # lhs is not a K object 625 1 < key(3) # lhs is not a K object
494 with self.assertRaises(TypeError): 626 with self.assertRaises(TypeError):
495 key = self.cmp_to_key() # too few args 627 key = self.cmp_to_key() # too few args
496 with self.assertRaises(TypeError): 628 with self.assertRaises(TypeError):
497 key = self.module.cmp_to_key(cmp1, None) # too many args 629 key = self.cmp_to_key(cmp1, None) # too many args
498 key = self.cmp_to_key(cmp1) 630 key = self.cmp_to_key(cmp1)
499 with self.assertRaises(TypeError): 631 with self.assertRaises(TypeError):
500 key() # too few args 632 key() # too few args
501 with self.assertRaises(TypeError): 633 with self.assertRaises(TypeError):
502 key(None, None) # too many args 634 key(None, None) # too many args
503 635
504 def test_bad_cmp(self): 636 def test_bad_cmp(self):
505 def cmp1(x, y): 637 def cmp1(x, y):
506 raise ZeroDivisionError 638 raise ZeroDivisionError
507 key = self.cmp_to_key(cmp1) 639 key = self.cmp_to_key(cmp1)
(...skipping 30 matching lines...) Expand all
538 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 670 [0, 1, 1, 2, 3, 4, 5, 7, 10])
539 671
540 def test_hash(self): 672 def test_hash(self):
541 def mycmp(x, y): 673 def mycmp(x, y):
542 return y - x 674 return y - x
543 key = self.cmp_to_key(mycmp) 675 key = self.cmp_to_key(mycmp)
544 k = key(10) 676 k = key(10)
545 self.assertRaises(TypeError, hash, k) 677 self.assertRaises(TypeError, hash, k)
546 self.assertNotIsInstance(k, collections.Hashable) 678 self.assertNotIsInstance(k, collections.Hashable)
547 679
548 class TestCmpToKeyC(BaseTestC, TestCmpToKey): 680
549 cmp_to_key = c_functools.cmp_to_key 681 @unittest.skipUnless(c_functools, 'requires the C _functools module')
550 682 class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
551 class TestCmpToKeyPy(BaseTestPy, TestCmpToKey): 683 if c_functools:
684 cmp_to_key = c_functools.cmp_to_key
685
686
687 class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
552 cmp_to_key = staticmethod(py_functools.cmp_to_key) 688 cmp_to_key = staticmethod(py_functools.cmp_to_key)
689
553 690
554 class TestTotalOrdering(unittest.TestCase): 691 class TestTotalOrdering(unittest.TestCase):
555 692
556 def test_total_ordering_lt(self): 693 def test_total_ordering_lt(self):
557 @functools.total_ordering 694 @functools.total_ordering
558 class A: 695 class A:
559 def __init__(self, value): 696 def __init__(self, value):
560 self.value = value 697 self.value = value
561 def __lt__(self, other): 698 def __lt__(self, other):
562 return self.value < other.value 699 return self.value < other.value
563 def __eq__(self, other): 700 def __eq__(self, other):
564 return self.value == other.value 701 return self.value == other.value
565 self.assertTrue(A(1) < A(2)) 702 self.assertTrue(A(1) < A(2))
566 self.assertTrue(A(2) > A(1)) 703 self.assertTrue(A(2) > A(1))
567 self.assertTrue(A(1) <= A(2)) 704 self.assertTrue(A(1) <= A(2))
568 self.assertTrue(A(2) >= A(1)) 705 self.assertTrue(A(2) >= A(1))
569 self.assertTrue(A(2) <= A(2)) 706 self.assertTrue(A(2) <= A(2))
570 self.assertTrue(A(2) >= A(2)) 707 self.assertTrue(A(2) >= A(2))
708 self.assertFalse(A(1) > A(2))
571 709
572 def test_total_ordering_le(self): 710 def test_total_ordering_le(self):
573 @functools.total_ordering 711 @functools.total_ordering
574 class A: 712 class A:
575 def __init__(self, value): 713 def __init__(self, value):
576 self.value = value 714 self.value = value
577 def __le__(self, other): 715 def __le__(self, other):
578 return self.value <= other.value 716 return self.value <= other.value
579 def __eq__(self, other): 717 def __eq__(self, other):
580 return self.value == other.value 718 return self.value == other.value
581 self.assertTrue(A(1) < A(2)) 719 self.assertTrue(A(1) < A(2))
582 self.assertTrue(A(2) > A(1)) 720 self.assertTrue(A(2) > A(1))
583 self.assertTrue(A(1) <= A(2)) 721 self.assertTrue(A(1) <= A(2))
584 self.assertTrue(A(2) >= A(1)) 722 self.assertTrue(A(2) >= A(1))
585 self.assertTrue(A(2) <= A(2)) 723 self.assertTrue(A(2) <= A(2))
586 self.assertTrue(A(2) >= A(2)) 724 self.assertTrue(A(2) >= A(2))
725 self.assertFalse(A(1) >= A(2))
587 726
588 def test_total_ordering_gt(self): 727 def test_total_ordering_gt(self):
589 @functools.total_ordering 728 @functools.total_ordering
590 class A: 729 class A:
591 def __init__(self, value): 730 def __init__(self, value):
592 self.value = value 731 self.value = value
593 def __gt__(self, other): 732 def __gt__(self, other):
594 return self.value > other.value 733 return self.value > other.value
595 def __eq__(self, other): 734 def __eq__(self, other):
596 return self.value == other.value 735 return self.value == other.value
597 self.assertTrue(A(1) < A(2)) 736 self.assertTrue(A(1) < A(2))
598 self.assertTrue(A(2) > A(1)) 737 self.assertTrue(A(2) > A(1))
599 self.assertTrue(A(1) <= A(2)) 738 self.assertTrue(A(1) <= A(2))
600 self.assertTrue(A(2) >= A(1)) 739 self.assertTrue(A(2) >= A(1))
601 self.assertTrue(A(2) <= A(2)) 740 self.assertTrue(A(2) <= A(2))
602 self.assertTrue(A(2) >= A(2)) 741 self.assertTrue(A(2) >= A(2))
742 self.assertFalse(A(2) < A(1))
603 743
604 def test_total_ordering_ge(self): 744 def test_total_ordering_ge(self):
605 @functools.total_ordering 745 @functools.total_ordering
606 class A: 746 class A:
607 def __init__(self, value): 747 def __init__(self, value):
608 self.value = value 748 self.value = value
609 def __ge__(self, other): 749 def __ge__(self, other):
610 return self.value >= other.value 750 return self.value >= other.value
611 def __eq__(self, other): 751 def __eq__(self, other):
612 return self.value == other.value 752 return self.value == other.value
613 self.assertTrue(A(1) < A(2)) 753 self.assertTrue(A(1) < A(2))
614 self.assertTrue(A(2) > A(1)) 754 self.assertTrue(A(2) > A(1))
615 self.assertTrue(A(1) <= A(2)) 755 self.assertTrue(A(1) <= A(2))
616 self.assertTrue(A(2) >= A(1)) 756 self.assertTrue(A(2) >= A(1))
617 self.assertTrue(A(2) <= A(2)) 757 self.assertTrue(A(2) <= A(2))
618 self.assertTrue(A(2) >= A(2)) 758 self.assertTrue(A(2) >= A(2))
759 self.assertFalse(A(2) <= A(1))
619 760
620 def test_total_ordering_no_overwrite(self): 761 def test_total_ordering_no_overwrite(self):
621 # new methods should not overwrite existing 762 # new methods should not overwrite existing
622 @functools.total_ordering 763 @functools.total_ordering
623 class A(int): 764 class A(int):
624 pass 765 pass
625 self.assertTrue(A(1) < A(2)) 766 self.assertTrue(A(1) < A(2))
626 self.assertTrue(A(2) > A(1)) 767 self.assertTrue(A(2) > A(1))
627 self.assertTrue(A(1) <= A(2)) 768 self.assertTrue(A(1) <= A(2))
628 self.assertTrue(A(2) >= A(1)) 769 self.assertTrue(A(2) >= A(1))
629 self.assertTrue(A(2) <= A(2)) 770 self.assertTrue(A(2) <= A(2))
630 self.assertTrue(A(2) >= A(2)) 771 self.assertTrue(A(2) >= A(2))
631 772
632 def test_no_operations_defined(self): 773 def test_no_operations_defined(self):
633 with self.assertRaises(ValueError): 774 with self.assertRaises(ValueError):
634 @functools.total_ordering 775 @functools.total_ordering
635 class A: 776 class A:
636 pass 777 pass
637 778
638 def test_bug_10042(self): 779 def test_type_error_when_not_implemented(self):
780 # bug 10042; ensure stack overflow does not occur
781 # when decorated types return NotImplemented
639 @functools.total_ordering 782 @functools.total_ordering
640 class TestTO: 783 class ImplementsLessThan:
641 def __init__(self, value): 784 def __init__(self, value):
642 self.value = value 785 self.value = value
643 def __eq__(self, other): 786 def __eq__(self, other):
644 if isinstance(other, TestTO): 787 if isinstance(other, ImplementsLessThan):
645 return self.value == other.value 788 return self.value == other.value
646 return False 789 return False
647 def __lt__(self, other): 790 def __lt__(self, other):
648 if isinstance(other, TestTO): 791 if isinstance(other, ImplementsLessThan):
649 return self.value < other.value 792 return self.value < other.value
650 raise TypeError 793 return NotImplemented
651 with self.assertRaises(TypeError): 794
652 TestTO(8) <= () 795 @functools.total_ordering
796 class ImplementsGreaterThan:
797 def __init__(self, value):
798 self.value = value
799 def __eq__(self, other):
800 if isinstance(other, ImplementsGreaterThan):
801 return self.value == other.value
802 return False
803 def __gt__(self, other):
804 if isinstance(other, ImplementsGreaterThan):
805 return self.value > other.value
806 return NotImplemented
807
808 @functools.total_ordering
809 class ImplementsLessThanEqualTo:
810 def __init__(self, value):
811 self.value = value
812 def __eq__(self, other):
813 if isinstance(other, ImplementsLessThanEqualTo):
814 return self.value == other.value
815 return False
816 def __le__(self, other):
817 if isinstance(other, ImplementsLessThanEqualTo):
818 return self.value <= other.value
819 return NotImplemented
820
821 @functools.total_ordering
822 class ImplementsGreaterThanEqualTo:
823 def __init__(self, value):
824 self.value = value
825 def __eq__(self, other):
826 if isinstance(other, ImplementsGreaterThanEqualTo):
827 return self.value == other.value
828 return False
829 def __ge__(self, other):
830 if isinstance(other, ImplementsGreaterThanEqualTo):
831 return self.value >= other.value
832 return NotImplemented
833
834 @functools.total_ordering
835 class ComparatorNotImplemented:
836 def __init__(self, value):
837 self.value = value
838 def __eq__(self, other):
839 if isinstance(other, ComparatorNotImplemented):
840 return self.value == other.value
841 return False
842 def __lt__(self, other):
843 return NotImplemented
844
845 with self.subTest("LT < 1"), self.assertRaises(TypeError):
846 ImplementsLessThan(-1) < 1
847
848 with self.subTest("LT < LE"), self.assertRaises(TypeError):
849 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
850
851 with self.subTest("LT < GT"), self.assertRaises(TypeError):
852 ImplementsLessThan(1) < ImplementsGreaterThan(1)
853
854 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
855 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
856
857 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
858 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
859
860 with self.subTest("GT > GE"), self.assertRaises(TypeError):
861 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
862
863 with self.subTest("GT > LT"), self.assertRaises(TypeError):
864 ImplementsGreaterThan(5) > ImplementsLessThan(5)
865
866 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
867 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
868
869 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
870 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
871
872 with self.subTest("GE when equal"):
873 a = ComparatorNotImplemented(8)
874 b = ComparatorNotImplemented(8)
875 self.assertEqual(a, b)
876 with self.assertRaises(TypeError):
877 a >= b
878
879 with self.subTest("LE when equal"):
880 a = ComparatorNotImplemented(9)
881 b = ComparatorNotImplemented(9)
882 self.assertEqual(a, b)
883 with self.assertRaises(TypeError):
884 a <= b
653 885
654 class TestLRU(unittest.TestCase): 886 class TestLRU(unittest.TestCase):
655 887
656 def test_lru(self): 888 def test_lru(self):
657 def orig(x, y): 889 def orig(x, y):
658 return 3 * x + y 890 return 3 * x + y
659 f = functools.lru_cache(maxsize=20)(orig) 891 f = functools.lru_cache(maxsize=20)(orig)
660 hits, misses, maxsize, currsize = f.cache_info() 892 hits, misses, maxsize, currsize = f.cache_info()
661 self.assertEqual(maxsize, 20) 893 self.assertEqual(maxsize, 20)
662 self.assertEqual(currsize, 0) 894 self.assertEqual(currsize, 0)
(...skipping 148 matching lines...) Expand 10 before | Expand all | Expand 10 after
811 return n 1043 return n
812 return fib(n=n-1) + fib(n=n-2) 1044 return fib(n=n-1) + fib(n=n-2)
813 self.assertEqual([fib(n=number) for number in range(16)], 1045 self.assertEqual([fib(n=number) for number in range(16)],
814 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1046 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
815 self.assertEqual(fib.cache_info(), 1047 self.assertEqual(fib.cache_info(),
816 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1048 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
817 fib.cache_clear() 1049 fib.cache_clear()
818 self.assertEqual(fib.cache_info(), 1050 self.assertEqual(fib.cache_info(),
819 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1051 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
820 1052
1053 def test_need_for_rlock(self):
1054 # This will deadlock on an LRU cache that uses a regular lock
1055
1056 @functools.lru_cache(maxsize=10)
1057 def test_func(x):
1058 'Used to demonstrate a reentrant lru_cache call within a single thre ad'
1059 return x
1060
1061 class DoubleEq:
1062 'Demonstrate a reentrant lru_cache call within a single thread'
1063 def __init__(self, x):
1064 self.x = x
1065 def __hash__(self):
1066 return self.x
1067 def __eq__(self, other):
1068 if self.x == 2:
1069 test_func(DoubleEq(1))
1070 return self.x == other.x
1071
1072 test_func(DoubleEq(1)) # Load the cache
1073 test_func(DoubleEq(2)) # Load the cache
1074 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq_ _ call
1075 DoubleEq(2)) # Verify the correct return value
1076
1077
1078 class TestSingleDispatch(unittest.TestCase):
1079 def test_simple_overloads(self):
1080 @functools.singledispatch
1081 def g(obj):
1082 return "base"
1083 def g_int(i):
1084 return "integer"
1085 g.register(int, g_int)
1086 self.assertEqual(g("str"), "base")
1087 self.assertEqual(g(1), "integer")
1088 self.assertEqual(g([1,2,3]), "base")
1089
1090 def test_mro(self):
1091 @functools.singledispatch
1092 def g(obj):
1093 return "base"
1094 class A:
1095 pass
1096 class C(A):
1097 pass
1098 class B(A):
1099 pass
1100 class D(C, B):
1101 pass
1102 def g_A(a):
1103 return "A"
1104 def g_B(b):
1105 return "B"
1106 g.register(A, g_A)
1107 g.register(B, g_B)
1108 self.assertEqual(g(A()), "A")
1109 self.assertEqual(g(B()), "B")
1110 self.assertEqual(g(C()), "A")
1111 self.assertEqual(g(D()), "B")
1112
1113 def test_register_decorator(self):
1114 @functools.singledispatch
1115 def g(obj):
1116 return "base"
1117 @g.register(int)
1118 def g_int(i):
1119 return "int %s" % (i,)
1120 self.assertEqual(g(""), "base")
1121 self.assertEqual(g(12), "int 12")
1122 self.assertIs(g.dispatch(int), g_int)
1123 self.assertIs(g.dispatch(object), g.dispatch(str))
1124 # Note: in the assert above this is not g.
1125 # @singledispatch returns the wrapper.
1126
1127 def test_wrapping_attributes(self):
1128 @functools.singledispatch
1129 def g(obj):
1130 "Simple test"
1131 return "Test"
1132 self.assertEqual(g.__name__, "g")
1133 if sys.flags.optimize < 2:
1134 self.assertEqual(g.__doc__, "Simple test")
1135
1136 @unittest.skipUnless(decimal, 'requires _decimal')
1137 @support.cpython_only
1138 def test_c_classes(self):
1139 @functools.singledispatch
1140 def g(obj):
1141 return "base"
1142 @g.register(decimal.DecimalException)
1143 def _(obj):
1144 return obj.args
1145 subn = decimal.Subnormal("Exponent < Emin")
1146 rnd = decimal.Rounded("Number got rounded")
1147 self.assertEqual(g(subn), ("Exponent < Emin",))
1148 self.assertEqual(g(rnd), ("Number got rounded",))
1149 @g.register(decimal.Subnormal)
1150 def _(obj):
1151 return "Too small to care."
1152 self.assertEqual(g(subn), "Too small to care.")
1153 self.assertEqual(g(rnd), ("Number got rounded",))
1154
1155 def test_compose_mro(self):
1156 # None of the examples in this test depend on haystack ordering.
1157 c = collections
1158 mro = functools._compose_mro
1159 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1160 for haystack in permutations(bases):
1161 m = mro(dict, haystack)
1162 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1163 c.Iterable, c.Container, object])
1164 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1165 for haystack in permutations(bases):
1166 m = mro(c.ChainMap, haystack)
1167 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1168 c.Sized, c.Iterable, c.Container, object])
1169
1170 # If there's a generic function with implementations registered for
1171 # both Sized and Container, passing a defaultdict to it results in an
1172 # ambiguous dispatch which will cause a RuntimeError (see
1173 # test_mro_conflicts).
1174 bases = [c.Container, c.Sized, str]
1175 for haystack in permutations(bases):
1176 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1177 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1178 object])
1179
1180 # MutableSequence below is registered directly on D. In other words, it
1181 # preceeds MutableMapping which means single dispatch will always
1182 # choose MutableSequence here.
1183 class D(c.defaultdict):
1184 pass
1185 c.MutableSequence.register(D)
1186 bases = [c.MutableSequence, c.MutableMapping]
1187 for haystack in permutations(bases):
1188 m = mro(D, bases)
1189 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1190 c.defaultdict, dict, c.MutableMapping,
1191 c.Mapping, c.Sized, c.Iterable, c.Container,
1192 object])
1193
1194 # Container and Callable are registered on different base classes and
1195 # a generic function supporting both should always pick the Callable
1196 # implementation if a C instance is passed.
1197 class C(c.defaultdict):
1198 def __call__(self):
1199 pass
1200 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1201 for haystack in permutations(bases):
1202 m = mro(C, haystack)
1203 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1204 c.Sized, c.Iterable, c.Container, object])
1205
1206 def test_register_abc(self):
1207 c = collections
1208 d = {"a": "b"}
1209 l = [1, 2, 3]
1210 s = {object(), None}
1211 f = frozenset(s)
1212 t = (1, 2, 3)
1213 @functools.singledispatch
1214 def g(obj):
1215 return "base"
1216 self.assertEqual(g(d), "base")
1217 self.assertEqual(g(l), "base")
1218 self.assertEqual(g(s), "base")
1219 self.assertEqual(g(f), "base")
1220 self.assertEqual(g(t), "base")
1221 g.register(c.Sized, lambda obj: "sized")
1222 self.assertEqual(g(d), "sized")
1223 self.assertEqual(g(l), "sized")
1224 self.assertEqual(g(s), "sized")
1225 self.assertEqual(g(f), "sized")
1226 self.assertEqual(g(t), "sized")
1227 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1228 self.assertEqual(g(d), "mutablemapping")
1229 self.assertEqual(g(l), "sized")
1230 self.assertEqual(g(s), "sized")
1231 self.assertEqual(g(f), "sized")
1232 self.assertEqual(g(t), "sized")
1233 g.register(c.ChainMap, lambda obj: "chainmap")
1234 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1235 self.assertEqual(g(l), "sized")
1236 self.assertEqual(g(s), "sized")
1237 self.assertEqual(g(f), "sized")
1238 self.assertEqual(g(t), "sized")
1239 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1240 self.assertEqual(g(d), "mutablemapping")
1241 self.assertEqual(g(l), "mutablesequence")
1242 self.assertEqual(g(s), "sized")
1243 self.assertEqual(g(f), "sized")
1244 self.assertEqual(g(t), "sized")
1245 g.register(c.MutableSet, lambda obj: "mutableset")
1246 self.assertEqual(g(d), "mutablemapping")
1247 self.assertEqual(g(l), "mutablesequence")
1248 self.assertEqual(g(s), "mutableset")
1249 self.assertEqual(g(f), "sized")
1250 self.assertEqual(g(t), "sized")
1251 g.register(c.Mapping, lambda obj: "mapping")
1252 self.assertEqual(g(d), "mutablemapping") # not specific enough
1253 self.assertEqual(g(l), "mutablesequence")
1254 self.assertEqual(g(s), "mutableset")
1255 self.assertEqual(g(f), "sized")
1256 self.assertEqual(g(t), "sized")
1257 g.register(c.Sequence, lambda obj: "sequence")
1258 self.assertEqual(g(d), "mutablemapping")
1259 self.assertEqual(g(l), "mutablesequence")
1260 self.assertEqual(g(s), "mutableset")
1261 self.assertEqual(g(f), "sized")
1262 self.assertEqual(g(t), "sequence")
1263 g.register(c.Set, lambda obj: "set")
1264 self.assertEqual(g(d), "mutablemapping")
1265 self.assertEqual(g(l), "mutablesequence")
1266 self.assertEqual(g(s), "mutableset")
1267 self.assertEqual(g(f), "set")
1268 self.assertEqual(g(t), "sequence")
1269 g.register(dict, lambda obj: "dict")
1270 self.assertEqual(g(d), "dict")
1271 self.assertEqual(g(l), "mutablesequence")
1272 self.assertEqual(g(s), "mutableset")
1273 self.assertEqual(g(f), "set")
1274 self.assertEqual(g(t), "sequence")
1275 g.register(list, lambda obj: "list")
1276 self.assertEqual(g(d), "dict")
1277 self.assertEqual(g(l), "list")
1278 self.assertEqual(g(s), "mutableset")
1279 self.assertEqual(g(f), "set")
1280 self.assertEqual(g(t), "sequence")
1281 g.register(set, lambda obj: "concrete-set")
1282 self.assertEqual(g(d), "dict")
1283 self.assertEqual(g(l), "list")
1284 self.assertEqual(g(s), "concrete-set")
1285 self.assertEqual(g(f), "set")
1286 self.assertEqual(g(t), "sequence")
1287 g.register(frozenset, lambda obj: "frozen-set")
1288 self.assertEqual(g(d), "dict")
1289 self.assertEqual(g(l), "list")
1290 self.assertEqual(g(s), "concrete-set")
1291 self.assertEqual(g(f), "frozen-set")
1292 self.assertEqual(g(t), "sequence")
1293 g.register(tuple, lambda obj: "tuple")
1294 self.assertEqual(g(d), "dict")
1295 self.assertEqual(g(l), "list")
1296 self.assertEqual(g(s), "concrete-set")
1297 self.assertEqual(g(f), "frozen-set")
1298 self.assertEqual(g(t), "tuple")
1299
1300 def test_c3_abc(self):
1301 c = collections
1302 mro = functools._c3_mro
1303 class A(object):
1304 pass
1305 class B(A):
1306 def __len__(self):
1307 return 0 # implies Sized
1308 @c.Container.register
1309 class C(object):
1310 pass
1311 class D(object):
1312 pass # unrelated
1313 class X(D, C, B):
1314 def __call__(self):
1315 pass # implies Callable
1316 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1317 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1318 self.assertEqual(mro(X, abcs=abcs), expected)
1319 # unrelated ABCs don't appear in the resulting MRO
1320 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1321 self.assertEqual(mro(X, abcs=many_abcs), expected)
1322
1323 def test_mro_conflicts(self):
1324 c = collections
1325 @functools.singledispatch
1326 def g(arg):
1327 return "base"
1328 class O(c.Sized):
1329 def __len__(self):
1330 return 0
1331 o = O()
1332 self.assertEqual(g(o), "base")
1333 g.register(c.Iterable, lambda arg: "iterable")
1334 g.register(c.Container, lambda arg: "container")
1335 g.register(c.Sized, lambda arg: "sized")
1336 g.register(c.Set, lambda arg: "set")
1337 self.assertEqual(g(o), "sized")
1338 c.Iterable.register(O)
1339 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1340 c.Container.register(O)
1341 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
1342 c.Set.register(O)
1343 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1344 # c.Sized and c.Container
1345 class P:
1346 pass
1347 p = P()
1348 self.assertEqual(g(p), "base")
1349 c.Iterable.register(P)
1350 self.assertEqual(g(p), "iterable")
1351 c.Container.register(P)
1352 with self.assertRaises(RuntimeError) as re_one:
1353 g(p)
1354 self.assertIn(
1355 str(re_one.exception),
1356 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1357 "or <class 'collections.abc.Iterable'>"),
1358 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1359 "or <class 'collections.abc.Container'>")),
1360 )
1361 class Q(c.Sized):
1362 def __len__(self):
1363 return 0
1364 q = Q()
1365 self.assertEqual(g(q), "sized")
1366 c.Iterable.register(Q)
1367 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1368 c.Set.register(Q)
1369 self.assertEqual(g(q), "set") # because c.Set is a subclass of
1370 # c.Sized and c.Iterable
1371 @functools.singledispatch
1372 def h(arg):
1373 return "base"
1374 @h.register(c.Sized)
1375 def _(arg):
1376 return "sized"
1377 @h.register(c.Container)
1378 def _(arg):
1379 return "container"
1380 # Even though Sized and Container are explicit bases of MutableMapping,
1381 # this ABC is implicitly registered on defaultdict which makes all of
1382 # MutableMapping's bases implicit as well from defaultdict's
1383 # perspective.
1384 with self.assertRaises(RuntimeError) as re_two:
1385 h(c.defaultdict(lambda: 0))
1386 self.assertIn(
1387 str(re_two.exception),
1388 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1389 "or <class 'collections.abc.Sized'>"),
1390 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1391 "or <class 'collections.abc.Container'>")),
1392 )
1393 class R(c.defaultdict):
1394 pass
1395 c.MutableSequence.register(R)
1396 @functools.singledispatch
1397 def i(arg):
1398 return "base"
1399 @i.register(c.MutableMapping)
1400 def _(arg):
1401 return "mapping"
1402 @i.register(c.MutableSequence)
1403 def _(arg):
1404 return "sequence"
1405 r = R()
1406 self.assertEqual(i(r), "sequence")
1407 class S:
1408 pass
1409 class T(S, c.Sized):
1410 def __len__(self):
1411 return 0
1412 t = T()
1413 self.assertEqual(h(t), "sized")
1414 c.Container.register(T)
1415 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1416 class U:
1417 def __len__(self):
1418 return 0
1419 u = U()
1420 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1421 # from the existence of __len__()
1422 c.Container.register(U)
1423 # There is no preference for registered versus inferred ABCs.
1424 with self.assertRaises(RuntimeError) as re_three:
1425 h(u)
1426 self.assertIn(
1427 str(re_three.exception),
1428 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1429 "or <class 'collections.abc.Sized'>"),
1430 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1431 "or <class 'collections.abc.Container'>")),
1432 )
1433 class V(c.Sized, S):
1434 def __len__(self):
1435 return 0
1436 @functools.singledispatch
1437 def j(arg):
1438 return "base"
1439 @j.register(S)
1440 def _(arg):
1441 return "s"
1442 @j.register(c.Container)
1443 def _(arg):
1444 return "container"
1445 v = V()
1446 self.assertEqual(j(v), "s")
1447 c.Container.register(V)
1448 self.assertEqual(j(v), "container") # because it ends up right after
1449 # Sized in the MRO
1450
1451 def test_cache_invalidation(self):
1452 from collections import UserDict
1453 class TracingDict(UserDict):
1454 def __init__(self, *args, **kwargs):
1455 super(TracingDict, self).__init__(*args, **kwargs)
1456 self.set_ops = []
1457 self.get_ops = []
1458 def __getitem__(self, key):
1459 result = self.data[key]
1460 self.get_ops.append(key)
1461 return result
1462 def __setitem__(self, key, value):
1463 self.set_ops.append(key)
1464 self.data[key] = value
1465 def clear(self):
1466 self.data.clear()
1467 _orig_wkd = functools.WeakKeyDictionary
1468 td = TracingDict()
1469 functools.WeakKeyDictionary = lambda: td
1470 c = collections
1471 @functools.singledispatch
1472 def g(arg):
1473 return "base"
1474 d = {}
1475 l = []
1476 self.assertEqual(len(td), 0)
1477 self.assertEqual(g(d), "base")
1478 self.assertEqual(len(td), 1)
1479 self.assertEqual(td.get_ops, [])
1480 self.assertEqual(td.set_ops, [dict])
1481 self.assertEqual(td.data[dict], g.registry[object])
1482 self.assertEqual(g(l), "base")
1483 self.assertEqual(len(td), 2)
1484 self.assertEqual(td.get_ops, [])
1485 self.assertEqual(td.set_ops, [dict, list])
1486 self.assertEqual(td.data[dict], g.registry[object])
1487 self.assertEqual(td.data[list], g.registry[object])
1488 self.assertEqual(td.data[dict], td.data[list])
1489 self.assertEqual(g(l), "base")
1490 self.assertEqual(g(d), "base")
1491 self.assertEqual(td.get_ops, [list, dict])
1492 self.assertEqual(td.set_ops, [dict, list])
1493 g.register(list, lambda arg: "list")
1494 self.assertEqual(td.get_ops, [list, dict])
1495 self.assertEqual(len(td), 0)
1496 self.assertEqual(g(d), "base")
1497 self.assertEqual(len(td), 1)
1498 self.assertEqual(td.get_ops, [list, dict])
1499 self.assertEqual(td.set_ops, [dict, list, dict])
1500 self.assertEqual(td.data[dict],
1501 functools._find_impl(dict, g.registry))
1502 self.assertEqual(g(l), "list")
1503 self.assertEqual(len(td), 2)
1504 self.assertEqual(td.get_ops, [list, dict])
1505 self.assertEqual(td.set_ops, [dict, list, dict, list])
1506 self.assertEqual(td.data[list],
1507 functools._find_impl(list, g.registry))
1508 class X:
1509 pass
1510 c.MutableMapping.register(X) # Will not invalidate the cache,
1511 # not using ABCs yet.
1512 self.assertEqual(g(d), "base")
1513 self.assertEqual(g(l), "list")
1514 self.assertEqual(td.get_ops, [list, dict, dict, list])
1515 self.assertEqual(td.set_ops, [dict, list, dict, list])
1516 g.register(c.Sized, lambda arg: "sized")
1517 self.assertEqual(len(td), 0)
1518 self.assertEqual(g(d), "sized")
1519 self.assertEqual(len(td), 1)
1520 self.assertEqual(td.get_ops, [list, dict, dict, list])
1521 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1522 self.assertEqual(g(l), "list")
1523 self.assertEqual(len(td), 2)
1524 self.assertEqual(td.get_ops, [list, dict, dict, list])
1525 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1526 self.assertEqual(g(l), "list")
1527 self.assertEqual(g(d), "sized")
1528 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1529 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1530 g.dispatch(list)
1531 g.dispatch(dict)
1532 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1533 list, dict])
1534 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1535 c.MutableSet.register(X) # Will invalidate the cache.
1536 self.assertEqual(len(td), 2) # Stale cache.
1537 self.assertEqual(g(l), "list")
1538 self.assertEqual(len(td), 1)
1539 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1540 self.assertEqual(len(td), 0)
1541 self.assertEqual(g(d), "mutablemapping")
1542 self.assertEqual(len(td), 1)
1543 self.assertEqual(g(l), "list")
1544 self.assertEqual(len(td), 2)
1545 g.register(dict, lambda arg: "dict")
1546 self.assertEqual(g(d), "dict")
1547 self.assertEqual(g(l), "list")
1548 g._clear_cache()
1549 self.assertEqual(len(td), 0)
1550 functools.WeakKeyDictionary = _orig_wkd
1551
1552
821 def test_main(verbose=None): 1553 def test_main(verbose=None):
822 test_classes = ( 1554 test_classes = (
823 TestPartialC, 1555 TestPartialC,
824 TestPartialPy, 1556 TestPartialPy,
825 TestPartialCSubclass, 1557 TestPartialCSubclass,
826 TestPartialPySubclass, 1558 TestPartialMethod,
827 TestUpdateWrapper, 1559 TestUpdateWrapper,
828 TestTotalOrdering, 1560 TestTotalOrdering,
829 TestCmpToKeyC, 1561 TestCmpToKeyC,
830 TestCmpToKeyPy, 1562 TestCmpToKeyPy,
831 TestWraps, 1563 TestWraps,
832 TestReduce, 1564 TestReduce,
833 TestLRU, 1565 TestLRU,
1566 TestSingleDispatch,
834 ) 1567 )
835 support.run_unittest(*test_classes) 1568 support.run_unittest(*test_classes)
836 1569
837 # verify reference counting 1570 # verify reference counting
838 if verbose and hasattr(sys, "gettotalrefcount"): 1571 if verbose and hasattr(sys, "gettotalrefcount"):
839 import gc 1572 import gc
840 counts = [None] * 5 1573 counts = [None] * 5
841 for i in range(len(counts)): 1574 for i in range(len(counts)):
842 support.run_unittest(*test_classes) 1575 support.run_unittest(*test_classes)
843 gc.collect() 1576 gc.collect()
844 counts[i] = sys.gettotalrefcount() 1577 counts[i] = sys.gettotalrefcount()
845 print(counts) 1578 print(counts)
846 1579
847 if __name__ == '__main__': 1580 if __name__ == '__main__':
848 test_main(verbose=True) 1581 test_main(verbose=True)
LEFTRIGHT

RSS Feeds Recent Issues | This issue
This is Rietveld 894c83f36cb7+