Refactor _resolve_state_at_missing_prevs to return an EventContext (#13404)

Previously, `_resolve_state_at_missing_prevs` returned the resolved
state before an event and a partial state flag. These were unwieldy to
carry around would only ever be used to build an event context. Build
the event context directly instead.

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2022-08-01 13:53:56 +01:00 committed by GitHub
parent 05aeeb3a80
commit 224d792dd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 86 deletions

1
changelog.d/13404.misc Normal file
View file

@ -0,0 +1 @@
Refactor `_resolve_state_at_missing_prevs` to compute an `EventContext` instead.

View file

@ -23,7 +23,6 @@ from typing import (
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
@ -278,9 +277,8 @@ class FederationEventHandler:
)
try:
await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
context = await self._state_handler.compute_event_context(pdu)
await self._process_received_pdu(origin, pdu, context)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time.
@ -288,9 +286,8 @@ class FederationEventHandler:
"Room %s was un-partial stated while processing the PDU, trying again.",
room_id,
)
await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
context = await self._state_handler.compute_event_context(pdu)
await self._process_received_pdu(origin, pdu, context)
async def on_send_membership_event(
self, origin: str, event: EventBase
@ -320,6 +317,7 @@ class FederationEventHandler:
The event and context of the event after inserting it into the room graph.
Raises:
RuntimeError if any prev_events are missing
SynapseError if the event is not accepted into the room
PartialStateConflictError if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should
@ -380,7 +378,7 @@ class FederationEventHandler:
# need to.
await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
await self._check_for_soft_fail(event, None, origin=origin)
await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context)
return event, context
@ -538,36 +536,10 @@ class FederationEventHandler:
#
# This is the same operation as we do when we receive a regular event
# over federation.
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_maybe_missing_prevs(
destination, event
)
# There are three possible cases for (state_ids, partial_state):
# * `state_ids` and `partial_state` are both `None` if we had all the
# prev_events. The prev_events may or may not have partial state and
# we won't know until we compute the event context.
# * `state_ids` is not `None` and `partial_state` is `False` if we were
# missing some prev_events (but we have full state for any we did
# have). We calculated the full state after the prev_events.
# * `state_ids` is not `None` and `partial_state` is `True` if we were
# missing some, but not all, prev_events. At least one of the
# prev_events we did have had partial state, so we calculated a partial
# state after the prev_events.
context = None
if state_ids is not None and partial_state:
# the state after the prev events is still partial. We can't de-partial
# state the event, so don't bother building the event context.
pass
else:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)
if context is None or context.partial_state:
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
# partial state. We were careful to only pick events from the db without
# partial-state prev events, so that implies that a prev event has
@ -840,26 +812,25 @@ class FederationEventHandler:
try:
try:
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_maybe_missing_prevs(
origin, event
)
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
context,
backfilled=backfilled,
)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the event.
# Try once more, with full state this time.
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_maybe_missing_prevs(
origin, event
)
# We ought to have full state now, barring some unlikely race where we left and
# rejoned the room in the background.
if state_ids is not None and partial_state:
if context.partial_state:
raise AssertionError(
f"Event {event.event_id} still has a partial resolved state "
f"after room {event.room_id} was un-partial stated"
@ -868,8 +839,7 @@ class FederationEventHandler:
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
context,
backfilled=backfilled,
)
except FederationError as e:
@ -878,15 +848,18 @@ class FederationEventHandler:
else:
raise
async def _resolve_state_at_missing_prevs(
async def _compute_event_context_with_maybe_missing_prevs(
self, dest: str, event: EventBase
) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
"""Calculate the state at an event with missing prev_events.
) -> EventContext:
"""Build an EventContext structure for a non-outlier event whose prev_events may
be missing.
This is used when we have pulled a batch of events from a remote server, and
still don't have all the prev_events.
This is used when we have pulled a batch of events from a remote server, and may
not have all the prev_events.
If we already have all the prev_events for `event`, this method does nothing.
To build an EventContext, we need to calculate the state before the event. If we
already have all the prev_events for `event`, we can simply use the state after
the prev_events to calculate the state before `event`.
Otherwise, the missing prevs become new backwards extremities, and we fall back
to asking the remote server for the state after each missing `prev_event`,
@ -907,10 +880,7 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
if we already had all the prev events, `None, None`. Otherwise, returns a
tuple containing:
* the event ids of the state at `event`.
* a boolean indicating whether the state may be partial.
The event context.
Raises:
FederationError if we fail to get the state from the remote server after any
@ -924,7 +894,7 @@ class FederationEventHandler:
missing_prevs = prevs - seen
if not missing_prevs:
return None, None
return await self._state_handler.compute_event_context(event)
logger.info(
"Event %s is missing prev_events %s: calculating state for a "
@ -990,7 +960,9 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
return state_map, partial_state
return await self._state_handler.compute_event_context(
event, state_ids_before_event=state_map, partial_state=partial_state
)
async def _get_state_ids_after_missing_prev_event(
self,
@ -1159,8 +1131,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
state_ids: Optional[StateMap[str]],
partial_state: Optional[bool],
context: EventContext,
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@ -1182,32 +1153,18 @@ class FederationEventHandler:
event: event to be persisted
state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
partial_state:
`True` if `state_ids` is partial and omits non-critical membership
events.
`False` if `state_ids` is the full state.
`None` if `state_ids` is not provided. In this case, the flag will be
calculated based on `event`'s prev events.
context: The `EventContext` to persist the event with.
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
PartialStateConflictError: if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should retry
exactly once in this case.
computing the state at the event and persisting it. The caller should
recompute `context` and retry exactly once when this happens.
"""
logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)
try:
await self._check_event_auth(origin, event, context)
except AuthError as e:
@ -1219,7 +1176,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled)
@ -1782,7 +1739,7 @@ class FederationEventHandler:
async def _check_for_soft_fail(
self,
event: EventBase,
state_ids: Optional[StateMap[str]],
context: EventContext,
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@ -1793,7 +1750,7 @@ class FederationEventHandler:
Args:
event
state_ids: The state at the event if we don't have all the event's prev events
context: The `EventContext` which we are about to persist the event with.
origin: The host the event originates from.
"""
if await self._store.is_partial_state_room(event.room_id):
@ -1819,11 +1776,15 @@ class FederationEventHandler:
auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state".
if state_ids is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
# a while and have an incorrect view of the current state,
seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
has_missing_prevs = bool(prev_event_ids - seen_event_ids)
if has_missing_prevs:
# We don't have all the prev_events of this event, which means we have a
# gap in the graph, and the new event is going to become a new backwards
# extremity.
#
# In this case we want to be a little careful as we might have been
# down for a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to
# maliciously manufacture.
#
@ -1836,6 +1797,7 @@ class FederationEventHandler:
event.room_id, extrem_ids
)
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_ids = await context.get_prev_state_ids()
state_sets.append(state_ids)
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(

View file

@ -278,6 +278,10 @@ class StateHandler:
flag will be calculated based on `event`'s prev events.
Returns:
The event context.
Raises:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
"""
assert not event.internal_metadata.is_outlier()
@ -432,6 +436,10 @@ class StateHandler:
Returns:
The resolved state
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)

View file

@ -338,6 +338,10 @@ class StateStorageController:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)

View file

@ -280,14 +280,21 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event.
state_handler = self.hs.get_state_handler()
context = self.get_success(
state_handler.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in current_state
},
partial_state=False,
)
)
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME,
event,
state_ids={
(e.type, e.state_key): e.event_id for e in current_state
},
partial_state=False,
context,
)
)