diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index e93ff40dc0..8dba61d49f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,6 +25,7 @@ from synapse.util.logcontext import ( from . import DEBUG_CACHES, register_cache from twisted.internet import defer +from collections import namedtuple import os import functools @@ -210,16 +211,17 @@ class CacheDescriptor(object): # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - self_context = _CacheContext(cache, None) + # Add temp cache_context so inspect.getcallargs doesn't explode if self.add_cache_context: - kwargs["cache_context"] = self_context + kwargs["cache_context"] = None arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) - self_context.key = cache_key + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) @@ -414,13 +416,7 @@ class CacheListDescriptor(object): return wrapped -class _CacheContext(object): - __slots__ = ["cache", "key"] - - def __init__(self, cache, key): - self.cache = cache - self.key = key - +class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): def invalidate(self): self.cache.invalidate(self.key) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a5a827b4d1..9c4c679175 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -32,7 +32,7 @@ def enumerate_leaves(node, depth): class _Node(object): __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] - def __init__(self, prev_node, next_node, key, value, callbacks=[]): + def __init__(self, prev_node, next_node, key, value, callbacks=set()): self.prev_node = prev_node self.next_node = next_node self.key = key @@ -66,7 +66,7 @@ class LruCache(object): return inner - def add_node(key, value, callbacks=[]): + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node node = _Node(prev_node, next_node, key, value, callbacks) @@ -94,7 +94,7 @@ class LruCache(object): for cb in node.callbacks: cb() - node.callbacks = [] + node.callbacks.clear() @synchronized def cache_get(key, default=None, callback=None): @@ -102,7 +102,7 @@ class LruCache(object): if node is not None: move_node_to_front(node) if callback: - node.callbacks.append(callback) + node.callbacks.add(callback) return node.value else: return default @@ -114,18 +114,18 @@ class LruCache(object): if value != node.value: for cb in node.callbacks: cb() - node.callbacks = [] + node.callbacks.clear() if callback: - node.callbacks.append(callback) + node.callbacks.add(callback) move_node_to_front(node) node.value = value else: if callback: - callbacks = [callback] + callbacks = set([callback]) else: - callbacks = [] + callbacks = set() add_node(key, value, callbacks) if len(cache) > max_size: todelete = list_root.prev_node diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 4fc3639de0..ab6095564a 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -17,6 +17,8 @@ from tests import unittest from twisted.internet import defer +from mock import Mock + from synapse.util.async import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached @@ -265,3 +267,49 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 4) self.assertEquals(callcount2[0], 3) + + @defer.inlineCallbacks + def test_double_get(self): + callcount = [0] + callcount2 = [0] + + class A(object): + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + a.func2.cache.cache = Mock(wraps=a.func2.cache.cache) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + + yield a.func2("foo") + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 2) + + a.func.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 3) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index bacec2f465..1eba5b535e 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -50,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.setdefault("key", 2), 1) self.assertEquals(cache.get("key"), 1) + cache["key"] = 2 # Make sure overriding works. + self.assertEquals(cache.get("key"), 2) def test_pop(self): cache = LruCache(1) @@ -84,6 +86,44 @@ class LruCacheTestCase(unittest.TestCase): class LruCacheCallbacksTestCase(unittest.TestCase): + def test_get(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value") + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.get("key", "value") + self.assertFalse(m.called) + + cache.set("key", "value2") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + + def test_multi_get(self): + m = Mock() + cache = LruCache(1) + + cache.set("key", "value") + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.get("key", callback=m) + self.assertFalse(m.called) + + cache.set("key", "value2") + self.assertEquals(m.call_count, 1) + + cache.set("key", "value") + self.assertEquals(m.call_count, 1) + def test_set(self): m = Mock() cache = LruCache(1)