Ensure invalidation list does not grow unboundedly

This commit is contained in:
Erik Johnston 2016-08-19 15:58:52 +01:00
parent c0d7d9d642
commit 45fd2c8942
4 changed files with 104 additions and 20 deletions

View file

@ -25,6 +25,7 @@ from synapse.util.logcontext import (
from . import DEBUG_CACHES, register_cache from . import DEBUG_CACHES, register_cache
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
import os import os
import functools import functools
@ -210,16 +211,17 @@ class CacheDescriptor(object):
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
# Add our own `cache_context` to argument list if the wrapped function # Add temp cache_context so inspect.getcallargs doesn't explode
# has asked for one
self_context = _CacheContext(cache, None)
if self.add_cache_context: if self.add_cache_context:
kwargs["cache_context"] = self_context kwargs["cache_context"] = None
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
self_context.key = cache_key # Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext(cache, cache_key)
try: try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback) cached_result_d = cache.get(cache_key, callback=invalidate_callback)
@ -414,13 +416,7 @@ class CacheListDescriptor(object):
return wrapped return wrapped
class _CacheContext(object): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
__slots__ = ["cache", "key"]
def __init__(self, cache, key):
self.cache = cache
self.key = key
def invalidate(self): def invalidate(self):
self.cache.invalidate(self.key) self.cache.invalidate(self.key)

View file

@ -32,7 +32,7 @@ def enumerate_leaves(node, depth):
class _Node(object): class _Node(object):
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
def __init__(self, prev_node, next_node, key, value, callbacks=[]): def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.prev_node = prev_node self.prev_node = prev_node
self.next_node = next_node self.next_node = next_node
self.key = key self.key = key
@ -66,7 +66,7 @@ class LruCache(object):
return inner return inner
def add_node(key, value, callbacks=[]): def add_node(key, value, callbacks=set()):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value, callbacks) node = _Node(prev_node, next_node, key, value, callbacks)
@ -94,7 +94,7 @@ class LruCache(object):
for cb in node.callbacks: for cb in node.callbacks:
cb() cb()
node.callbacks = [] node.callbacks.clear()
@synchronized @synchronized
def cache_get(key, default=None, callback=None): def cache_get(key, default=None, callback=None):
@ -102,7 +102,7 @@ class LruCache(object):
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
if callback: if callback:
node.callbacks.append(callback) node.callbacks.add(callback)
return node.value return node.value
else: else:
return default return default
@ -114,18 +114,18 @@ class LruCache(object):
if value != node.value: if value != node.value:
for cb in node.callbacks: for cb in node.callbacks:
cb() cb()
node.callbacks = [] node.callbacks.clear()
if callback: if callback:
node.callbacks.append(callback) node.callbacks.add(callback)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
else: else:
if callback: if callback:
callbacks = [callback] callbacks = set([callback])
else: else:
callbacks = [] callbacks = set()
add_node(key, value, callbacks) add_node(key, value, callbacks)
if len(cache) > max_size: if len(cache) > max_size:
todelete = list_root.prev_node todelete = list_root.prev_node

View file

@ -17,6 +17,8 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.descriptors import Cache, cached
@ -265,3 +267,49 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 4) self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3) self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)

View file

@ -50,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1) self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get("key"), 1)
cache["key"] = 2 # Make sure overriding works.
self.assertEquals(cache.get("key"), 2)
def test_pop(self): def test_pop(self):
cache = LruCache(1) cache = LruCache(1)
@ -84,6 +86,44 @@ class LruCacheTestCase(unittest.TestCase):
class LruCacheCallbacksTestCase(unittest.TestCase): class LruCacheCallbacksTestCase(unittest.TestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", "value")
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_multi_get(self):
m = Mock()
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.get("key", callback=m)
self.assertFalse(m.called)
cache.set("key", "value2")
self.assertEquals(m.call_count, 1)
cache.set("key", "value")
self.assertEquals(m.call_count, 1)
def test_set(self): def test_set(self):
m = Mock() m = Mock()
cache = LruCache(1) cache = LruCache(1)