Refactor back to not pulling out full events

This commit is contained in:
Eric Eastwood 2024-06-25 23:40:55 -05:00
parent daa7e3691a
commit cccbd15e7e
2 changed files with 125 additions and 109 deletions

View file

@ -28,7 +28,6 @@ from synapse.events import EventBase
from synapse.events.utils import strip_event
from synapse.handlers.relations import BundledAggregations
from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
from synapse.storage.roommember import RoomsForUser
from synapse.types import (
JsonDict,
PersistedEventPosition,
@ -48,27 +47,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def convert_event_to_rooms_for_user(event: EventBase) -> RoomsForUser:
"""
Quick helper to convert an event to a `RoomsForUser` object.
"""
# These fields should be present for all persisted events
assert event.internal_metadata.stream_ordering is not None
assert event.internal_metadata.instance_name is not None
return RoomsForUser(
room_id=event.room_id,
sender=event.sender,
membership=event.membership,
event_id=event.event_id,
event_pos=PersistedEventPosition(
event.internal_metadata.instance_name,
event.internal_metadata.stream_ordering,
),
room_version_id=event.room_version.identifier,
)
def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool:
"""
Returns True if the membership event should be included in the sync response,
@ -108,6 +86,25 @@ class RoomSyncConfig:
required_state: Set[Tuple[str, str]]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _RoomMembershipForUser:
"""
Attributes:
event_id: The event ID of the membership event
event_pos: The stream position of the membership event
membership: The membership state of the user in the room
sender: The person who sent the membership event
newly_joined: Whether the user newly joined the room during the given token
range
"""
event_id: str
event_pos: PersistedEventPosition
membership: str
sender: str
newly_joined: bool
class SlidingSyncHandler:
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
@ -302,7 +299,7 @@ class SlidingSyncHandler:
user=sync_config.user,
room_id=room_id,
room_sync_config=room_sync_config,
rooms_for_user_membership_at_to_token=sync_room_map[room_id],
rooms_membership_for_user_at_to_token=sync_room_map[room_id],
from_token=from_token,
to_token=to_token,
)
@ -321,7 +318,7 @@ class SlidingSyncHandler:
user: UserID,
to_token: StreamToken,
from_token: Optional[StreamToken] = None,
) -> Dict[str, RoomsForUser]:
) -> Dict[str, _RoomMembershipForUser]:
"""
Fetch room IDs that should be listed for this user in the sync response (the
full room list that will be filtered, sorted, and sliced).
@ -373,7 +370,13 @@ class SlidingSyncHandler:
# Note: The `room_for_user` we're assigning here will need to be fixed up
# (below) because they are potentially from the current snapshot time
# instead from the time of the `to_token`.
room_for_user.room_id: room_for_user
room_for_user.room_id: _RoomMembershipForUser(
event_id=room_for_user.event_id,
event_pos=room_for_user.event_pos,
membership=room_for_user.membership,
sender=room_for_user.sender,
newly_joined=False,
)
for room_for_user in room_for_user_list
}
@ -440,7 +443,7 @@ class SlidingSyncHandler:
for membership_change in current_state_delta_membership_changes_after_to_token:
# Only set if we haven't already set it
first_membership_change_by_room_id_after_to_token.setdefault(
membership_change.event.room_id, membership_change
membership_change.room_id, membership_change
)
# 1) Fixup
@ -448,27 +451,59 @@ class SlidingSyncHandler:
# Since we fetched a snapshot of the users room list at some point in time after
# the from/to tokens, we need to revert/rewind some membership changes to match
# the point in time of the `to_token`.
prev_event_ids_in_from_to_range: List[str] = []
for (
room_id,
first_membership_change_after_to_token,
) in first_membership_change_by_room_id_after_to_token.items():
# 1a) Remove rooms that the user joined after the `to_token`
if first_membership_change_after_to_token.prev_event is None:
if first_membership_change_after_to_token.prev_event_id is None:
sync_room_id_set.pop(room_id, None)
# 1b) 1c) From the first membership event after the `to_token`, step backward to the
# previous membership that would apply to the from/to range.
else:
sync_room_id_set[room_id] = convert_event_to_rooms_for_user(
first_membership_change_after_to_token.prev_event
prev_event_ids_in_from_to_range.append(
first_membership_change_after_to_token.prev_event_id
)
# 1) Fixup (more)
#
# 1b) 1c) Fetch the previous membership events that apply to the from/to range
# and fixup our working list.
prev_events_in_from_to_range = await self.store.get_events(
prev_event_ids_in_from_to_range
)
for prev_event_in_from_to_range in prev_events_in_from_to_range.values():
# These fields should be present for all persisted events
assert (
prev_event_in_from_to_range.internal_metadata.instance_name is not None
)
assert (
prev_event_in_from_to_range.internal_metadata.stream_ordering
is not None
)
# 1b) 1c) Update the membership with what we found
sync_room_id_set[prev_event_in_from_to_range.room_id] = (
_RoomMembershipForUser(
event_id=prev_event_in_from_to_range.event_id,
event_pos=PersistedEventPosition(
instance_name=prev_event_in_from_to_range.internal_metadata.instance_name,
stream=prev_event_in_from_to_range.internal_metadata.stream_ordering,
),
membership=prev_event_in_from_to_range.membership,
sender=prev_event_in_from_to_range.sender,
newly_joined=False,
)
)
filtered_sync_room_id_set = {
room_id: room_for_user
for room_id, room_for_user in sync_room_id_set.items()
room_id: room_membership_for_user
for room_id, room_membership_for_user in sync_room_id_set.items()
if filter_membership_for_sync(
membership=room_for_user.membership,
membership=room_membership_for_user.membership,
user_id=user_id,
sender=room_for_user.sender,
sender=room_membership_for_user.sender,
)
}
@ -498,35 +533,38 @@ class SlidingSyncHandler:
membership_change
) in current_state_delta_membership_changes_in_from_to_range:
last_membership_change_by_room_id_in_from_to_range[
membership_change.event.room_id
membership_change.room_id
] = membership_change
# 2) Fixup
for (
last_membership_change_in_from_to_range
) in last_membership_change_by_room_id_in_from_to_range.values():
room_id = last_membership_change_in_from_to_range.event.room_id
room_id = last_membership_change_in_from_to_range.room_id
# 2) Add back newly_left rooms (> `from_token` and <= `to_token`). We
# include newly_left rooms because the last event that the user should see
# is their own leave event
if (
last_membership_change_in_from_to_range.event.membership
== Membership.LEAVE
):
filtered_sync_room_id_set[room_id] = convert_event_to_rooms_for_user(
last_membership_change_in_from_to_range.event
if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
filtered_sync_room_id_set[room_id] = _RoomMembershipForUser(
event_id=last_membership_change_in_from_to_range.event_id,
event_pos=last_membership_change_in_from_to_range.event_pos,
membership=last_membership_change_in_from_to_range.membership,
sender=last_membership_change_in_from_to_range.sender,
newly_joined=False,
)
# TODO: Figure out `newly_joined`
return filtered_sync_room_id_set
async def filter_rooms(
self,
user: UserID,
sync_room_map: Dict[str, RoomsForUser],
sync_room_map: Dict[str, _RoomMembershipForUser],
filters: SlidingSyncConfig.SlidingSyncList.Filters,
to_token: StreamToken,
) -> Dict[str, RoomsForUser]:
) -> Dict[str, _RoomMembershipForUser]:
"""
Filter rooms based on the sync request.
@ -627,9 +665,9 @@ class SlidingSyncHandler:
async def sort_rooms(
self,
sync_room_map: Dict[str, RoomsForUser],
sync_room_map: Dict[str, _RoomMembershipForUser],
to_token: StreamToken,
) -> List[Tuple[str, RoomsForUser]]:
) -> List[Tuple[str, _RoomMembershipForUser]]:
"""
Sort by `stream_ordering` of the last event that the user should see in the
room. `stream_ordering` is unique so we get a stable sort.
@ -682,7 +720,7 @@ class SlidingSyncHandler:
user: UserID,
room_id: str,
room_sync_config: RoomSyncConfig,
rooms_for_user_membership_at_to_token: RoomsForUser,
rooms_membership_for_user_at_to_token: _RoomMembershipForUser,
from_token: Optional[StreamToken],
to_token: StreamToken,
) -> SlidingSyncResult.RoomResult:
@ -696,7 +734,7 @@ class SlidingSyncHandler:
room_id: The room ID to fetch data for
room_sync_config: Config for what data we should fetch for a room in the
sync response.
rooms_for_user_membership_at_to_token: Membership information for the user
rooms_membership_for_user_at_to_token: Membership information for the user
in the room at the time of `to_token`.
from_token: The point in the stream to sync from.
to_token: The point in the stream to sync up to.
@ -716,7 +754,7 @@ class SlidingSyncHandler:
if (
room_sync_config.timeline_limit > 0
# No timeline for invite/knock rooms (just `stripped_state`)
and rooms_for_user_membership_at_to_token.membership
and rooms_membership_for_user_at_to_token.membership
not in (Membership.INVITE, Membership.KNOCK)
):
limited = False
@ -726,27 +764,15 @@ class SlidingSyncHandler:
# position once we've fetched the events to point to the earliest event fetched.
prev_batch_token = to_token
newly_joined = False
if (
# We can only determine new-ness if we have a `from_token` to define our range
from_token is not None
and rooms_for_user_membership_at_to_token.membership == Membership.JOIN
):
newly_joined = (
rooms_for_user_membership_at_to_token.event_pos.persisted_after(
from_token.room_key
)
)
# We're going to paginate backwards from the `to_token`
from_bound = to_token.room_key
# People shouldn't see past their leave/ban event
if rooms_for_user_membership_at_to_token.membership in (
if rooms_membership_for_user_at_to_token.membership in (
Membership.LEAVE,
Membership.BAN,
):
from_bound = (
rooms_for_user_membership_at_to_token.event_pos.to_room_stream_token()
rooms_membership_for_user_at_to_token.event_pos.to_room_stream_token()
)
# Determine whether we should limit the timeline to the token range.
@ -760,7 +786,8 @@ class SlidingSyncHandler:
# connection before
to_bound = (
from_token.room_key
if from_token is not None and not newly_joined
if from_token is not None
and not rooms_membership_for_user_at_to_token.newly_joined
else None
)
@ -797,7 +824,7 @@ class SlidingSyncHandler:
self.storage_controllers,
user.to_string(),
timeline_events,
is_peeking=rooms_for_user_membership_at_to_token.membership
is_peeking=rooms_membership_for_user_at_to_token.membership
!= Membership.JOIN,
filter_send_to_client=True,
)
@ -852,12 +879,12 @@ class SlidingSyncHandler:
# Figure out any stripped state events for invite/knocks. This allows the
# potential joiner to identify the room.
stripped_state: List[JsonDict] = []
if rooms_for_user_membership_at_to_token.membership in (
if rooms_membership_for_user_at_to_token.membership in (
Membership.INVITE,
Membership.KNOCK,
):
invite_or_knock_event = await self.store.get_event(
rooms_for_user_membership_at_to_token.event_id
rooms_membership_for_user_at_to_token.event_id
)
stripped_state = []

View file

@ -112,32 +112,25 @@ class _EventsAround:
end: RoomStreamToken
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CurrentStateDeltaMembershipReturn:
"""
Attributes:
event_id: The "current" membership event ID in this room.
prev_event_id: The previous membership event in this room that was replaced by
the "current" one. May be `None` if there was no previous membership event.
room_id: The room ID of the membership event.
"""
event_id: str
prev_event_id: Optional[str]
room_id: str
@attr.s(slots=True, frozen=True, auto_attribs=True)
class CurrentStateDeltaMembership:
"""
Attributes:
event: The "current" membership event in this room.
prev_event: The previous membership event in this room that was replaced by
event_id: The "current" membership event ID in this room.
event_pos: The position of the "current" membership event in the event stream.
prev_event_id: The previous membership event in this room that was replaced by
the "current" one. May be `None` if there was no previous membership event.
room_id: The room ID of the membership event.
membership: The membership state of the user in the room
sender: The person who sent the membership event
"""
event: EventBase
prev_event: Optional[EventBase]
event_id: str
event_pos: PersistedEventPosition
prev_event_id: Optional[str]
room_id: str
membership: str
sender: str
def generate_pagination_where_clause(
@ -808,7 +801,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if not has_changed:
return []
def f(txn: LoggingTransaction) -> List[_CurrentStateDeltaMembershipReturn]:
def f(txn: LoggingTransaction) -> List[CurrentStateDeltaMembership]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@ -833,7 +826,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
s.room_id,
s.instance_name,
s.stream_id,
e.topological_ordering
e.topological_ordering,
m.membership,
e.sender
FROM current_state_delta_stream AS s
INNER JOIN events AS e ON e.stream_ordering = s.stream_id
INNER JOIN room_memberships AS m ON m.event_id = e.event_id
@ -844,7 +839,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, args)
membership_changes: List[_CurrentStateDeltaMembershipReturn] = []
membership_changes: List[CurrentStateDeltaMembership] = []
for (
event_id,
prev_event_id,
@ -852,6 +847,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
instance_name,
stream_ordering,
topological_ordering,
membership,
sender,
) in txn:
assert event_id is not None
# `prev_event_id` can be `None`
@ -859,6 +856,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
assert instance_name is not None
assert stream_ordering is not None
assert topological_ordering is not None
assert membership is not None
assert sender is not None
if _filter_results(
from_key,
@ -868,43 +867,33 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
stream_ordering,
):
membership_changes.append(
_CurrentStateDeltaMembershipReturn(
CurrentStateDeltaMembership(
event_id=event_id,
event_pos=PersistedEventPosition(
instance_name=instance_name,
stream=stream_ordering,
),
prev_event_id=prev_event_id,
room_id=room_id,
membership=membership,
sender=sender,
)
)
return membership_changes
raw_membership_changes = await self.db_pool.runInteraction(
membership_changes = await self.db_pool.runInteraction(
"get_current_state_delta_membership_changes_for_user", f
)
# Fetch all events in one go
event_ids = []
for m in raw_membership_changes:
event_ids.append(m.event_id)
if m.prev_event_id is not None:
event_ids.append(m.prev_event_id)
events = await self.get_events(event_ids, get_prev_content=False)
room_ids_to_exclude: AbstractSet[str] = set()
if excluded_room_ids is not None:
room_ids_to_exclude = set(excluded_room_ids)
return [
CurrentStateDeltaMembership(
event=events[raw_membership_change.event_id],
prev_event=(
events[raw_membership_change.prev_event_id]
if raw_membership_change.prev_event_id
else None
),
)
for raw_membership_change in raw_membership_changes
if raw_membership_change.room_id not in room_ids_to_exclude
membership_change
for membership_change in membership_changes
if membership_change.room_id not in room_ids_to_exclude
]
@cancellable