diff --git a/changelog.d/13477.misc b/changelog.d/13477.misc new file mode 100644 index 0000000000..5d21ae9d7a --- /dev/null +++ b/changelog.d/13477.misc @@ -0,0 +1 @@ +Faster room joins: Avoid blocking lazy-loading `/sync`s during partial joins due to remote memberships. Pull remote memberships from auth events instead of the room state. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3ca01391c9..b4d3f3958c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -16,9 +16,11 @@ import logging from typing import ( TYPE_CHECKING, Any, + Collection, Dict, FrozenSet, List, + Mapping, Optional, Sequence, Set, @@ -517,10 +519,17 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): + # FIXME(faster_joins): We use the partial state here as + # we don't want to block `/sync` on finishing a lazy join. + # Which should be fine once + # https://github.com/matrix-org/synapse/issues/12989 is resolved, + # since we shouldn't reach here anymore? + # Note that we use the current state as a whitelist for filtering + # `recents`, so partial state is only a problem when a membership + # event turns up in `recents` but has not made it into the current + # state. current_state_ids_map = ( - await self._state_storage_controller.get_current_state_ids( - room_id - ) + await self.store.get_partial_current_state_ids(room_id) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -589,7 +598,13 @@ class SyncHandler: if any(e.is_state() for e in loaded_recents): # FIXME(faster_joins): We use the partial state here as # we don't want to block `/sync` on finishing a lazy join. - # Is this the correct way of doing it? + # Which should be fine once + # https://github.com/matrix-org/synapse/issues/12989 is resolved, + # since we shouldn't reach here anymore? + # Note that we use the current state as a whitelist for filtering + # `loaded_recents`, so partial state is only a problem when a + # membership event turns up in `loaded_recents` but has not made it + # into the current state. current_state_ids_map = ( await self.store.get_partial_current_state_ids(room_id) ) @@ -637,7 +652,10 @@ class SyncHandler: ) async def get_state_after_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """ Get the room state after the given event @@ -645,9 +663,14 @@ class SyncHandler: Args: event_id: event of interest state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the event and `state_filter` is not satisfied by partial state. + Defaults to `True`. """ state_ids = await self._state_storage_controller.get_state_ids_for_event( - event_id, state_filter=state_filter or StateFilter.all() + event_id, + state_filter=state_filter or StateFilter.all(), + await_full_state=await_full_state, ) # using get_metadata_for_events here (instead of get_event) sidesteps an issue @@ -670,6 +693,7 @@ class SyncHandler: room_id: str, stream_position: StreamToken, state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Get the room state at a particular stream position @@ -677,6 +701,9 @@ class SyncHandler: room_id: room for which to get state stream_position: point at which to get state state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the last event in the room before `stream_position` and + `state_filter` is not satisfied by partial state. Defaults to `True`. """ # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time @@ -688,7 +715,9 @@ class SyncHandler: if last_event_id: state = await self.get_state_after_event( - last_event_id, state_filter=state_filter or StateFilter.all() + last_event_id, + state_filter=state_filter or StateFilter.all(), + await_full_state=await_full_state, ) else: @@ -891,7 +920,15 @@ class SyncHandler: with Measure(self.clock, "compute_state_delta"): # The memberships needed for events in the timeline. # Only calculated when `lazy_load_members` is on. - members_to_fetch = None + members_to_fetch: Optional[Set[str]] = None + + # A dictionary mapping user IDs to the first event in the timeline sent by + # them. Only calculated when `lazy_load_members` is on. + first_event_by_sender_map: Optional[Dict[str, EventBase]] = None + + # The contribution to the room state from state events in the timeline. + # Only contains the last event for any given state key. + timeline_state: StateMap[str] lazy_load_members = sync_config.filter_collection.lazy_load_members() include_redundant_members = ( @@ -902,10 +939,23 @@ class SyncHandler: # We only request state for the members needed to display the # timeline: - members_to_fetch = { - event.sender # FIXME: we also care about invite targets etc. - for event in batch.events - } + timeline_state = {} + + members_to_fetch = set() + first_event_by_sender_map = {} + for event in batch.events: + # Build the map from user IDs to the first timeline event they sent. + if event.sender not in first_event_by_sender_map: + first_event_by_sender_map[event.sender] = event + + # We need the event's sender, unless their membership was in a + # previous timeline event. + if (EventTypes.Member, event.sender) not in timeline_state: + members_to_fetch.add(event.sender) + # FIXME: we also care about invite targets etc. + + if event.is_state(): + timeline_state[(event.type, event.state_key)] = event.event_id if full_state: # always make sure we LL ourselves so we know we're in the room @@ -915,16 +965,21 @@ class SyncHandler: members_to_fetch.add(sync_config.user.to_string()) state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) - else: - state_filter = StateFilter.all() - # The contribution to the room state from state events in the timeline. - # Only contains the last event for any given state key. - timeline_state = { - (event.type, event.state_key): event.event_id - for event in batch.events - if event.is_state() - } + # We are happy to use partial state to compute the `/sync` response. + # Since partial state may not include the lazy-loaded memberships we + # require, we fix up the state response afterwards with memberships from + # auth events. + await_full_state = False + else: + timeline_state = { + (event.type, event.state_key): event.event_id + for event in batch.events + if event.is_state() + } + + state_filter = StateFilter.all() + await_full_state = True # Now calculate the state to return in the sync response for the room. # This is more or less the change in state between the end of the previous @@ -936,19 +991,26 @@ class SyncHandler: if batch: state_at_timeline_end = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: state_at_timeline_end = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) state_at_timeline_start = state_at_timeline_end @@ -964,14 +1026,19 @@ class SyncHandler: if batch: state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: # We can get here if the user has ignored the senders of all # the recent events. state_at_timeline_start = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) # for now, we disable LL for gappy syncs - see @@ -993,20 +1060,28 @@ class SyncHandler: # is indeed the case. assert since_token is not None state_at_previous_sync = await self.get_state_at( - room_id, stream_position=since_token, state_filter=state_filter + room_id, + stream_position=since_token, + state_filter=state_filter, + await_full_state=await_full_state, ) if batch: state_at_timeline_end = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: # We can get here if the user has ignored the senders of all # the recent events. state_at_timeline_end = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) state_ids = _calculate_state( @@ -1036,8 +1111,23 @@ class SyncHandler: (EventTypes.Member, member) for member in members_to_fetch ), + await_full_state=False, ) + # If we only have partial state for the room, `state_ids` may be missing the + # memberships we wanted. We attempt to find some by digging through the auth + # events of timeline events. + if lazy_load_members and await self.store.is_partial_state_room(room_id): + assert members_to_fetch is not None + assert first_event_by_sender_map is not None + + additional_state_ids = ( + await self._find_missing_partial_state_memberships( + room_id, members_to_fetch, first_event_by_sender_map, state_ids + ) + ) + state_ids = {**state_ids, **additional_state_ids} + # At this point, if `lazy_load_members` is enabled, `state_ids` includes # the memberships of all event senders in the timeline. This is because we # may not have sent the memberships in a previous sync. @@ -1086,6 +1176,99 @@ class SyncHandler: if e.type != EventTypes.Aliases # until MSC2261 or alternative solution } + async def _find_missing_partial_state_memberships( + self, + room_id: str, + members_to_fetch: Collection[str], + events_with_membership_auth: Mapping[str, EventBase], + found_state_ids: StateMap[str], + ) -> StateMap[str]: + """Finds missing memberships from a set of auth events and returns them as a + state map. + + Args: + room_id: The partial state room to find the remaining memberships for. + members_to_fetch: The memberships to find. + events_with_membership_auth: A mapping from user IDs to events whose auth + events are known to contain their membership. + found_state_ids: A dict from (type, state_key) -> state_event_id, containing + memberships that have been previously found. Entries in + `members_to_fetch` that have a membership in `found_state_ids` are + ignored. + + Returns: + A dict from ("m.room.member", state_key) -> state_event_id, containing the + memberships missing from `found_state_ids`. + + Raises: + KeyError: if `events_with_membership_auth` does not have an entry for a + missing membership. Memberships in `found_state_ids` do not need an + entry in `events_with_membership_auth`. + """ + additional_state_ids: MutableStateMap[str] = {} + + # Tracks the missing members for logging purposes. + missing_members = set() + + # Identify memberships missing from `found_state_ids` and pick out the auth + # events in which to look for them. + auth_event_ids: Set[str] = set() + for member in members_to_fetch: + if (EventTypes.Member, member) in found_state_ids: + continue + + missing_members.add(member) + event_with_membership_auth = events_with_membership_auth[member] + auth_event_ids.update(event_with_membership_auth.auth_event_ids()) + + auth_events = await self.store.get_events(auth_event_ids) + + # Run through the missing memberships once more, picking out the memberships + # from the pile of auth events we have just fetched. + for member in members_to_fetch: + if (EventTypes.Member, member) in found_state_ids: + continue + + event_with_membership_auth = events_with_membership_auth[member] + + # Dig through the auth events to find the desired membership. + for auth_event_id in event_with_membership_auth.auth_event_ids(): + # We only store events once we have all their auth events, + # so the auth event must be in the pile we have just + # fetched. + auth_event = auth_events[auth_event_id] + + if ( + auth_event.type == EventTypes.Member + and auth_event.state_key == member + ): + missing_members.remove(member) + additional_state_ids[ + (EventTypes.Member, member) + ] = auth_event.event_id + break + + if missing_members: + # There really shouldn't be any missing memberships now. Either: + # * we couldn't find an auth event, which shouldn't happen because we do + # not persist events with persisting their auth events first, or + # * the set of auth events did not contain a membership we wanted, which + # means our caller didn't compute the events in `members_to_fetch` + # correctly, or we somehow accepted an event whose auth events were + # dodgy. + logger.error( + "Failed to find memberships for %s in partial state room " + "%s in the auth events of %s.", + missing_members, + room_id, + [ + events_with_membership_auth[member].event_id + for member in missing_members + ], + ) + + return additional_state_ids + async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig ) -> NotifCounts: @@ -1730,7 +1913,11 @@ class SyncHandler: continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = await self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at( + room_id, + since_token, + state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]), + ) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: @@ -1756,7 +1943,13 @@ class SyncHandler: newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = await self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at( + room_id, + since_token, + state_filter=StateFilter.from_types( + [(EventTypes.Member, user_id)] + ), + ) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 1ad002f57b..f9ffd0e29e 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -234,6 +234,7 @@ class StateStorageController: self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -242,6 +243,9 @@ class StateStorageController: Args: event_ids: events whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at these events and `state_filter` is not satisfied by partial state. + Defaults to `True`. Returns: A dict from event_id -> (type, state_key) -> event_id @@ -250,8 +254,12 @@ class StateStorageController: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + if ( + await_full_state + and state_filter + and not state_filter.must_await_full_state(self._is_mine_id) + ): + # Full state is not required if the state filter is restrictive enough. await_full_state = False event_to_groups = await self.get_state_group_for_events( @@ -294,7 +302,10 @@ class StateStorageController: @trace async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -302,6 +313,9 @@ class StateStorageController: Args: event_id: event whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the event and `state_filter` is not satisfied by partial state. + Defaults to `True`. Returns: A dict from (type, state_key) -> state_event_id @@ -311,7 +325,9 @@ class StateStorageController: outlier or is unknown) """ state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() + [event_id], + state_filter or StateFilter.all(), + await_full_state=await_full_state, ) return state_map[event_id]