Merge branch 'erikj/less_state_on_missing' into erikj/push_hack

This commit is contained in:
Erik Johnston 2022-05-21 14:13:10 +01:00
commit da10dfc311
15 changed files with 198 additions and 110 deletions

View file

@ -0,0 +1 @@
Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands.

View file

@ -1 +0,0 @@
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.

1
changelog.d/12703.misc Normal file
View file

@ -0,0 +1 @@
Convert namespace class `Codes` into a string enum.

View file

@ -0,0 +1 @@
Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands.

1
changelog.d/12828.misc Normal file
View file

@ -0,0 +1 @@
Pull out less state when handling gaps in room DAG.

View file

@ -17,6 +17,7 @@
import logging
import typing
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
@ -30,7 +31,11 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class Codes:
class Codes(str, Enum):
"""
All known error codes, as an enum of strings.
"""
UNRECOGNIZED = "M_UNRECOGNIZED"
UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN"
@ -265,7 +270,9 @@ class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(
self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
self,
msg: str = "Unrecognized request",
errcode: str = Codes.UNRECOGNIZED,
):
super().__init__(400, msg, errcode)

View file

@ -463,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in state
},
partial_state=partial_state,
)
@ -501,7 +503,7 @@ class FederationEventHandler:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event=state,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
@ -765,7 +767,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
@ -792,8 +794,8 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
if we already had all the prev events, `None`. Otherwise, returns a list of
the events in the state at `event`.
if we already had all the prev events, `None`. Otherwise, returns
the state at `event`.
"""
room_id = event.room_id
event_id = event.event_id
@ -837,13 +839,7 @@ class FederationEventHandler:
dest, room_id, p
)
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
state_maps.append(remote_state_map)
for x in remote_state:
event_map[x.event_id] = x
state_maps.append(remote_state)
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
@ -854,19 +850,6 @@ class FederationEventHandler:
state_res_store=StateResolutionStore(self._store),
)
# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self._store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.as_is,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
@ -878,14 +861,14 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
return state
return state_map
async def _get_state_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@ -894,7 +877,7 @@ class FederationEventHandler:
event_id: The id of the event we want the state at.
Returns:
A list of events in the state, including the event itself
The state *after* the given event.
"""
(
state_event_ids,
@ -913,15 +896,13 @@ class FederationEventHandler:
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
fetched_events = await self._store.get_events(
desired_events, allow_rejected=True
)
have_events = await self._store.have_seen_events(room_id, desired_events)
missing_desired_events = desired_events - fetched_events.keys()
missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
len(fetched_events),
len(have_events),
)
# We probably won't need most of the auth events, so let's just check which
@ -932,7 +913,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
@ -958,47 +939,54 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
await self._store.get_events(missing_desired_events, allow_rejected=True)
)
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
]
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
state_map = {}
del fetched_events[bad_event_id]
for state_event_id, metadata in event_metadata.items():
if metadata.room_id != room_id:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
state_event_id,
metadata.room_id,
room_id,
)
continue
if metadata.state_key is None:
logger.warning(
"Remote server gave us non-state event in state: %s", state_event_id
)
continue
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem.
remote_event = fetched_events.get(event_id)
remote_event = await self._store.get_event(
event_id,
allow_none=True,
allow_rejected=True,
redact_behaviour=EventRedactBehaviour.as_is,
)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - fetched_events.keys()
failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@ -1006,14 +994,12 @@ class FederationEventHandler:
failed_to_fetch,
)
remote_state = [
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
]
if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
state_map[
(remote_event.type, remote_event.state_key)
] = remote_event.event_id
return remote_state
return state_map
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
@ -1040,7 +1026,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
state: Optional[StateMap[str]],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@ -1074,7 +1060,7 @@ class FederationEventHandler:
try:
context = await self._state_handler.compute_event_context(
event, old_state=state
event, state_ids_before_event=state
)
context = await self._check_event_auth(
origin,
@ -1565,7 +1551,7 @@ class FederationEventHandler:
async def _check_for_soft_fail(
self,
event: EventBase,
state: Optional[Iterable[EventBase]],
state: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@ -1605,17 +1591,21 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
state_sets_d = await self._state_store.get_state_groups(
state_sets_d = await self._state_store.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self._state_handler.resolve_events(
room_version, state_sets, event
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
state_sets,
event_map={},
state_res_store=StateResolutionStore(self._store),
)
)
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
}
else:
current_state_ids = await self._store.get_filtered_current_state_ids(
event.room_id, StateFilter.from_types(auth_types)

View file

@ -1023,8 +1023,21 @@ class EventCreationHandler:
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
metadata = await self.store.get_metadata_for_events(state_event_ids)
state_map = {}
for event_id, data in metadata.items():
if data.state_key is None:
raise Exception(
"Trying to set non-state event as state: %s", event_id
)
state_map[(data.event_type, data.state_key)] = event_id
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map,
)
else:
context = await self.state.compute_event_context(event)

View file

@ -58,6 +58,15 @@ class Command(metaclass=abc.ABCMeta):
# by default, we just use the command name.
return self.NAME
def redis_channel_name(self, prefix: str) -> str:
"""
Returns the Redis channel name upon which to publish this command.
Args:
prefix: The prefix for the channel.
"""
return prefix
SC = TypeVar("SC", bound="_SimpleCommand")
@ -395,6 +404,9 @@ class UserIpCommand(Command):
f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
)
def redis_channel_name(self, prefix: str) -> str:
return f"{prefix}/USER_IP"
class RemoteServerUpCommand(_SimpleCommand):
"""Sent when a worker has detected that a remote server is no longer

View file

@ -221,10 +221,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
channel_name = cmd.redis_channel_name(self.synapse_stream_prefix)
await make_deferred_yieldable(
self.synapse_outbound_redis_connection.publish(
self.synapse_stream_prefix, encoded_string
)
self.synapse_outbound_redis_connection.publish(channel_name, encoded_string)
)

View file

@ -194,7 +194,7 @@ class StateHandler:
async def compute_event_context(
self,
event: EventBase,
old_state: Optional[Iterable[EventBase]] = None,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@ -206,12 +206,12 @@ class StateHandler:
Args:
event:
old_state: The state at the event if it can't be
state_ids_before_event: The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical
membership events
partial_state: True if `state_ids_before_event` is partial and omits
non-critical membership events
Returns:
The event context.
"""
@ -221,11 +221,7 @@ class StateHandler:
#
# first of all, figure out the state before the event
#
if old_state:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state
}
if state_ids_before_event:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None

