diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 84f29e3867..1d0f0058a2 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -18,7 +18,7 @@ from ._base import BaseHandler from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.metrics import Measure from twisted.internet import defer diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 6c32e8f7b3..90ec50bb50 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -264,26 +264,20 @@ class StateStore(SQLBaseStore): ) @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1) + num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - def f(txn): - results = {} - for event_id in event_ids: - results[event_id] = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={ - "event_id": event_id, - }, - retcol="state_group", - allow_none=True, - ) + rows = yield self._simple_select_many_batch( + table="event_to_state_groups", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id", "state_group",), + desc="_get_state_group_for_events", + ) - return results - - return self.runInteraction("_get_state_group_for_events", f) + defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) def _get_some_state_from_cache(self, group, types): """Checks if group is in cache. See `_get_state_for_groups`