diff --git a/Lib/functools.py b/Lib/functools.py index 214523c..92276bd 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -388,7 +388,7 @@ def _make_key(args, kwds, typed, return key[0] return _HashedSeq(key) -def lru_cache(maxsize=128, typed=False): +def lru_cache(maxsize=128, typed=False, cache=None): """Least-recently-used cache decorator. If *maxsize* is set to None, the LRU features are disabled and the cache @@ -420,18 +420,20 @@ def lru_cache(maxsize=128, typed=False): raise TypeError('Expected maxsize to be an integer or None') def decorating_function(user_function): - wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) + wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo, cache) return update_wrapper(wrapper, user_function) return decorating_function -def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo): +def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo, cache): # Constants shared by all lru cache instances: sentinel = object() # unique object used to signal cache misses make_key = _make_key # build a key from the function arguments PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields - cache = {} + if cache is None: + cache = {} + hits = misses = 0 full = False cache_get = cache.get # bound method to lookup a key or return None diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index cf0b95d..cfb0e4b 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1113,6 +1113,25 @@ class TestLRU: for attr in self.module.WRAPPER_ASSIGNMENTS: self.assertEqual(getattr(g, attr), getattr(f, attr)) + def test_lru_cache_with_specified_cache(self): + cache = {} + + @self.module.lru_cache(maxsize=None, cache=cache) + def f(n): + return n + + data = (42, "test", (55, 66)) + + for i in data: + f(i) + + # modify cache data to make sure it use our cachedict + for i in cache.keys(): + cache[i] *= 2 + + for i in data: + self.assertEqual(f(i), i*2) + @unittest.skipUnless(threading, 'This test requires threading.') def test_lru_cache_threaded(self): n, m = 5, 11 diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c index 035d3d9..b4d042e 100644 --- a/Modules/_functoolsmodule.c +++ b/Modules/_functoolsmodule.c @@ -1,4 +1,3 @@ - #include "Python.h" #include "structmember.h" @@ -913,17 +912,17 @@ bounded_lru_cache_wrapper(lru_cache_object *self, PyObject *args, PyObject *kwds static PyObject * lru_cache_new(PyTypeObject *type, PyObject *args, PyObject *kw) { - PyObject *func, *maxsize_O, *cache_info_type, *cachedict; + PyObject *func, *maxsize_O, *cache_info_type, *cachedict = NULL; int typed; lru_cache_object *obj; Py_ssize_t maxsize; PyObject *(*wrapper)(lru_cache_object *, PyObject *, PyObject *); static char *keywords[] = {"user_function", "maxsize", "typed", - "cache_info_type", NULL}; + "cache_info_type", "cachedict", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kw, "OOpO:lru_cache", keywords, + if (!PyArg_ParseTupleAndKeywords(args, kw, "OOpOO:lru_cache", keywords, &func, &maxsize_O, &typed, - &cache_info_type)) { + &cache_info_type, &cachedict)) { return NULL; } @@ -951,8 +950,13 @@ lru_cache_new(PyTypeObject *type, PyObject *args, PyObject *kw) return NULL; } - if (!(cachedict = PyDict_New())) - return NULL; + if (cachedict == NULL || cachedict == Py_None) { + if (!(cachedict = PyDict_New())) + return NULL; + } else { + Py_INCREF(cachedict); + } + obj = (lru_cache_object *)type->tp_alloc(type, 0); if (obj == NULL) {