mirror of
https://github.com/element-hq/synapse
synced 2024-09-17 18:55:10 +00:00
Merge branch 'erikj/less_state_on_missing' into erikj/push_hack
This commit is contained in:
commit
da10dfc311
15 changed files with 198 additions and 110 deletions
1
changelog.d/12672.feature
Normal file
1
changelog.d/12672.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands.
|
|
@ -1 +0,0 @@
|
|||
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.
|
1
changelog.d/12703.misc
Normal file
1
changelog.d/12703.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert namespace class `Codes` into a string enum.
|
1
changelog.d/12809.feature
Normal file
1
changelog.d/12809.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands.
|
1
changelog.d/12828.misc
Normal file
1
changelog.d/12828.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Pull out less state when handling gaps in room DAG.
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
import logging
|
||||
import typing
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
@ -30,7 +31,11 @@ if typing.TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Codes:
|
||||
class Codes(str, Enum):
|
||||
"""
|
||||
All known error codes, as an enum of strings.
|
||||
"""
|
||||
|
||||
UNRECOGNIZED = "M_UNRECOGNIZED"
|
||||
UNAUTHORIZED = "M_UNAUTHORIZED"
|
||||
FORBIDDEN = "M_FORBIDDEN"
|
||||
|
@ -265,7 +270,9 @@ class UnrecognizedRequestError(SynapseError):
|
|||
"""An error indicating we don't understand the request you're trying to make"""
|
||||
|
||||
def __init__(
|
||||
self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
|
||||
self,
|
||||
msg: str = "Unrecognized request",
|
||||
errcode: str = Codes.UNRECOGNIZED,
|
||||
):
|
||||
super().__init__(400, msg, errcode)
|
||||
|
||||
|
|
|
@ -463,7 +463,9 @@ class FederationEventHandler:
|
|||
with nested_logging_context(suffix=event.event_id):
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event,
|
||||
old_state=state,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in state
|
||||
},
|
||||
partial_state=partial_state,
|
||||
)
|
||||
|
||||
|
@ -501,7 +503,7 @@ class FederationEventHandler:
|
|||
# build a new state group for it if need be
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event,
|
||||
old_state=state,
|
||||
state_ids_before_event=state,
|
||||
)
|
||||
if context.partial_state:
|
||||
# this can happen if some or all of the event's prev_events still have
|
||||
|
@ -765,7 +767,7 @@ class FederationEventHandler:
|
|||
|
||||
async def _resolve_state_at_missing_prevs(
|
||||
self, dest: str, event: EventBase
|
||||
) -> Optional[Iterable[EventBase]]:
|
||||
) -> Optional[StateMap[str]]:
|
||||
"""Calculate the state at an event with missing prev_events.
|
||||
|
||||
This is used when we have pulled a batch of events from a remote server, and
|
||||
|
@ -792,8 +794,8 @@ class FederationEventHandler:
|
|||
event: an event to check for missing prevs.
|
||||
|
||||
Returns:
|
||||
if we already had all the prev events, `None`. Otherwise, returns a list of
|
||||
the events in the state at `event`.
|
||||
if we already had all the prev events, `None`. Otherwise, returns
|
||||
the state at `event`.
|
||||
"""
|
||||
room_id = event.room_id
|
||||
event_id = event.event_id
|
||||
|
@ -837,13 +839,7 @@ class FederationEventHandler:
|
|||
dest, room_id, p
|
||||
)
|
||||
|
||||
remote_state_map = {
|
||||
(x.type, x.state_key): x.event_id for x in remote_state
|
||||
}
|
||||
state_maps.append(remote_state_map)
|
||||
|
||||
for x in remote_state:
|
||||
event_map[x.event_id] = x
|
||||
state_maps.append(remote_state)
|
||||
|
||||
room_version = await self._store.get_room_version_id(room_id)
|
||||
state_map = await self._state_resolution_handler.resolve_events_with_store(
|
||||
|
@ -854,19 +850,6 @@ class FederationEventHandler:
|
|||
state_res_store=StateResolutionStore(self._store),
|
||||
)
|
||||
|
||||
# We need to give _process_received_pdu the actual state events
|
||||
# rather than event ids, so generate that now.
|
||||
|
||||
# First though we need to fetch all the events that are in
|
||||
# state_map, so we can build up the state below.
|
||||
evs = await self._store.get_events(
|
||||
list(state_map.values()),
|
||||
get_prev_content=False,
|
||||
redact_behaviour=EventRedactBehaviour.as_is,
|
||||
)
|
||||
event_map.update(evs)
|
||||
|
||||
state = [event_map[e] for e in state_map.values()]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error attempting to resolve state at missing prev_events",
|
||||
|
@ -878,14 +861,14 @@ class FederationEventHandler:
|
|||
"We can't get valid state history.",
|
||||
affected=event_id,
|
||||
)
|
||||
return state
|
||||
return state_map
|
||||
|
||||
async def _get_state_after_missing_prev_event(
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> List[EventBase]:
|
||||
) -> StateMap[str]:
|
||||
"""Requests all of the room state at a given event from a remote homeserver.
|
||||
|
||||
Args:
|
||||
|
@ -894,7 +877,7 @@ class FederationEventHandler:
|
|||
event_id: The id of the event we want the state at.
|
||||
|
||||
Returns:
|
||||
A list of events in the state, including the event itself
|
||||
The state *after* the given event.
|
||||
"""
|
||||
(
|
||||
state_event_ids,
|
||||
|
@ -913,15 +896,13 @@ class FederationEventHandler:
|
|||
desired_events = set(state_event_ids)
|
||||
desired_events.add(event_id)
|
||||
logger.debug("Fetching %i events from cache/store", len(desired_events))
|
||||
fetched_events = await self._store.get_events(
|
||||
desired_events, allow_rejected=True
|
||||
)
|
||||
have_events = await self._store.have_seen_events(room_id, desired_events)
|
||||
|
||||
missing_desired_events = desired_events - fetched_events.keys()
|
||||
missing_desired_events = desired_events - have_events
|
||||
logger.debug(
|
||||
"We are missing %i events (got %i)",
|
||||
len(missing_desired_events),
|
||||
len(fetched_events),
|
||||
len(have_events),
|
||||
)
|
||||
|
||||
# We probably won't need most of the auth events, so let's just check which
|
||||
|
@ -932,7 +913,7 @@ class FederationEventHandler:
|
|||
# already have a bunch of the state events. It would be nice if the
|
||||
# federation api gave us a way of finding out which we actually need.
|
||||
|
||||
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
|
||||
missing_auth_events = set(auth_event_ids) - have_events
|
||||
missing_auth_events.difference_update(
|
||||
await self._store.have_seen_events(room_id, missing_auth_events)
|
||||
)
|
||||
|
@ -958,47 +939,54 @@ class FederationEventHandler:
|
|||
destination=destination, room_id=room_id, event_ids=missing_events
|
||||
)
|
||||
|
||||
# we need to make sure we re-load from the database to get the rejected
|
||||
# state correct.
|
||||
fetched_events.update(
|
||||
await self._store.get_events(missing_desired_events, allow_rejected=True)
|
||||
)
|
||||
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
|
||||
|
||||
# check for events which were in the wrong room.
|
||||
#
|
||||
# this can happen if a remote server claims that the state or
|
||||
# auth_events at an event in room A are actually events in room B
|
||||
|
||||
bad_events = [
|
||||
(event_id, event.room_id)
|
||||
for event_id, event in fetched_events.items()
|
||||
if event.room_id != room_id
|
||||
]
|
||||
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
|
||||
|
||||
for bad_event_id, bad_room_id in bad_events:
|
||||
# This is a bogus situation, but since we may only discover it a long time
|
||||
# after it happened, we try our best to carry on, by just omitting the
|
||||
# bad events from the returned state set.
|
||||
logger.warning(
|
||||
"Remote server %s claims event %s in room %s is an auth/state "
|
||||
"event in room %s",
|
||||
destination,
|
||||
bad_event_id,
|
||||
bad_room_id,
|
||||
room_id,
|
||||
)
|
||||
state_map = {}
|
||||
|
||||
del fetched_events[bad_event_id]
|
||||
for state_event_id, metadata in event_metadata.items():
|
||||
if metadata.room_id != room_id:
|
||||
# This is a bogus situation, but since we may only discover it a long time
|
||||
# after it happened, we try our best to carry on, by just omitting the
|
||||
# bad events from the returned state set.
|
||||
logger.warning(
|
||||
"Remote server %s claims event %s in room %s is an auth/state "
|
||||
"event in room %s",
|
||||
destination,
|
||||
state_event_id,
|
||||
metadata.room_id,
|
||||
room_id,
|
||||
)
|
||||
continue
|
||||
|
||||
if metadata.state_key is None:
|
||||
logger.warning(
|
||||
"Remote server gave us non-state event in state: %s", state_event_id
|
||||
)
|
||||
continue
|
||||
|
||||
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
|
||||
|
||||
# if we couldn't get the prev event in question, that's a problem.
|
||||
remote_event = fetched_events.get(event_id)
|
||||
remote_event = await self._store.get_event(
|
||||
event_id,
|
||||
allow_none=True,
|
||||
allow_rejected=True,
|
||||
redact_behaviour=EventRedactBehaviour.as_is,
|
||||
)
|
||||
if not remote_event:
|
||||
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||
|
||||
# missing state at that event is a warning, not a blocker
|
||||
# XXX: this doesn't sound right? it means that we'll end up with incomplete
|
||||
# state.
|
||||
failed_to_fetch = desired_events - fetched_events.keys()
|
||||
failed_to_fetch = desired_events - event_metadata.keys()
|
||||
if failed_to_fetch:
|
||||
logger.warning(
|
||||
"Failed to fetch missing state events for %s %s",
|
||||
|
@ -1006,14 +994,12 @@ class FederationEventHandler:
|
|||
failed_to_fetch,
|
||||
)
|
||||
|
||||
remote_state = [
|
||||
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
|
||||
]
|
||||
|
||||
if remote_event.is_state() and remote_event.rejected_reason is None:
|
||||
remote_state.append(remote_event)
|
||||
state_map[
|
||||
(remote_event.type, remote_event.state_key)
|
||||
] = remote_event.event_id
|
||||
|
||||
return remote_state
|
||||
return state_map
|
||||
|
||||
async def _get_state_and_persist(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
|
@ -1040,7 +1026,7 @@ class FederationEventHandler:
|
|||
self,
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
state: Optional[StateMap[str]],
|
||||
backfilled: bool = False,
|
||||
) -> None:
|
||||
"""Called when we have a new non-outlier event.
|
||||
|
@ -1074,7 +1060,7 @@ class FederationEventHandler:
|
|||
|
||||
try:
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
event, state_ids_before_event=state
|
||||
)
|
||||
context = await self._check_event_auth(
|
||||
origin,
|
||||
|
@ -1565,7 +1551,7 @@ class FederationEventHandler:
|
|||
async def _check_for_soft_fail(
|
||||
self,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
state: Optional[StateMap[str]],
|
||||
origin: str,
|
||||
) -> None:
|
||||
"""Checks if we should soft fail the event; if so, marks the event as
|
||||
|
@ -1605,17 +1591,21 @@ class FederationEventHandler:
|
|||
# given state at the event. This should correctly handle cases
|
||||
# like bans, especially with state res v2.
|
||||
|
||||
state_sets_d = await self._state_store.get_state_groups(
|
||||
state_sets_d = await self._state_store.get_state_groups_ids(
|
||||
event.room_id, extrem_ids
|
||||
)
|
||||
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
|
||||
state_sets: List[StateMap[str]] = list(state_sets_d.values())
|
||||
state_sets.append(state)
|
||||
current_states = await self._state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
|
||||
current_state_ids = (
|
||||
await self._state_resolution_handler.resolve_events_with_store(
|
||||
event.room_id,
|
||||
room_version,
|
||||
state_sets,
|
||||
event_map={},
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
)
|
||||
)
|
||||
current_state_ids: StateMap[str] = {
|
||||
k: e.event_id for k, e in current_states.items()
|
||||
}
|
||||
else:
|
||||
current_state_ids = await self._store.get_filtered_current_state_ids(
|
||||
event.room_id, StateFilter.from_types(auth_types)
|
||||
|
|
|
@ -1023,8 +1023,21 @@ class EventCreationHandler:
|
|||
#
|
||||
# TODO(faster_joins): figure out how this works, and make sure that the
|
||||
# old state is complete.
|
||||
old_state = await self.store.get_events_as_list(state_event_ids)
|
||||
context = await self.state.compute_event_context(event, old_state=old_state)
|
||||
metadata = await self.store.get_metadata_for_events(state_event_ids)
|
||||
|
||||
state_map = {}
|
||||
for event_id, data in metadata.items():
|
||||
if data.state_key is None:
|
||||
raise Exception(
|
||||
"Trying to set non-state event as state: %s", event_id
|
||||
)
|
||||
|
||||
state_map[(data.event_type, data.state_key)] = event_id
|
||||
|
||||
context = await self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event=state_map,
|
||||
)
|
||||
else:
|
||||
context = await self.state.compute_event_context(event)
|
||||
|
||||
|
|
|
@ -58,6 +58,15 @@ class Command(metaclass=abc.ABCMeta):
|
|||
# by default, we just use the command name.
|
||||
return self.NAME
|
||||
|
||||
def redis_channel_name(self, prefix: str) -> str:
|
||||
"""
|
||||
Returns the Redis channel name upon which to publish this command.
|
||||
|
||||
Args:
|
||||
prefix: The prefix for the channel.
|
||||
"""
|
||||
return prefix
|
||||
|
||||
|
||||
SC = TypeVar("SC", bound="_SimpleCommand")
|
||||
|
||||
|
@ -395,6 +404,9 @@ class UserIpCommand(Command):
|
|||
f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
|
||||
)
|
||||
|
||||
def redis_channel_name(self, prefix: str) -> str:
|
||||
return f"{prefix}/USER_IP"
|
||||
|
||||
|
||||
class RemoteServerUpCommand(_SimpleCommand):
|
||||
"""Sent when a worker has detected that a remote server is no longer
|
||||
|
|
|
@ -221,10 +221,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
|||
# remote instances.
|
||||
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||
|
||||
channel_name = cmd.redis_channel_name(self.synapse_stream_prefix)
|
||||
|
||||
await make_deferred_yieldable(
|
||||
self.synapse_outbound_redis_connection.publish(
|
||||
self.synapse_stream_prefix, encoded_string
|
||||
)
|
||||
self.synapse_outbound_redis_connection.publish(channel_name, encoded_string)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -194,7 +194,7 @@ class StateHandler:
|
|||
async def compute_event_context(
|
||||
self,
|
||||
event: EventBase,
|
||||
old_state: Optional[Iterable[EventBase]] = None,
|
||||
state_ids_before_event: Optional[StateMap[str]] = None,
|
||||
partial_state: bool = False,
|
||||
) -> EventContext:
|
||||
"""Build an EventContext structure for a non-outlier event.
|
||||
|
@ -206,12 +206,12 @@ class StateHandler:
|
|||
|
||||
Args:
|
||||
event:
|
||||
old_state: The state at the event if it can't be
|
||||
state_ids_before_event: The state at the event if it can't be
|
||||
calculated from existing events. This is normally only specified
|
||||
when receiving an event from federation where we don't have the
|
||||
prev events for, e.g. when backfilling.
|
||||
partial_state: True if `old_state` is partial and omits non-critical
|
||||
membership events
|
||||
partial_state: True if `state_ids_before_event` is partial and omits
|
||||
non-critical membership events
|
||||
Returns:
|
||||
The event context.
|
||||
"""
|
||||
|
@ -221,11 +221,7 @@ class StateHandler:
|
|||
#
|
||||
# first of all, figure out the state before the event
|
||||
#
|
||||
if old_state:
|
||||
# if we're given the state before the event, then we use that
|
||||
state_ids_before_event: StateMap[str] = {
|
||||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
if state_ids_before_event:
|
||||
state_group_before_event = None
|
||||
state_group_before_event_prev_group = None
|
||||
deltas_to_state_group_before_event = None
|
||||
|
|
|
@ -16,6 +16,8 @@ import collections.abc
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
|
@ -26,6 +28,7 @@ from synapse.storage.database import (
|
|||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
|
@ -43,6 +46,15 @@ logger = logging.getLogger(__name__)
|
|||
MAX_STATE_DELTA_HOPS = 100
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class EventMetadata:
|
||||
"""Returned by `get_metadata_for_events`"""
|
||||
|
||||
room_id: str
|
||||
event_type: str
|
||||
state_key: Optional[str]
|
||||
|
||||
|
||||
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
||||
if not v:
|
||||
|
@ -133,6 +145,36 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return room_version
|
||||
|
||||
async def get_metadata_for_events(
|
||||
self, event_ids: Collection[str]
|
||||
) -> Dict[str, EventMetadata]:
|
||||
"""Get some metadata (room_id, type, state_key) for the given events."""
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "e.event_id", event_ids
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
|
||||
LEFT JOIN state_events USING (event_id)
|
||||
WHERE {clause}
|
||||
"""
|
||||
|
||||
def get_metadata_for_events_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[str, EventMetadata]:
|
||||
txn.execute(sql, args)
|
||||
return {
|
||||
event_id: EventMetadata(
|
||||
room_id=room_id, event_type=event_type, state_key=state_key
|
||||
)
|
||||
for event_id, room_id, event_type, state_key in txn
|
||||
}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_metadata_for_events", get_metadata_for_events_txn
|
||||
)
|
||||
|
||||
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
|
||||
"""Get the predecessor of an upgraded room if it exists.
|
||||
Otherwise return None.
|
||||
|
|
|
@ -276,7 +276,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
|||
# federation handler wanting to backfill the fake event.
|
||||
self.get_success(
|
||||
federation_event_handler._process_received_pdu(
|
||||
self.OTHER_SERVER_NAME, event, state=current_state
|
||||
self.OTHER_SERVER_NAME,
|
||||
event,
|
||||
state={(e.type, e.state_key): e.event_id for e in current_state},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
def persist_event(self, event, state=None):
|
||||
"""Persist the event, with optional state"""
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(event, old_state=state)
|
||||
self.state.compute_event_context(event, state_ids_before_event=state)
|
||||
)
|
||||
self.get_success(self.persistence.persist_event(event, context))
|
||||
|
||||
|
@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
|
@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
# setting. The state resolution across the old and new event will then
|
||||
# include it, and so the resolved state won't match the new state.
|
||||
state_before_gap = dict(
|
||||
self.get_success(self.store.get_current_state(self.room_id))
|
||||
self.get_success(self.store.get_current_state_ids(self.room_id))
|
||||
)
|
||||
state_before_gap.pop(("m.room.history_visibility", ""))
|
||||
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(
|
||||
remote_event_2, old_state=state_before_gap.values()
|
||||
remote_event_2,
|
||||
state_ids_before_event=state_before_gap,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
|
@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||
|
@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
|
@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
|
||||
|
@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
|
||||
|
|
|
@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
|
|||
]
|
||||
|
||||
context = yield defer.ensureDeferred(
|
||||
self.state.compute_event_context(event, old_state=old_state)
|
||||
self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in old_state
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
|
@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
|
|||
]
|
||||
|
||||
context = yield defer.ensureDeferred(
|
||||
self.state.compute_event_context(event, old_state=old_state)
|
||||
self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in old_state
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
|
|
Loading…
Reference in a new issue