diff -r 9b084337f13e Doc/library/functools.rst --- a/Doc/library/functools.rst Sun Oct 27 12:38:59 2013 +0200 +++ b/Doc/library/functools.rst Tue Oct 29 01:59:48 2013 +0200 @@ -194,6 +194,23 @@ 18 +.. class:: partialmethod(func, *args, **keywords) + + Return a new :class:`partialmethod` object which behaves like :class:`partial` except that it supports methods. When *func* is a function and retrieved via attribute lookup, a bound method is returned (upon call, the *self* argument will be prepended before *args* and *keywords*). + + *func* can also be a descriptor such as `classmethod`, `staticmethod`, `abstractmethod` or another instance of :class:`partialmethod`. In that case the call is delegated to the descriptor. + + Example:: + + class Cell(object): + def set_state(self, state): + self._state = state + set_alive = partialmethod(set_state, True) + set_dead = partialmethod(set_state, False) + + .. versionadded:: 3.4 + + .. function:: reduce(function, iterable[, initializer]) Apply *function* of two arguments cumulatively to the items of *sequence*, from @@ -431,4 +448,3 @@ are not created automatically. Also, :class:`partial` objects defined in classes behave like static methods and do not transform into bound methods during instance attribute look-up. - diff -r 9b084337f13e Lib/functools.py --- a/Lib/functools.py Sun Oct 27 12:38:59 2013 +0200 +++ b/Lib/functools.py Tue Oct 29 01:59:48 2013 +0200 @@ -638,3 +638,68 @@ wrapper._clear_cache = dispatch_cache.clear update_wrapper(wrapper, func) return wrapper + +class partialmethod(object): + """A 'partial' implemenation for methods. + + In contrast to partial, partialmethod is a descriptor (implements __get__) and can bind + arguments to methods and other descriptors such as classsmethod and staticmethod. + """ + + def __init__(self, func, *args, **keywords): + if not callable(func) and not hasattr(func, "__get__"): + raise TypeError("the first argument must be a callable or a descriptor") + + # func could be a descriptor like classmethod which isn't callable, + # so we can't inherit from partial (it verifies func is callable) + if isinstance(func, partialmethod): + # flattening is mandatory in order to place cls/self before all other arguments + # it's also more efficient since only one function will be called + self.func = func.func + self.args = func.args + args + self.keywords = {} + self.keywords.update(func.keywords) + self.keywords.update(keywords) + else: + self.func = func + self.args = args + self.keywords = keywords + + def __repr__(self): + args = ", ".join(map(repr, self.args)) + keywords = ", ".join("{}={!r}".format(k, v) for k, v in self.keywords.items()) + format_string = "{module}.{cls}({func}, {args}, {keywords})" + return format_string.format(module=self.__class__.__module__, + cls=self.__class__.__name__, + func=self.func, + args=args, + keywords=keywords) + + def _make_unbound_method(self): + def _method(*args, **keywords): + call_keywords = self.keywords.copy() + call_keywords.update(keywords) + cls_or_self, *rest = args + call_args = (cls_or_self,) + self.args + tuple(rest) + return self.func(*call_args, **call_keywords) + return _method + + def __get__(self, obj, cls): + get = getattr(self.func, "__get__", None) + if get is None: + if obj is None: + result = self._make_unbound_method() + else: + result = partial(self._make_unbound_method(), obj) # returns a bound method + else: + callable = get(obj, cls) + if callable is self.func: + result = self._make_unbound_method() + else: + result = partial(callable, *self.args, **self.keywords) + result.__isabstractmethod__ = self.__isabstractmethod__ + return result + + @property + def __isabstractmethod__(self): + return getattr(self.func, "__isabstractmethod__", False) diff -r 9b084337f13e Lib/test/test_functools.py --- a/Lib/test/test_functools.py Sun Oct 27 12:38:59 2013 +0200 +++ b/Lib/test/test_functools.py Tue Oct 29 01:59:48 2013 +0200 @@ -1,3 +1,4 @@ +import abc import collections from itertools import permutations import pickle @@ -1427,6 +1428,101 @@ self.assertEqual(len(td), 0) functools.WeakKeyDictionary = _orig_wkd +class TestPartialMethod(unittest.TestCase): + + class A(object): + nothing = functools.partialmethod(capture) + positional = functools.partialmethod(capture, 1) + keywords = functools.partialmethod(capture, a=2) + both = functools.partialmethod(capture, 3, b=4) + + nested = functools.partialmethod(positional, 5) + + over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) + + static = functools.partialmethod(staticmethod(capture), 8) + cls = functools.partialmethod(classmethod(capture), d=9) + + a = A() + + def test_arg_combinations(self): + self.assertEqual(self.a.nothing(), ((self.a,), {})) + self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) + self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) + self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) + + self.assertEqual(self.a.positional(), ((self.a, 1), {})) + self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) + self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) + self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) + + self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) + self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) + self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) + self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) + + self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) + self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) + self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) + self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) + + self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) + + def test_nested(self): + self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) + self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) + self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) + self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) + + self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) + + def test_over_partial(self): + self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) + self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) + self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) + self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) + + self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) + + def test_descriptors(self): + for obj in [self.A, self.a]: + self.assertEqual(obj.static(), ((8,), {})) + self.assertEqual(obj.static(5), ((8, 5), {})) + self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) + self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) + + self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) + self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) + self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) + self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) + + def test_overriding_keywords(self): + self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) + + def test_invalid_args(self): + with self.assertRaises(TypeError): + class B(object): + method = functools.partialmethod(None, 1) + + def test_repr(self): + self.assertEqual(repr(vars(self.A)['both']), + 'functools.partialmethod({}, 3, b=4)'.format(capture)) + + def test_abstract(self): + class Abstract(object): + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def add(self, x, y): + pass + + add5 = functools.partialmethod(add, 5) + + self.assertTrue(Abstract.add.__isabstractmethod__) + self.assertTrue(Abstract.add5.__isabstractmethod__) + + for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: + self.assertFalse(getattr(func, '__isabstractmethod__', False)) def test_main(verbose=None): test_classes = ( @@ -1441,6 +1537,7 @@ TestReduce, TestLRU, TestSingleDispatch, + TestPartialMethod, ) support.run_unittest(*test_classes)