diff -r 8f22e03f5f07 Lib/functools.py --- a/Lib/functools.py Tue Jun 25 08:11:22 2013 -0400 +++ b/Lib/functools.py Sun Jun 30 15:28:32 2013 +0200 @@ -365,36 +365,124 @@ ### singledispatch() - single-dispatch generic function decorator ################################################################################ -def _compose_mro(cls, haystack): - """Calculates the MRO for a given class `cls`, including relevant abstract - base classes from `haystack`. +def _c3_merge(sequences): + """Merges MROs in *sequences* to a single MRO using the C3 algorithm. + + Adapted from http://www.python.org/download/releases/2.3/mro/. + + """ + result = [] + while True: + sequences = [s for s in sequences if s] # purge empty sequences + if not sequences: + return result + for s1 in sequences: # find merge candidates among seq heads + candidate = s1[0] + for s2 in sequences: + if candidate in s2[1:]: + candidate = None + break # reject the current head, it appears later + else: + break + if not candidate: + raise RuntimeError("Inconsistent hierarchy") + result.append(candidate) + # remove the chosen candidate + for seq in sequences: + if seq[0] == candidate: + del seq[0] + +def _c3_mro(cls, abcs=None): + """Computes the method resolution order using extended C3 linearization. + + If no *abcs* are given, the algorithm works exactly like the built-in C3 + linearization used for method resolution. + + If given, *abcs* is a list of abstract base classes that should be inserted + to the resulting MRO. The algorithm inserts ABCs directly where their + functionality is introduced, e.g. issubclass(cls, abc) returns True for the + class itself but returns False for all its direct subclasses. Implicitly + registered ABCs for a given class are inserted directly after the last ABC + explicitly listed in the MRO of said class. Unrelated ABCs are ignored and + don't end up in the result. + + """ + abcs = list(abcs) if abcs else [] + explicit_bases = list(cls.__bases__) + abstract_bases = [] + other_bases = [] + # Move all explicit bases after the last explicit ABC to other_bases. + while explicit_bases: + if hasattr(explicit_bases[-1], '__abstractmethods__'): + break + other_bases.insert(0, explicit_bases.pop()) + for base in abcs: + if issubclass(cls, base) and not any( + issubclass(b, base) for b in cls.__bases__ + ): + # If `cls` is the class that introduces behaviour described by + # an ABC `base`, insert the said ABC to its MRO. + abstract_bases.append(base) + for base in abstract_bases: + abcs.remove(base) + explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases] + abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases] + other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases] + return _c3_merge( + [[cls]] + + explicit_c3_mros + abstract_c3_mros + other_c3_mros + + [explicit_bases] + [abstract_bases] + [other_bases] + ) + +def _compose_mro(cls, types): + """Calculates the method resolution order for a given class *cls*. + + Includes relevant abstract base classes (with their respective bases) from + the *types* iterable. Uses a modified C3 linearization algorithm. """ bases = set(cls.__mro__) - mro = list(cls.__mro__) - for needle in haystack: - if (needle in bases or not hasattr(needle, '__mro__') - or not issubclass(cls, needle)): - continue # either present in the __mro__ already or unrelated - for index, base in enumerate(mro): - if not issubclass(base, needle): - break - if base in bases and not issubclass(needle, base): - # Conflict resolution: put classes present in __mro__ and their - # subclasses first. See test_mro_conflicts() in test_functools.py - # for examples. - index += 1 - mro.insert(index, needle) - return mro + # Remove entries which are already present in the __mro__ or unrelated. + def is_related(typ): + return (typ not in bases and hasattr(typ, '__mro__') + and issubclass(cls, typ)) + types = [n for n in types if is_related(n)] + # Remove entries which are strict bases of other entries (they will end up + # in the MRO anyway. + def is_strict_base(typ): + for other in types: + if typ != other and typ in other.__mro__: + return True + return False + types = [n for n in types if not is_strict_base(n)] + # Subclasses of the ABCs in `types` which are also implemented by + # `cls` can be used to stabilize ABC ordering. + type_set = set(types) + mro = [] + for typ in types: + found = [] + for sub in typ.__subclasses__(): + if sub not in bases and issubclass(cls, sub): + found.append([s for s in sub.__mro__ if s in type_set]) + if not found: + mro.append(typ) + continue + # Favor subclasses with the biggest number of useful bases + found.sort(key=len, reverse=True) + for sub in found: + for subcls in sub: + if subcls not in mro: + mro.append(subcls) + return _c3_mro(cls, abcs=mro) def _find_impl(cls, registry): - """Returns the best matching implementation for the given class `cls` in - `registry`. Where there is no registered implementation for a specific - type, its method resolution order is used to find a more generic - implementation. + """Returns the best matching implementation from *registry* for type *cls*. - Note: if `registry` does not contain an implementation for the base - `object` type, this function may return None. + Where there is no registered implementation for a specific type, its method + resolution order is used to find a more generic implementation. + + Note: if *registry* does not contain an implementation for the base + *object* type, this function may return None. """ mro = _compose_mro(cls, registry.keys()) @@ -402,7 +490,7 @@ for t in mro: if match is not None: # If `match` is an ABC but there is another unrelated, equally - # matching ABC. Refuse the temptation to guess. + # matching ABC, refuse the temptation to guess. if (t in registry and not issubclass(match, t) and match not in cls.__mro__): raise RuntimeError("Ambiguous dispatch: {} or {}".format( @@ -418,19 +506,19 @@ Transforms a function into a generic function, which can have different behaviours depending upon the type of its first argument. The decorated function acts as the default implementation, and additional - implementations can be registered using the 'register()' attribute of - the generic function. + implementations can be registered using the register() attribute of the + generic function. """ registry = {} dispatch_cache = WeakKeyDictionary() cache_token = None - def dispatch(typ): - """generic_func.dispatch(type) -> + def dispatch(cls): + """generic_func.dispatch(cls) -> Runs the dispatch algorithm to return the best available implementation - for the given `type` registered on `generic_func`. + for the given *cls* registered on *generic_func*. """ nonlocal cache_token @@ -440,26 +528,26 @@ dispatch_cache.clear() cache_token = current_token try: - impl = dispatch_cache[typ] + impl = dispatch_cache[cls] except KeyError: try: - impl = registry[typ] + impl = registry[cls] except KeyError: - impl = _find_impl(typ, registry) - dispatch_cache[typ] = impl + impl = _find_impl(cls, registry) + dispatch_cache[cls] = impl return impl - def register(typ, func=None): - """generic_func.register(type, func) -> func + def register(cls, func=None): + """generic_func.register(cls, func) -> func - Registers a new implementation for the given `type` on a `generic_func`. + Registers a new implementation for the given *cls* on a *generic_func*. """ nonlocal cache_token if func is None: - return lambda f: register(typ, f) - registry[typ] = func - if cache_token is None and hasattr(typ, '__abstractmethods__'): + return lambda f: register(cls, f) + registry[cls] = func + if cache_token is None and hasattr(cls, '__abstractmethods__'): cache_token = get_cache_token() dispatch_cache.clear() return func diff -r 8f22e03f5f07 Lib/test/test_functools.py --- a/Lib/test/test_functools.py Tue Jun 25 08:11:22 2013 -0400 +++ b/Lib/test/test_functools.py Sun Jun 30 15:28:32 2013 +0200 @@ -929,22 +929,55 @@ self.assertEqual(g(rnd), ("Number got rounded",)) def test_compose_mro(self): + # None of the examples in this test depend on haystack ordering. c = collections mro = functools._compose_mro bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] for haystack in permutations(bases): m = mro(dict, haystack) - self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object]) + self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized, + c.Iterable, c.Container, object]) bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] for haystack in permutations(bases): m = mro(c.ChainMap, haystack) self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, c.Sized, c.Iterable, c.Container, object]) - # Note: The MRO order below depends on haystack ordering. - m = mro(c.defaultdict, [c.Sized, c.Container, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object]) - m = mro(c.defaultdict, [c.Container, c.Sized, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object]) + + # If there's a generic function with implementations registered for + # both Sized and Container, passing a defaultdict to it results in an + # ambiguous dispatch which will cause a RuntimeError (see + # test_mro_conflicts). + bases = [c.Container, c.Sized, str] + for haystack in permutations(bases): + m = mro(c.defaultdict, [c.Sized, c.Container, str]) + self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, + object]) + + # MutableSequence below is registered directly on D. In other words, it + # preceeds MutableMapping which means single dispatch will always + # choose MutableSequence here. + class D(c.defaultdict): + pass + c.MutableSequence.register(D) + bases = [c.MutableSequence, c.MutableMapping] + for haystack in permutations(bases): + m = mro(D, bases) + self.assertEqual(m, [D, c.MutableSequence, c.Sequence, + c.defaultdict, dict, c.MutableMapping, + c.Mapping, c.Sized, c.Iterable, c.Container, + object]) + + # Container and Callable are registered on different base classes and + # a generic function supporting both should always pick the Callable + # implementation if a C instance is passed. + class C(c.defaultdict): + def __call__(self): + pass + bases = [c.Sized, c.Callable, c.Container, c.Mapping] + for haystack in permutations(bases): + m = mro(C, haystack) + self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping, + c.Sized, c.Iterable, c.Container, object]) def test_register_abc(self): c = collections @@ -1040,17 +1073,37 @@ self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "tuple") + def test_c3_abc(self): + c = collections + mro = functools._c3_mro + class A(object): + pass + class B(A): + def __len__(self): + return 0 # implies Sized + @c.Container.register + class C(object): + pass + class D(object): + pass # unrelated + class X(D, C, B): + def __call__(self): + pass # implies Callable + expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] + for abcs in permutations([c.Sized, c.Callable, c.Container]): + self.assertEqual(mro(X, abcs=abcs), expected) + # unrelated ABCs don't appear in the resulting MRO + many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] + self.assertEqual(mro(X, abcs=many_abcs), expected) + def test_mro_conflicts(self): c = collections - @functools.singledispatch def g(arg): return "base" - class O(c.Sized): def __len__(self): return 0 - o = O() self.assertEqual(g(o), "base") g.register(c.Iterable, lambda arg: "iterable") @@ -1062,35 +1115,72 @@ self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ c.Container.register(O) self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ - + c.Set.register(O) + self.assertEqual(g(o), "set") # because c.Set is a subclass of + # c.Sized and c.Container class P: pass - p = P() self.assertEqual(g(p), "base") c.Iterable.register(P) self.assertEqual(g(p), "iterable") c.Container.register(P) - with self.assertRaises(RuntimeError) as re: + with self.assertRaises(RuntimeError) as re_one: g(p) - self.assertEqual( - str(re), - ("Ambiguous dispatch: " - "or "), - ) - + self.assertEqual( + str(re_one.exception), + ("Ambiguous dispatch: " + "or "), + ) class Q(c.Sized): def __len__(self): return 0 - q = Q() self.assertEqual(g(q), "sized") c.Iterable.register(Q) self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ c.Set.register(Q) self.assertEqual(g(q), "set") # because c.Set is a subclass of - # c.Sized which is explicitly in - # __mro__ + # c.Sized and c.Iterable + @functools.singledispatch + def h(arg): + return "base" + @h.register(c.Sized) + def _(arg): + return "sized" + @h.register(c.Container) + def _(arg): + return "container" + with self.assertRaises(RuntimeError) as re_two: + h(c.defaultdict(lambda: 0)) + self.assertEqual( + str(re_two.exception), + ("Ambiguous dispatch: " + "or "), + ) + class R(c.defaultdict): + pass + c.MutableSequence.register(R) + @functools.singledispatch + def i(arg): + return "base" + @i.register(c.MutableMapping) + def _(arg): + return "mapping" + @i.register(c.MutableSequence) + def _(arg): + return "sequence" + r = R() + self.assertEqual(i(r), "sequence") + class S: + pass + class T(S, c.Sized): + def __len__(self): + return 0 + t = T() + self.assertEqual(h(t), "sized") + c.Container.register(T) + self.assertEqual(h(t), "sized") # because it's explicitly in the MRO def test_cache_invalidation(self): from collections import UserDict