View file

@ -16,6 +16,8 @@ import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@ -26,6 +28,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@ -43,6 +46,15 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventMetadata:
"""Returned by `get_metadata_for_events`"""
room_id: str
event_type: str
state_key: Optional[str]
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v:
@ -133,6 +145,36 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
async def get_metadata_for_events(
self, event_ids: Collection[str]
) -> Dict[str, EventMetadata]:
"""Get some metadata (room_id, type, state_key) for the given events."""
clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", event_ids
)
sql = f"""
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
LEFT JOIN state_events USING (event_id)
WHERE {clause}
"""
def get_metadata_for_events_txn(
txn: LoggingTransaction,
) -> Dict[str, EventMetadata]:
txn.execute(sql, args)
return {
event_id: EventMetadata(
room_id=room_id, event_type=event_type, state_key=state_key
)
for event_id, room_id, event_type, state_key in txn
}
return await self.db_pool.runInteraction(
"get_metadata_for_events", get_metadata_for_events_txn
)
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.

View file

@ -276,7 +276,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state
self.OTHER_SERVER_NAME,
event,
state={(e.type, e.state_key): e.event_id for e in current_state},
)
)

View file

@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(event, old_state=state)
self.state.compute_event_context(event, state_ids_before_event=state)
)
self.get_success(self.persistence.persist_event(event, context))
@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
self.get_success(self.store.get_current_state(self.room_id))
self.get_success(self.store.get_current_state_ids(self.room_id))
)
state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success(
self.state.compute_event_context(
remote_event_2, old_state=state_before_gap.values()
remote_event_2,
state_ids_before_event=state_before_gap,
)
)
@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])

View file

@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())