Make @cached cache deferreds rather than the deferreds' values

This commit is contained in:
Erik Johnston 2015-08-06 13:33:34 +01:00
parent 39e21ea51c
commit 7eea3e356f
3 changed files with 22 additions and 19 deletions

View file

@ -15,6 +15,7 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
@ -173,33 +174,27 @@ class CacheDescriptor(object):
) )
@functools.wraps(self.orig) @functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
try: try:
cached_result = cache.get(*keyargs) cached_result = cache.get(*keyargs)
if DEBUG_CACHES: return cached_result.observe()
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = cache.sequence sequence = cache.sequence
ret = yield self.function_to_call(obj, *args, **kwargs) ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
ret = ObservableDeferred(ret, consumeErrors=False)
cache.update(sequence, *(keyargs + [ret])) cache.update(sequence, *(keyargs + [ret]))
defer.returnValue(ret) return ret.observe()
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all

View file

@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set()) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) object.__setattr__(self, "_result", (True, r))
while self._observers: while self._observers:
try: try:
self._observers.pop().callback(r) self._observers.pop().callback(r)
@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r return r
def errback(f): def errback(f):
self._result = (False, f) object.__setattr__(self, "_result", (False, f))
while self._observers: while self._observers:
try: try:
self._observers.pop().errback(f) self._observers.pop().errback(f)
@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self._deferred, name, value) setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)

View file

@ -17,6 +17,8 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import ObservableDeferred
from synapse.storage._base import Cache, cached from synapse.storage._base import Cache, cached
@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertTrue(callcount[0] >= 14, self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0])) msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self): def test_prefill(self):
callcount = [0] callcount = [0]
d = defer.succeed(123)
class A(object): class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return d
a = A() a = A()
a.func.prefill("foo", 123) a.func.prefill("foo", ObservableDeferred(d))
self.assertEquals((yield a.func("foo")), 123) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)