Merge pull request #3383 from matrix-org/rav/hacky_speedup_state_group_cache

Improve StateGroupCache efficiency for wildcard lookups
This commit is contained in:
Richard van der Hoff 2018-06-22 14:01:55 +01:00 committed by GitHub
commit bb018d0b5b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 32 deletions

View file

@ -526,10 +526,23 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None): def _get_state_for_groups(self, groups, types=None):
"""Given list of groups returns dict of group -> list of state events """Gets the state at each of a list of state groups, optionally
with matching types. `types` is a list of `(type, state_key)`, where filtering by type/state_key
a `state_key` of None matches all state_keys. If `types` is None then
all events are returned. Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
types (None|iterable[(str, None|str)]):
indicates the state type/keys required. If None, the whole
state is fetched and returned.
Otherwise, each entry should be a `(type, state_key)` tuple to
include in the response. A `state_key` of None is a wildcard
meaning that we require all state with that type.
Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary.
""" """
if types: if types:
types = frozenset(types) types = frozenset(types)
@ -538,7 +551,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if types is not None: if types is not None:
for group in set(groups): for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache( state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types group, types,
) )
results[group] = state_dict_ids results[group] = state_dict_ids
@ -559,22 +572,40 @@ class StateGroupWorkerStore(SQLBaseStore):
# Okay, so we have some missing_types, lets fetch them. # Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence cache_seq_num = self._state_group_cache.sequence
# the DictionaryCache knows if it has *all* the state, but
# does not know if it has all of the keys of a particular type,
# which makes wildcard lookups expensive unless we have a complete
# cache. Hence, if we are doing a wildcard lookup, populate the
# cache fully so that we can do an efficient lookup next time.
if types and any(k is None for (t, k) in types):
types_to_fetch = None
else:
types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types missing_groups, types_to_fetch,
) )
# Now we want to update the cache with all the things we fetched
# from the database.
for group, group_state_dict in iteritems(group_to_state_dict): for group, group_state_dict in iteritems(group_to_state_dict):
state_dict = results[group] state_dict = results[group]
# update the result, filtering by `types`.
if types:
for k, v in iteritems(group_state_dict):
(typ, _) = k
if k in types or (typ, None) in types:
state_dict[k] = v
else:
state_dict.update(group_state_dict) state_dict.update(group_state_dict)
# update the cache with all the things we fetched from the
# database.
self._state_group_cache.update( self._state_group_cache.update(
cache_seq_num, cache_seq_num,
key=group, key=group,
value=state_dict, value=group_state_dict,
full=(types is None), fetched_keys=types_to_fetch,
known_absent=types,
) )
defer.returnValue(results) defer.returnValue(results)
@ -681,7 +712,6 @@ class StateGroupWorkerStore(SQLBaseStore):
self._state_group_cache.sequence, self._state_group_cache.sequence,
key=state_group, key=state_group,
value=dict(current_state_ids), value=dict(current_state_ids),
full=True,
) )
return state_group return state_group

View file

@ -107,29 +107,28 @@ class DictionaryCache(object):
self.sequence += 1 self.sequence += 1
self.cache.clear() self.cache.clear()
def update(self, sequence, key, value, full=False, known_absent=None): def update(self, sequence, key, value, fetched_keys=None):
"""Updates the entry in the cache """Updates the entry in the cache
Args: Args:
sequence sequence
key key (K)
value (dict): The value to update the cache with. value (dict[X,Y]): The value to update the cache with.
full (bool): Whether the given value is the full dict, or just a fetched_keys (None|set[X]): All of the dictionary keys which were
partial subset there of. If not full then any existing entries fetched from the database.
for the key will be updated.
known_absent (set): Set of keys that we know don't exist in the full If None, this is the complete value for key K. Otherwise, it
dict. is used to infer a list of keys which we know don't exist in
the full dict.
""" """
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
if known_absent is None: if fetched_keys is None:
known_absent = set() self._insert(key, value, set())
if full:
self._insert(key, value, known_absent)
else: else:
self._update_or_insert(key, value, known_absent) self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(self, key, value, known_absent): def _update_or_insert(self, key, value, known_absent):
# We pop and reinsert as we need to tell the cache the size may have # We pop and reinsert as we need to tell the cache the size may have

View file

@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"} test_value = {"test": "test_simple_cache_hit_full"}
self.cache.update(seq, key, test_value, full=True) self.cache.update(seq, key, test_value)
c = self.cache.get(key) c = self.cache.get(key)
self.assertEqual(test_value, c.value) self.assertEqual(test_value, c.value)
@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = { test_value = {
"test": "test_simple_cache_hit_partial" "test": "test_simple_cache_hit_partial"
} }
self.cache.update(seq, key, test_value, full=True) self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test"]) c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value) self.assertEqual(test_value, c.value)
@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = { test_value = {
"test": "test_simple_cache_miss_partial" "test": "test_simple_cache_miss_partial"
} }
self.cache.update(seq, key, test_value, full=True) self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"]) c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value) self.assertEqual({}, c.value)
@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase):
"test2": "test_simple_cache_hit_miss_partial2", "test2": "test_simple_cache_hit_miss_partial2",
"test3": "test_simple_cache_hit_miss_partial3", "test3": "test_simple_cache_hit_miss_partial3",
} }
self.cache.update(seq, key, test_value, full=True) self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"]) c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase):
test_value_1 = { test_value_1 = {
"test": "test_simple_cache_hit_miss_partial", "test": "test_simple_cache_hit_miss_partial",
} }
self.cache.update(seq, key, test_value_1, full=False) self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
seq = self.cache.sequence seq = self.cache.sequence
test_value_2 = { test_value_2 = {
"test2": "test_simple_cache_hit_miss_partial2", "test2": "test_simple_cache_hit_miss_partial2",
} }
self.cache.update(seq, key, test_value_2, full=False) self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
c = self.cache.get(key) c = self.cache.get(key)
self.assertEqual( self.assertEqual(