diff --git a/Doc/library/inspect.rst b/Doc/library/inspect.rst --- a/Doc/library/inspect.rst +++ b/Doc/library/inspect.rst @@ -795,6 +795,16 @@ .. versionadded:: 3.3 +.. function:: unwrap(func) + + Get the object wrapped by *func*. It follows the chain of :attr:`__wrapped__` + attributes returning the last object in the chain. + + :exc:`ValueError` is raised if a cycle is encountered. + + .. versionadded:: 3.4 + + .. _inspect-stack: The interpreter stack diff --git a/Lib/inspect.py b/Lib/inspect.py --- a/Lib/inspect.py +++ b/Lib/inspect.py @@ -361,6 +361,25 @@ "Return tuple of base classes (including cls) in method resolution order." return cls.__mro__ +# -------------------------------------------------------- function helpers + +def unwrap(func): + """Get the object wrapped by 'func'. + + Follow the chain of __wrapped__ attributes, + and return the last object in the chain. + If there is a cycle, raise a ValueError. + """ + f = func # remember the original func for error reporting + memo = {id(func)} + while hasattr(func, '__wrapped__'): + func = func.__wrapped__ + if id(func) in memo: + raise ValueError('wrapper loop when unwrapping {!r}'.format(f)) + else: + memo.add(id(func)) + return func + # -------------------------------------------------- source code extraction def indentsize(line): """Return the indent size, in spaces, at the start of a line of text.""" @@ -1352,13 +1371,8 @@ if sig is not None: return sig - try: - # Was this function wrapped by a decorator? - wrapped = obj.__wrapped__ - except AttributeError: - pass - else: - return signature(wrapped) + # Was this function wrapped by a decorator? + obj = unwrap(obj) if isinstance(obj, types.FunctionType): return Signature.from_function(obj) diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py --- a/Lib/test/test_inspect.py +++ b/Lib/test/test_inspect.py @@ -8,6 +8,7 @@ import collections import os import shutil +import functools from os.path import normcase from test.support import run_unittest, TESTFN, DirsOnSysPath @@ -1827,6 +1828,10 @@ self.assertEqual(self.signature(Wrapped), ((('a', ..., ..., "positional_or_keyword"),), ...)) + # wrapper loop: + Wrapped.__wrapped__ = Wrapped + with self.assertRaisesRegex(ValueError, 'wrapper loop'): + self.signature(Wrapped) def test_signature_on_lambdas(self): self.assertEqual(self.signature((lambda a=10: a)), @@ -2268,6 +2273,37 @@ self.assertNotEqual(ba, ba4) +class TestUnwrap(unittest.TestCase): + + def test_unwrap(self): + def func(a, b): + return a + b + wrapper = functools.lru_cache(maxsize=20)(func) + self.assertIs(inspect.unwrap(wrapper), func) + + def test_cycle(self): + def func1(): pass + func1.__wrapped__ = func1 + with self.assertRaisesRegex(ValueError, 'wrapper loop'): + inspect.unwrap(func1) + + def func2(): pass + func2.__wrapped__ = func1 + func1.__wrapped__ = func2 + with self.assertRaisesRegex(ValueError, 'wrapper loop'): + inspect.unwrap(func1) + with self.assertRaisesRegex(ValueError, 'wrapper loop'): + inspect.unwrap(func2) + + def test_unhashable(self): + def func(): pass + func.__wrapped__ = None + class C: + __hash__ = None + __wrapped__ = func + self.assertIsNone(inspect.unwrap(C())) + + def test_main(): run_unittest( TestDecorators, TestRetrievingSourceCode, TestOneliners, TestBuggyCases, @@ -2275,7 +2311,7 @@ TestGetcallargsFunctions, TestGetcallargsMethods, TestGetcallargsUnboundMethods, TestGetattrStatic, TestGetGeneratorState, TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject, - TestBoundArguments, TestGetClosureVars + TestBoundArguments, TestGetClosureVars, TestUnwrap ) if __name__ == "__main__":