diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 2a601e37e3..f74b81b7a2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -27,6 +27,7 @@ from twisted.internet import defer from collections import namedtuple, OrderedDict import functools +import inspect import sys import time import threading @@ -141,13 +142,28 @@ class CacheDescriptor(object): which can be used to insert values into the cache specifically, without calling the calculation function. """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True): + def __init__(self, orig, max_entries=1000, num_args=1, lru=True, + inlineCallbacks=False): self.orig = orig + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + self.max_entries = max_entries self.num_args = num_args self.lru = lru + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + def __get__(self, obj, objtype=None): cache = Cache( name=self.orig.__name__, @@ -158,11 +174,13 @@ class CacheDescriptor(object): @functools.wraps(self.orig) @defer.inlineCallbacks - def wrapped(*keyargs): + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] try: - cached_result = cache.get(*keyargs[:self.num_args]) + cached_result = cache.get(*keyargs) if DEBUG_CACHES: - actual_result = yield self.orig(obj, *keyargs) + actual_result = yield self.function_to_call(obj, *args, **kwargs) if actual_result != cached_result: logger.error( "Stale cache entry %s%r: cached: %r, actual %r", @@ -177,9 +195,9 @@ class CacheDescriptor(object): # while the SELECT is executing (SYN-369) sequence = cache.sequence - ret = yield self.orig(obj, *keyargs) + ret = yield self.function_to_call(obj, *args, **kwargs) - cache.update(sequence, *keyargs[:self.num_args] + (ret,)) + cache.update(sequence, *(keyargs + [ret])) defer.returnValue(ret) @@ -201,6 +219,16 @@ def cached(max_entries=1000, num_args=1, lru=True): ) +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru, + inlineCallbacks=True, + ) + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 940a5f7e08..e3f98f0cde 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, cached +from _base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer @@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_all_server_verify_keys(self, server_name): rows = yield self._simple_select_list( table="server_signature_keys", diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 4cac118d17..a220f3632e 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer import logging @@ -23,8 +23,7 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_push_rules_for_user(self, user_name): rows = yield self._simple_select_list( table=PushRuleTable.table_name, @@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rows) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_push_rules_enabled_for_user(self, user_name): results = yield self._simple_select_list( table=PushRuleEnableTable.table_name, diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 7a6af98d98..b79d6683ca 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer @@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore): def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_max_token(self) - @cached - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_graph_receipts_for_room(self, room_id): """Get receipts for sending to remote servers. """ diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 4612a8aa83..dd5bc2c8fb 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks import collections import logging @@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore): } ) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_room_name_and_aliases(self, room_id): def f(txn): sql = ( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 47bec65497..55c6d52890 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cached, cachedInlineCallbacks from twisted.internet import defer @@ -189,8 +189,7 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cached(num_args=3) - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=3) def get_current_state_for_key(self, room_id, event_type, state_key): def f(txn): sql = (