Move simplegeneric into functools, and document it diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -57,6 +57,53 @@ 18 +.. function:: simplegeneric(default) + + This is a function decorator which makes a function into a *generic + function*. Generic functions may have multiple overloads, distinguished by + the type of the first argument. + + Adding an overload to a generic function is achieved by using the + :func:`register` attribute of the generic function. The :func:`register` + attribute is also a decorator, taking a type paramater and decorating a + function implementing the overload for that type. The register attribute + can also be used in a functional form, ``genericfn.register(type, + overload)``. This is normally only needed when registering an existing + function or a lambda form as an overload. + + Example:: + + @simplegeneric + def pprint(obj): + print str(obj) + + @pprint.register(list) + def pprint_list(lst): + if len(lst) == 0: + print "[]" + else: + print "[" + str(lst[0]) + ", ...]" + + Where there is no registered overload for the specific type of the first + argument, overloads are checked for using the method resolution order of + the type of the first argument - so, for example, subclasses of list would + be handled by pprint_list in the example above (unless a more specific + overload were defined). + + Note that it is not possible to register an overload for an :term:`abstract + base class` on a generic function. This is at least in part because there + is no defined "inheritance order" for ABC registrations, corresponding to + the MRO for class inheritance. (Generic functions are designed to *avoid* + instance checks, where ABCs are designed to make use of them, so it isn't + entirely surprising that the two mechanisms do not interact well together). + If a generic function needs to respect ABCs, the necessary + :func:`isinstance` check can be included in the default generic function + implementation. + + + .. versionadded:: 2.7 + + .. function:: update_wrapper(wrapper, wrapped[, assigned][, updated]) Update a *wrapper* function to look like the *wrapped* function. The optional diff --git a/Lib/functools.py b/Lib/functools.py --- a/Lib/functools.py +++ b/Lib/functools.py @@ -49,3 +49,48 @@ """ return partial(update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated) + +# simplegeneric is a basic generic function implementation, +# dispatching on the type of the first argument. + +def simplegeneric(func): + """Decorator to make a single-dispatch generic function + + Makes a function into a generic function, which can have different + behaviours depending upon the type of its first argument. The decorated + function acts as the default implementation, and additional + implementations can be registered using the 'register' method of the + generic function. + """ + registry = {} + def wrapper(*args, **kw): + ob = args[0] + try: + cls = ob.__class__ + except AttributeError: + cls = type(ob) + try: + mro = cls.__mro__ + except AttributeError: + try: + class cls(cls, object): + pass + mro = cls.__mro__[1:] + except TypeError: + mro = object, # must be an ExtensionClass or some such :( + for t in mro: + if t in registry: + return registry[t](*args, **kw) + else: + return func(*args, **kw) + + def register(typ, func=None): + if func is None: + return lambda f: register(typ, f) + registry[typ] = func + return func + + update_wrapper(wrapper, func) + wrapper.register = register + return wrapper + diff --git a/Lib/pkgutil.py b/Lib/pkgutil.py --- a/Lib/pkgutil.py +++ b/Lib/pkgutil.py @@ -8,6 +8,7 @@ import imp import os.path from types import ModuleType +from functools import simplegeneric __all__ = [ 'get_importer', 'iter_importers', 'get_loader', 'find_loader', @@ -28,45 +29,6 @@ return marshal.load(stream) -def simplegeneric(func): - """Make a trivial single-dispatch generic function""" - registry = {} - def wrapper(*args, **kw): - ob = args[0] - try: - cls = ob.__class__ - except AttributeError: - cls = type(ob) - try: - mro = cls.__mro__ - except AttributeError: - try: - class cls(cls, object): - pass - mro = cls.__mro__[1:] - except TypeError: - mro = object, # must be an ExtensionClass or some such :( - for t in mro: - if t in registry: - return registry[t](*args, **kw) - else: - return func(*args, **kw) - try: - wrapper.__name__ = func.__name__ - except (TypeError, AttributeError): - pass # Python 2.3 doesn't allow functions to be renamed - - def register(typ, func=None): - if func is None: - return lambda f: register(typ, f) - registry[typ] = func - return func - - wrapper.__dict__ = func.__dict__ - wrapper.__doc__ = func.__doc__ - wrapper.register = register - return wrapper - def walk_packages(path=None, prefix='', onerror=None): """Yields (module_loader, name, ispkg) for all modules recursively diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -308,6 +308,64 @@ self.assertRaises(TypeError, reduce, 42, (42, 42)) +class TestSimpleGeneric(unittest.TestCase): + def test_simple_overloads(self): + @functools.simplegeneric + def g(obj): + return "base" + def int_impl(i): + return "integer" + g.register(int, int_impl) + self.assertEqual(g("str"), "base") + self.assertEqual(g(1), "integer") + self.assertEqual(g([1,2,3]), "base") + + def test_mro(self): + @functools.simplegeneric + def g(obj): + return "base" + class C(object): + pass + class D(C): + pass + def C_impl(c): + return "C" + g.register(C, C_impl) + self.assertEqual(g(C()), "C") + self.assertEqual(g(D()), "C") + + def test_classic_classes(self): + @functools.simplegeneric + def g(obj): + return "base" + class C: + pass + class D(C): + pass + def C_impl(c): + return "C" + g.register(C, C_impl) + self.assertEqual(g(C()), "C") + self.assertEqual(g(D()), "C") + + def test_register_decorator(self): + @functools.simplegeneric + def g(obj): + return "base" + @g.register(int) + def g_int(i): + return "int %s" % (i,) + self.assertEqual(g(""), "base") + self.assertEqual(g(12), "int 12") + + def test_wrapping_attributes(self): + @functools.simplegeneric + def g(obj): + "Simple test" + return "Test" + self.assertEqual(g.__name__, "g") + self.assertEqual(g.__doc__, "Simple test") + def test_main(verbose=None): @@ -319,6 +377,7 @@ TestUpdateWrapper, TestWraps, TestReduce, + TestSimpleGeneric, ) test_support.run_unittest(*test_classes)