diff -r d6ad9d7468f7 Lib/functools.py --- a/Lib/functools.py Sat Jun 08 16:52:29 2013 +0100 +++ b/Lib/functools.py Thu Jun 20 21:10:25 2013 +0200 @@ -365,27 +365,107 @@ ### singledispatch() - single-dispatch generic function decorator ################################################################################ +def _c3_merge(sequences): + """Merges MROs in `sequences` to a single MRO using the C3 algorithm. + + Adapted from http://www.python.org/download/releases/2.3/mro/. + + """ + result = [] + while True: + sequences = [s for s in sequences if s] # purge empty sequences + if not sequences: + return result + for s1 in sequences: # find merge candidates among seq heads + candidate = s1[0] + for s2 in sequences: + if candidate in s2[1:]: + candidate = None + break # reject the current head, it appears later + else: + break + if not candidate: + raise RuntimeError("Inconsistent hierarchy") + result.append(candidate) + # remove the chosen candidate + for seq in sequences: + if seq[0] == candidate: + del seq[0] + +def _c3_mro(cls, abcs=None): + """Computes the method resolution order using extended C3 linearization. + + If no `abcs` are given, the algorithm works exactly like the built-in C3 + linearization used for method resolution. + + If given, `abcs` is a list of abstract base classes that should + be inserted to the resulting MRO. The algorithm inserts ABCs directly where + their functionality is introduced, e.g. `issubclass(cls, abc)` returns + `True` for the class itself but returns False for all its direct + subclasses. Unrelated ABCs are ignored and don't end up in the result. + + """ + abcs = list(abcs) if abcs else [] + bases = [] + for base in abcs: + if issubclass(cls, base) and not any( + issubclass(b, base) for b in cls.__bases__ + ): + # If `cls` is the class that introduces behaviour described by + # an ABC `base`, insert the said ABC to its MRO. + bases.append(base) + for base in bases: + abcs.remove(base) + abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in bases] + concrete_c3_mros = [_c3_mro(base, abcs=abcs) for base in cls.__bases__] + abstract_bases = [bases] + concrete_bases = [list(cls.__bases__)] + return _c3_merge( + [[cls]] + + abstract_c3_mros + concrete_c3_mros + + abstract_bases + concrete_bases + ) + def _compose_mro(cls, haystack): """Calculates the MRO for a given class `cls`, including relevant abstract - base classes from `haystack`. + base classes (with their respective bases) from `haystack`. Uses a modified + C3 linearization algorithm. """ bases = set(cls.__mro__) - mro = list(cls.__mro__) + def is_related(needle): + """Remove entries which are already present in the __mro__ or + unrelated.""" + return (needle not in bases and hasattr(needle, '__mro__') + and issubclass(cls, needle)) + haystack = [n for n in haystack if is_related(n)] + def is_strict_base(needle): + """Remove entries which are strict bases of other entries (they will + end up in the MRO anyway.""" + for other in haystack: + if needle != other and needle in other.__mro__: + return True + return False + haystack = [n for n in haystack if not is_strict_base(n)] + # Subclasses of the ABCs in `haystack` which are also implemented by + # `cls` can be used to stabilize ABC ordering. + hayset = set(haystack) + mro = [] for needle in haystack: - if (needle in bases or not hasattr(needle, '__mro__') - or not issubclass(cls, needle)): - continue # either present in the __mro__ already or unrelated - for index, base in enumerate(mro): - if not issubclass(base, needle): - break - if base in bases and not issubclass(needle, base): - # Conflict resolution: put classes present in __mro__ and their - # subclasses first. See test_mro_conflicts() in test_functools.py - # for examples. - index += 1 - mro.insert(index, needle) - return mro + found = [] + for sub in needle.__subclasses__(): + if sub not in bases and issubclass(cls, sub): + found.append([s for s in sub.__mro__ if s in hayset]) + if not found: + mro.append(needle) + continue + # Favor subclasses with the biggest number of useful bases + found.sort(key=len, reverse=True) + for sub in found: + for entry in sub: + if entry not in mro: + mro.append(entry) + return _c3_mro(cls, abcs=mro) def _find_impl(cls, registry): """Returns the best matching implementation for the given class `cls` in diff -r d6ad9d7468f7 Lib/test/test_functools.py --- a/Lib/test/test_functools.py Sat Jun 08 16:52:29 2013 +0100 +++ b/Lib/test/test_functools.py Thu Jun 20 21:10:25 2013 +0200 @@ -929,22 +929,53 @@ self.assertEqual(g(rnd), ("Number got rounded",)) def test_compose_mro(self): + # None of the examples below depend on haystack ordering. c = collections mro = functools._compose_mro bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] for haystack in permutations(bases): m = mro(dict, haystack) - self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object]) + self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized, + c.Iterable, c.Container, object]) bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] for haystack in permutations(bases): m = mro(c.ChainMap, haystack) self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, c.Sized, c.Iterable, c.Container, object]) - # Note: The MRO order below depends on haystack ordering. - m = mro(c.defaultdict, [c.Sized, c.Container, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object]) - m = mro(c.defaultdict, [c.Container, c.Sized, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object]) + # Note: For the MRO below, if there's a generic function with + # implementations registered for both Sized and Container, + # passing a defaultdict to it results in an ambiguous dispatch + # which will cause a RuntimeError (see test_mro_conflicts). + bases = [c.Container, c.Sized, str] + for haystack in permutations(bases): + m = mro(c.defaultdict, [c.Sized, c.Container, str]) + self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, + object]) + # Note: In the example below MutableSequence is registered directly + # on D. In other words, it preceeds MutableMapping which means + # single dispatch will always choose MutableSequence here. + class D(c.defaultdict): + pass + c.MutableSequence.register(D) + bases = [c.MutableSequence, c.MutableMapping] + for haystack in permutations(bases): + m = mro(D, bases) + self.assertEqual(m, [D, c.MutableSequence, c.Sequence, + c.defaultdict, dict, c.MutableMapping, + c.Mapping, c.Sized, c.Iterable, c.Container, + object]) + # The MRO below should not depend on haystack + # ordering. Container and Callable are registered on different base + # classes and a generic function supporting both should always pick + # the Callable implementation if a C instance is passed. + class C(c.defaultdict): + def __call__(self): + pass + bases = [c.Sized, c.Callable, c.Container, c.Mapping] + for haystack in permutations(bases): + m = mro(C, haystack) + self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping, + c.Sized, c.Iterable, c.Container, object]) def test_register_abc(self): c = collections @@ -1040,17 +1071,41 @@ self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "tuple") + def test_c3_abc(self): + c = collections + class A(object): + pass + class B(A): + def __len__(self): + return 0 # implies Sized + @c.Container.register + class C(object): + pass + class D(object): + pass # unrelated + class X(D, C, B): + def __call__(self): + pass # implies Callable + for abcs in permutations([c.Sized, c.Callable, c.Container]): + self.assertEqual( + functools._c3_mro(X, abcs=abcs), + [X, c.Callable, D, C, c.Container, B, c.Sized, A, object], + ) + # unrelated ABCs don't appear in the resulting MRO + self.assertEqual( + functools._c3_mro(X, abcs=[c.Mapping, c.Sized, c.Callable, + c.Container, c.Iterable]), + [X, c.Callable, D, C, c.Container, B, c.Sized, A, object], + ) + def test_mro_conflicts(self): c = collections - @functools.singledispatch def g(arg): return "base" - class O(c.Sized): def __len__(self): return 0 - o = O() self.assertEqual(g(o), "base") g.register(c.Iterable, lambda arg: "iterable") @@ -1058,39 +1113,77 @@ g.register(c.Sized, lambda arg: "sized") g.register(c.Set, lambda arg: "set") self.assertEqual(g(o), "sized") - c.Iterable.register(O) - self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ c.Container.register(O) - self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ - + with self.assertRaises(RuntimeError) as re1: + g(o) + self.assertEqual( + str(re1.exception), + ("Ambiguous dispatch: " + "or "), + ) + c.Set.register(O) + self.assertEqual(g(o), "set") # because c.Set is a subclass of + # c.Sized and c.Container class P: pass - p = P() self.assertEqual(g(p), "base") c.Iterable.register(P) self.assertEqual(g(p), "iterable") c.Container.register(P) - with self.assertRaises(RuntimeError) as re: + with self.assertRaises(RuntimeError) as re1: g(p) - self.assertEqual( - str(re), - ("Ambiguous dispatch: " - "or "), - ) - + self.assertEqual( + str(re1.exception), + ("Ambiguous dispatch: " + "or "), + ) class Q(c.Sized): def __len__(self): return 0 - q = Q() self.assertEqual(g(q), "sized") c.Iterable.register(Q) - self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ + with self.assertRaises(RuntimeError) as re2: + g(q) + self.assertEqual( + str(re2.exception), + ("Ambiguous dispatch: " + "or "), + ) c.Set.register(Q) self.assertEqual(g(q), "set") # because c.Set is a subclass of - # c.Sized which is explicitly in - # __mro__ + # c.Sized and c.Iterable + @functools.singledispatch + def h(arg): + return "base" + @h.register(c.Sized) + def _(arg): + return "sized" + @h.register(c.Container) + def _(arg): + return "container" + with self.assertRaises(RuntimeError) as re3: + h(c.defaultdict(lambda: 0)) + self.assertEqual( + str(re3.exception), + ("Ambiguous dispatch: " + "or "), + ) + class R(c.defaultdict): + pass + c.MutableSequence.register(R) + @functools.singledispatch + def i(arg): + return "base" + @i.register(c.MutableMapping) + def _(arg): + return "mapping" + @i.register(c.MutableSequence) + def _(arg): + return "sequence" + r = R() + self.assertEqual(i(r), "sequence") def test_cache_invalidation(self): from collections import UserDict