diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 6cd3a843df..f0a367aa9b 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -108,7 +108,7 @@ class SlavedEventStore(BaseSlavedStore): get_current_state_ids = ( StateStore.__dict__["get_current_state_ids"] ) - get_state_group_delta = DataStore.get_state_group_delta.__func__ + get_state_group_delta = StateStore.__dict__["get_state_group_delta"] _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] has_room_changed_since = DataStore.has_room_changed_since.__func__ diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 151223219d..d1e679719b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii from synapse.storage.engines import PostgresEngine from twisted.internet import defer +from collections import namedtuple import logging @@ -29,6 +30,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + class StateStore(SQLBaseStore): """ Keeps track of the state at a given event. @@ -98,6 +109,7 @@ class StateStore(SQLBaseStore): _get_current_state_ids_txn, ) + @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between the old and the new. @@ -117,7 +129,7 @@ class StateStore(SQLBaseStore): ) if not prev_group: - return None, None + return _GetStateGroupDelta(None, None) delta_ids = self._simple_select_list_txn( txn, @@ -128,10 +140,10 @@ class StateStore(SQLBaseStore): retcols=("type", "state_key", "event_id",) ) - return prev_group, { + return _GetStateGroupDelta(prev_group, { (row["type"], row["state_key"]): row["event_id"] for row in delta_ids - } + }) return self.runInteraction( "get_state_group_delta", _get_state_group_delta_txn,