Merge pull request #178 from matrix-org/erikj/cache_state_groups

Add cache to get_state_groups.
This commit is contained in:
Erik Johnston 2015-06-03 17:20:33 +01:00
commit 13ed3b9985
4 changed files with 93 additions and 58 deletions

View file

@ -127,7 +127,7 @@ class Cache(object):
self.cache.clear() self.cache.clear()
def cached(max_entries=1000, num_args=1, lru=False): class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in The function is presumed to take zero or more arguments, which are used in
@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
""" """
def wrap(orig): def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
self.orig = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
cache = Cache( cache = Cache(
name=orig.__name__, name=self.orig.__name__,
max_entries=max_entries, max_entries=self.max_entries,
keylen=num_args, keylen=self.num_args,
lru=lru, lru=self.lru,
) )
@functools.wraps(orig) @functools.wraps(self.orig)
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped(self, *keyargs): def wrapped(*keyargs):
try: try:
cached_result = cache.get(*keyargs) cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES: if DEBUG_CACHES:
actual_result = yield orig(self, *keyargs) actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result: if actual_result != cached_result:
logger.error( logger.error(
"Stale cache entry %s%r: cached: %r, actual %r", "Stale cache entry %s%r: cached: %r, actual %r",
orig.__name__, keyargs, self.orig.__name__, keyargs,
cached_result, actual_result, cached_result, actual_result,
) )
raise ValueError("Stale cache entry") raise ValueError("Stale cache entry")
@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = cache.sequence sequence = cache.sequence
ret = yield orig(self, *keyargs) ret = yield self.orig(obj, *keyargs)
cache.update(sequence, *keyargs + (ret,)) cache.update(sequence, *keyargs[:self.num_args] + (ret,))
defer.returnValue(ret) defer.returnValue(ret)
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill wrapped.prefill = cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped return wrapped
return wrap
def cached(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
class LoggingTransaction(object): class LoggingTransaction(object):

View file

@ -81,19 +81,23 @@ class StateStore(SQLBaseStore):
f, f,
) )
@defer.inlineCallbacks state_list = yield defer.gatherResults(
def c(vals):
vals[:] = yield self._get_events(vals, get_prev_content=False)
yield defer.gatherResults(
[ [
c(vals) self._fetch_events_for_group(group, vals)
for vals in states.values() for group, vals in states.items()
], ],
consumeErrors=True, consumeErrors=True,
) )
defer.returnValue(states) defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, state_group, events):
return self._get_events(
events, get_prev_content=False
).addCallback(
lambda evs: (state_group, evs)
)
def _store_state_groups_txn(self, txn, event, context): def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None: if context.current_state is None:

View file

@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_passthrough(self): def test_passthrough(self):
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
return key return key
self.assertEquals((yield func(self, "foo")), "foo") a = A()
self.assertEquals((yield func(self, "bar")), "bar")
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_hit(self): def test_hit(self):
callcount = [0] callcount = [0]
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
yield func(self, "foo") a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
self.assertEquals((yield func(self, "foo")), "foo") self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invalidate(self): def test_invalidate(self):
callcount = [0] callcount = [0]
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
yield func(self, "foo") a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
func.invalidate("foo") a.func.invalidate("foo")
yield func(self, "foo") yield a.func("foo")
self.assertEquals(callcount[0], 2) self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self): def test_invalidate_missing(self):
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
return key return key
func.invalidate("what") A().func.invalidate("what")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_max_entries(self): def test_max_entries(self):
callcount = [0] callcount = [0]
class A(object):
@cached(max_entries=10) @cached(max_entries=10)
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
a = A()
for k in range(0, 12): for k in range(0, 12):
yield func(self, k) yield a.func(k)
self.assertEquals(callcount[0], 12) self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate # There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times # all 12 values again, we must get called at least 2 more times
for k in range(0,12): for k in range(0,12):
yield func(self, k) yield a.func(k)
self.assertTrue(callcount[0] >= 14, self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0])) msg="Expected callcount >= 14, got %d" % (callcount[0]))
@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase):
def test_prefill(self): def test_prefill(self):
callcount = [0] callcount = [0]
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
func.prefill("foo", 123) a = A()
self.assertEquals((yield func(self, "foo")), 123) a.func.prefill("foo", 123)
self.assertEquals((yield a.func("foo")), 123)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)

View file

@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id)) (yield self.store.get_user_by_id(self.user_id))
) )
result = yield self.store.get_user_by_token(self.tokens[1]) result = yield self.store.get_user_by_token(self.tokens[0])
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {