Faster room joins: avoid blocking when pulling events with missing prevs (#13355)

Avoid blocking on full state in `_resolve_state_at_missing_prevs` and
return a new flag indicating whether the resolved state is partial.
Thread that flag around so that it makes it into the event context.

Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
This commit is contained in:
Sean Quah 2022-07-26 12:39:23 +01:00 committed by GitHub
parent 8b603299bf
commit 335ebb21cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 124 additions and 33 deletions

1
changelog.d/13355.misc Normal file
View file

@ -0,0 +1 @@
Faster room joins: avoid blocking when pulling events with partially missing prev events.

View file

@ -278,7 +278,9 @@ class FederationEventHandler:
) )
try: try:
await self._process_received_pdu(origin, pdu, state_ids=None) await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
except PartialStateConflictError: except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU. # The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time. # Try once more, with full state this time.
@ -286,7 +288,9 @@ class FederationEventHandler:
"Room %s was un-partial stated while processing the PDU, trying again.", "Room %s was un-partial stated while processing the PDU, trying again.",
room_id, room_id,
) )
await self._process_received_pdu(origin, pdu, state_ids=None) await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
async def on_send_membership_event( async def on_send_membership_event(
self, origin: str, event: EventBase self, origin: str, event: EventBase
@ -534,14 +538,36 @@ class FederationEventHandler:
# #
# This is the same operation as we do when we receive a regular event # This is the same operation as we do when we receive a regular event
# over federation. # over federation.
state_ids = await self._resolve_state_at_missing_prevs(destination, event) state_ids, partial_state = await self._resolve_state_at_missing_prevs(
destination, event
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
) )
if context.partial_state:
# There are three possible cases for (state_ids, partial_state):
# * `state_ids` and `partial_state` are both `None` if we had all the
# prev_events. The prev_events may or may not have partial state and
# we won't know until we compute the event context.
# * `state_ids` is not `None` and `partial_state` is `False` if we were
# missing some prev_events (but we have full state for any we did
# have). We calculated the full state after the prev_events.
# * `state_ids` is not `None` and `partial_state` is `True` if we were
# missing some, but not all, prev_events. At least one of the
# prev_events we did have had partial state, so we calculated a partial
# state after the prev_events.
context = None
if state_ids is not None and partial_state:
# the state after the prev events is still partial. We can't de-partial
# state the event, so don't bother building the event context.
pass
else:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)
if context is None or context.partial_state:
# this can happen if some or all of the event's prev_events still have # this can happen if some or all of the event's prev_events still have
# partial state - ie, an event has an earlier stream_ordering than one # partial state - ie, an event has an earlier stream_ordering than one
# or more of its prev_events, so we de-partial-state it before its # or more of its prev_events, so we de-partial-state it before its
@ -806,14 +832,39 @@ class FederationEventHandler:
return return
try: try:
state_ids = await self._resolve_state_at_missing_prevs(origin, event) try:
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does state_ids, partial_state = await self._resolve_state_at_missing_prevs(
# not return partial state origin, event
# https://github.com/matrix-org/synapse/issues/13002 )
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
backfilled=backfilled,
)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the event.
# Try once more, with full state this time.
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
origin, event
)
await self._process_received_pdu( # We ought to have full state now, barring some unlikely race where we left and
origin, event, state_ids=state_ids, backfilled=backfilled # rejoned the room in the background.
) if state_ids is not None and partial_state:
raise AssertionError(
f"Event {event.event_id} still has a partial resolved state "
f"after room {event.room_id} was un-partial stated"
)
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
backfilled=backfilled,
)
except FederationError as e: except FederationError as e:
if e.code == 403: if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id) logger.warning("Pulled event %s failed history check.", event_id)
@ -822,7 +873,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs( async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase self, dest: str, event: EventBase
) -> Optional[StateMap[str]]: ) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
"""Calculate the state at an event with missing prev_events. """Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and This is used when we have pulled a batch of events from a remote server, and
@ -849,8 +900,10 @@ class FederationEventHandler:
event: an event to check for missing prevs. event: an event to check for missing prevs.
Returns: Returns:
if we already had all the prev events, `None`. Otherwise, returns if we already had all the prev events, `None, None`. Otherwise, returns a
the event ids of the state at `event`. tuple containing:
* the event ids of the state at `event`.
* a boolean indicating whether the state may be partial.
Raises: Raises:
FederationError if we fail to get the state from the remote server after any FederationError if we fail to get the state from the remote server after any
@ -864,7 +917,7 @@ class FederationEventHandler:
missing_prevs = prevs - seen missing_prevs = prevs - seen
if not missing_prevs: if not missing_prevs:
return None return None, None
logger.info( logger.info(
"Event %s is missing prev_events %s: calculating state for a " "Event %s is missing prev_events %s: calculating state for a "
@ -876,9 +929,15 @@ class FederationEventHandler:
# resolve them to find the correct state at the current event. # resolve them to find the correct state at the current event.
try: try:
# Determine whether we may be about to retrieve partial state
# Events may be un-partial stated right after we compute the partial state
# flag, but that's okay, as long as the flag errs on the conservative side.
partial_state_flags = await self._store.get_partial_state_events(seen)
partial_state = any(partial_state_flags.values())
# Get the state of the events we know about # Get the state of the events we know about
ours = await self._state_storage_controller.get_state_groups_ids( ours = await self._state_storage_controller.get_state_groups_ids(
room_id, seen room_id, seen, await_full_state=False
) )
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
@ -924,7 +983,7 @@ class FederationEventHandler:
"We can't get valid state history.", "We can't get valid state history.",
affected=event_id, affected=event_id,
) )
return state_map return state_map, partial_state
async def _get_state_ids_after_missing_prev_event( async def _get_state_ids_after_missing_prev_event(
self, self,
@ -1094,6 +1153,7 @@ class FederationEventHandler:
origin: str, origin: str,
event: EventBase, event: EventBase,
state_ids: Optional[StateMap[str]], state_ids: Optional[StateMap[str]],
partial_state: Optional[bool],
backfilled: bool = False, backfilled: bool = False,
) -> None: ) -> None:
"""Called when we have a new non-outlier event. """Called when we have a new non-outlier event.
@ -1117,14 +1177,21 @@ class FederationEventHandler:
state_ids: Normally None, but if we are handling a gap in the graph state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the (ie, we are missing one or more prev_events), the resolved state at the
event. Must not be partial state. event
partial_state:
`True` if `state_ids` is partial and omits non-critical membership
events.
`False` if `state_ids` is the full state.
`None` if `state_ids` is not provided. In this case, the flag will be
calculated based on `event`'s prev events.
backfilled: True if this is part of a historical batch of events (inhibits backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.) notification to clients, and validation of device keys.)
PartialStateConflictError: if the room was un-partial stated in between PartialStateConflictError: if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should retry computing the state at the event and persisting it. The caller should retry
exactly once in this case. Will never be raised if `state_ids` is provided. exactly once in this case.
""" """
logger.debug("Processing event: %s", event) logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier assert not event.internal_metadata.outlier
@ -1132,6 +1199,7 @@ class FederationEventHandler:
context = await self._state_handler.compute_event_context( context = await self._state_handler.compute_event_context(
event, event,
state_ids_before_event=state_ids, state_ids_before_event=state_ids,
partial_state=partial_state,
) )
try: try:
await self._check_event_auth(origin, event, context) await self._check_event_auth(origin, event, context)

View file

@ -1135,6 +1135,10 @@ class EventCreationHandler:
context = await self.state.compute_event_context( context = await self.state.compute_event_context(
event, event,
state_ids_before_event=state_map_for_event, state_ids_before_event=state_map_for_event,
# TODO(faster_joins): check how MSC2716 works and whether we can have
# partial state here
# https://github.com/matrix-org/synapse/issues/13003
partial_state=False,
) )
else: else:
context = await self.state.compute_event_context(event) context = await self.state.compute_event_context(event)

View file

@ -255,7 +255,7 @@ class StateHandler:
self, self,
event: EventBase, event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None, state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False, partial_state: Optional[bool] = None,
) -> EventContext: ) -> EventContext:
"""Build an EventContext structure for a non-outlier event. """Build an EventContext structure for a non-outlier event.
@ -270,8 +270,12 @@ class StateHandler:
it can't be calculated from existing events. This is normally it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling. don't have the prev events, e.g. when backfilling.
partial_state: True if `state_ids_before_event` is partial and omits partial_state:
non-critical membership events `True` if `state_ids_before_event` is partial and omits non-critical
membership events.
`False` if `state_ids_before_event` is the full state.
`None` when `state_ids_before_event` is not provided. In this case, the
flag will be calculated based on `event`'s prev events.
Returns: Returns:
The event context. The event context.
""" """
@ -298,12 +302,14 @@ class StateHandler:
) )
) )
# the partial_state flag must be provided
assert partial_state is not None
else: else:
# otherwise, we'll need to resolve the state across the prev_events. # otherwise, we'll need to resolve the state across the prev_events.
# partial_state should not be set explicitly in this case: # partial_state should not be set explicitly in this case:
# we work it out dynamically # we work it out dynamically
assert not partial_state assert partial_state is None
# if any of the prev-events have partial state, so do we. # if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use # (This is slightly racy - the prev-events might get fixed up before we use
@ -313,13 +319,13 @@ class StateHandler:
incomplete_prev_events = await self.store.get_partial_state_events( incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids prev_event_ids
) )
if any(incomplete_prev_events.values()): partial_state = any(incomplete_prev_events.values())
if partial_state:
logger.debug( logger.debug(
"New/incoming event %s refers to prev_events %s with partial state", "New/incoming event %s refers to prev_events %s with partial state",
event.event_id, event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v], [k for (k, v) in incomplete_prev_events.items() if v],
) )
partial_state = True
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for # we've already taken into account partial state, so no need to wait for

View file

@ -82,13 +82,15 @@ class StateStorageController:
return state_group_delta.prev_group, state_group_delta.delta_ids return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids( async def get_state_groups_ids(
self, _room_id: str, event_ids: Collection[str] self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> Dict[int, MutableStateMap[str]]: ) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events """Get the event IDs of all the state for the state groups for the given events
Args: Args:
_room_id: id of the room for these events _room_id: id of the room for these events
event_ids: ids of the events event_ids: ids of the events
await_full_state: if `True`, will block if we do not yet have complete
state at these events.
Returns: Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id) dict of state_group_id -> (dict of (type, state_key) -> event id)
@ -100,7 +102,9 @@ class StateStorageController:
if not event_ids: if not event_ids:
return {} return {}
event_to_groups = await self.get_state_group_for_events(event_ids) event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups) group_to_state = await self.stores.state._get_state_for_groups(groups)

View file

@ -287,6 +287,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
state_ids={ state_ids={
(e.type, e.state_key): e.event_id for e in current_state (e.type, e.state_key): e.event_id for e in current_state
}, },
partial_state=False,
) )
) )

View file

@ -70,7 +70,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None): def persist_event(self, event, state=None):
"""Persist the event, with optional state""" """Persist the event, with optional state"""
context = self.get_success( context = self.get_success(
self.state.compute_event_context(event, state_ids_before_event=state) self.state.compute_event_context(
event,
state_ids_before_event=state,
partial_state=None if state is None else False,
)
) )
self.get_success(self._persistence.persist_event(event, context)) self.get_success(self._persistence.persist_event(event, context))
@ -148,6 +152,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
self.state.compute_event_context( self.state.compute_event_context(
remote_event_2, remote_event_2,
state_ids_before_event=state_before_gap, state_ids_before_event=state_before_gap,
partial_state=False,
) )
) )

View file

@ -462,6 +462,7 @@ class StateTestCase(unittest.TestCase):
state_ids_before_event={ state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state (e.type, e.state_key): e.event_id for e in old_state
}, },
partial_state=False,
) )
) )
@ -492,6 +493,7 @@ class StateTestCase(unittest.TestCase):
state_ids_before_event={ state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state (e.type, e.state_key): e.event_id for e in old_state
}, },
partial_state=False,
) )
) )