Merge pull request #209 from matrix-org/erikj/cached_keyword_args

Add support for using keyword arguments with cached functions
This commit is contained in:
Erik Johnston 2015-08-06 13:52:49 +01:00
commit 8049c9a71e
6 changed files with 45 additions and 23 deletions

View file

@ -27,6 +27,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import functools import functools
import inspect
import sys import sys
import time import time
import threading import threading
@ -141,13 +142,28 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
""" """
def __init__(self, orig, max_entries=1000, num_args=1, lru=True): def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries self.max_entries = max_entries
self.num_args = num_args self.num_args = num_args
self.lru = lru self.lru = lru
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cache = Cache( cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
@ -158,11 +174,13 @@ class CacheDescriptor(object):
@functools.wraps(self.orig) @functools.wraps(self.orig)
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped(*keyargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
try: try:
cached_result = cache.get(*keyargs[:self.num_args]) cached_result = cache.get(*keyargs)
if DEBUG_CACHES: if DEBUG_CACHES:
actual_result = yield self.orig(obj, *keyargs) actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result: if actual_result != cached_result:
logger.error( logger.error(
"Stale cache entry %s%r: cached: %r, actual %r", "Stale cache entry %s%r: cached: %r, actual %r",
@ -177,9 +195,9 @@ class CacheDescriptor(object):
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = cache.sequence sequence = cache.sequence
ret = yield self.orig(obj, *keyargs) ret = yield self.function_to_call(obj, *args, **kwargs)
cache.update(sequence, *keyargs[:self.num_args] + (ret,)) cache.update(sequence, *(keyargs + [ret]))
defer.returnValue(ret) defer.returnValue(ret)
@ -201,6 +219,16 @@ def cached(max_entries=1000, num_args=1, lru=True):
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore, cached from _base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate", desc="store_server_certificate",
) )
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_all_server_verify_keys(self, server_name): def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="server_signature_keys", table="server_signature_keys",

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_for_user(self, user_name): def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table=PushRuleTable.table_name, table=PushRuleTable.table_name,
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name): def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name, table=PushRuleEnableTable.table_name,

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self) return self._receipts_id_gen.get_max_token(self)
@cached @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_graph_receipts_for_room(self, room_id): def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers. """Get receipts for sending to remote servers.
""" """

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
import collections import collections
import logging import logging
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
} }
) )
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):
sql = ( sql = (

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cached, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -189,8 +189,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@cached(num_args=3) @cachedInlineCallbacks(num_args=3)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key): def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn): def f(txn):
sql = ( sql = (