Merge pull request #178 from matrix-org/erikj/cache_state_groups

Add cache to get_state_groups.
This commit is contained in:
Erik Johnston 2015-06-03 17:20:33 +01:00
commit 13ed3b9985
4 changed files with 93 additions and 58 deletions

View file

@ -127,7 +127,7 @@ class Cache(object):
self.cache.clear()
def cached(max_entries=1000, num_args=1, lru=False):
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def wrap(orig):
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
self.orig = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
cache = Cache(
name=orig.__name__,
max_entries=max_entries,
keylen=num_args,
lru=lru,
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
@functools.wraps(orig)
@functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(self, *keyargs):
def wrapped(*keyargs):
try:
cached_result = cache.get(*keyargs)
cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES:
actual_result = yield orig(self, *keyargs)
actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
orig.__name__, keyargs,
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = yield orig(self, *keyargs)
ret = yield self.orig(obj, *keyargs)
cache.update(sequence, *keyargs + (ret,))
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
defer.returnValue(ret)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
return wrap
def cached(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
class LoggingTransaction(object):

View file

@ -81,19 +81,23 @@ class StateStore(SQLBaseStore):
f,
)
@defer.inlineCallbacks
def c(vals):
vals[:] = yield self._get_events(vals, get_prev_content=False)
yield defer.gatherResults(
state_list = yield defer.gatherResults(
[
c(vals)
for vals in states.values()
self._fetch_events_for_group(group, vals)
for group, vals in states.items()
],
consumeErrors=True,
)
defer.returnValue(states)
defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, state_group, events):
return self._get_events(
events, get_prev_content=False
).addCallback(
lambda evs: (state_group, evs)
)
def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None:

View file

@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_passthrough(self):
@cached()
def func(self, key):
return key
class A(object):
@cached()
def func(self, key):
return key
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals((yield func(self, "bar")), "bar")
a = A()
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
func.invalidate("foo")
a.func.invalidate("foo")
yield func(self, "foo")
yield a.func("foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
@cached()
def func(self, key):
return key
class A(object):
@cached()
def func(self, key):
return key
func.invalidate("what")
A().func.invalidate("what")
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
class A(object):
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
for k in range(0,12):
yield func(self, k)
a = A()
for k in range(0, 12):
yield a.func(k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0,12):
yield func(self, k)
yield a.func(k)
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))
@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase):
def test_prefill(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
func.prefill("foo", 123)
a = A()
self.assertEquals((yield func(self, "foo")), 123)
a.func.prefill("foo", 123)
self.assertEquals((yield a.func("foo")), 123)
self.assertEquals(callcount[0], 0)

View file

@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id))
)
result = yield self.store.get_user_by_token(self.tokens[1])
result = yield self.store.get_user_by_token(self.tokens[0])
self.assertDictContainsSubset(
{