First pass on sort_rooms and refactor to include room membership alongside the sync rooms

This commit is contained in:
Eric Eastwood 2024-06-12 15:46:37 -05:00
parent afb6627b6f
commit af60f7b508
2 changed files with 114 additions and 45 deletions

View file

@ -24,7 +24,14 @@ from immutabledict import immutabledict
from synapse.api.constants import Membership
from synapse.events import EventBase
from synapse.types import Requester, RoomStreamToken, StreamToken, UserID
from synapse.storage.roommember import RoomsForUser
from synapse.types import (
PersistedEventPosition,
Requester,
RoomStreamToken,
StreamToken,
UserID,
)
from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult
if TYPE_CHECKING:
@ -33,6 +40,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def convert_event_to_rooms_for_user(event: EventBase) -> RoomsForUser:
"""
Quick helper to convert an event to a `RoomsForUser` object.
"""
# These fields should be present for all persisted events
assert event.internal_metadata.stream_ordering is not None
assert event.internal_metadata.instance_name is not None
return RoomsForUser(
room_id=event.room_id,
sender=event.sender,
membership=event.membership,
event_id=event.event_id,
event_pos=PersistedEventPosition(
event.internal_metadata.instance_name,
event.internal_metadata.stream_ordering,
),
room_version_id=event.room_version.identifier,
)
def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool:
"""
Returns True if the membership event should be included in the sync response,
@ -151,25 +179,25 @@ class SlidingSyncHandler:
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
# Get all of the room IDs that the user should be able to see in the sync
# response
room_id_set = await self.get_sync_room_ids_for_user(
sync_config.user,
from_token=from_token,
to_token=to_token,
)
# Assemble sliding window lists
lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {}
if sync_config.lists:
# Get all of the room IDs that the user should be able to see in the sync
# response
sync_room_map = await self.get_sync_room_ids_for_user(
sync_config.user,
from_token=from_token,
to_token=to_token,
)
for list_key, list_config in sync_config.lists.items():
# TODO: Apply filters
#
# TODO: Exclude partially stated rooms unless the `required_state` has
# `["m.room.member", "$LAZY"]`
filtered_room_ids = room_id_set
filtered_room_map = sync_room_map
# TODO: Apply sorts
sorted_room_ids = await self.sort_rooms(filtered_room_ids, to_token)
sorted_room_ids = await self.sort_rooms(filtered_room_map, to_token)
ops: List[SlidingSyncResult.SlidingWindowList.Operation] = []
if list_config.ranges:
@ -200,7 +228,7 @@ class SlidingSyncHandler:
user: UserID,
to_token: StreamToken,
from_token: Optional[StreamToken] = None,
) -> AbstractSet[str]:
) -> Dict[str, RoomsForUser]:
"""
Fetch room IDs that should be listed for this user in the sync response (the
full room list that will be filtered, sorted, and sliced).
@ -217,6 +245,15 @@ class SlidingSyncHandler:
`forgotten` flag to the `room_memberships` table in Synapse. There isn't a way
to tell when a room was forgotten at the moment so we can't factor it into the
from/to range.
Args:
user: User to fetch rooms for
to_token: The token to fetch rooms up to.
from_token: The point in the stream to sync from.
Returns:
A dictionary of room IDs that should be listed in the sync response along
with membership information in that room at the time of `to_token`.
"""
user_id = user.to_string()
@ -236,11 +273,11 @@ class SlidingSyncHandler:
# If the user has never joined any rooms before, we can just return an empty list
if not room_for_user_list:
return set()
return {}
# Our working list of rooms that can show up in the sync response
sync_room_id_set = {
room_for_user.room_id
room_for_user.room_id: room_for_user
for room_for_user in room_for_user_list
if filter_membership_for_sync(
membership=room_for_user.membership,
@ -390,7 +427,9 @@ class SlidingSyncHandler:
not was_last_membership_already_included
and should_prev_membership_be_included
):
sync_room_id_set.add(room_id)
sync_room_id_set[room_id] = convert_event_to_rooms_for_user(
last_membership_change_after_to_token
)
# 1b) Remove rooms that the user joined (hasn't left) after the `to_token`
#
# For example, if the last membership event after the `to_token` is a "join"
@ -401,7 +440,7 @@ class SlidingSyncHandler:
was_last_membership_already_included
and not should_prev_membership_be_included
):
sync_room_id_set.discard(room_id)
del sync_room_id_set[room_id]
# 2) -----------------------------------------------------
# We fix-up newly_left rooms after the first fixup because it may have removed
@ -436,13 +475,15 @@ class SlidingSyncHandler:
# include newly_left rooms because the last event that the user should see
# is their own leave event
if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
sync_room_id_set.add(room_id)
sync_room_id_set[room_id] = convert_event_to_rooms_for_user(
last_membership_change_in_from_to_range
)
return sync_room_id_set
async def sort_rooms(
self,
room_id_set: AbstractSet[str],
sync_room_map: Dict[str, RoomsForUser],
to_token: StreamToken,
) -> List[str]:
"""
@ -450,11 +491,39 @@ class SlidingSyncHandler:
a stable sort, we tie-break by room ID.
Args:
room_id_set: Set of room IDs to sort
to_token: We sort based on the events in the room at this token
sync_room_map: Dictionary of room IDs to sort along with membership
information in the room at the time of `to_token`.
to_token: We sort based on the events in the room at this token (<= `to_token`)
"""
# TODO: `get_last_event_in_room_before_stream_ordering()`
# TODO: Handle when people are left/banned from the room and shouldn't see past that point
# Assemble a map of room ID to the `stream_ordering` of the last activity that the
# user should see in the room (<= `to_token`)
last_activity_in_room_map: Dict[str, int] = {}
for room_id, room_for_user in sync_room_map.items():
# If they are fully-joined to the room, let's find the latest activity
# at/before the `to_token`.
if room_for_user.membership == Membership.JOIN:
last_event_result = (
await self.store.get_last_event_in_room_before_stream_ordering(
room_id, to_token.room_key
)
)
return list(room_id_set)
# If the room has no events at/before the `to_token`, this is probably a
# mistake in the code that generates the `sync_room_map` since that should
# only give us rooms that the user had membership in during the token range.
assert last_event_result is not None
_, event_pos = last_event_result
last_activity_in_room_map[room_id] = event_pos.stream
else:
# Otherwise, if the user left/banned from the room, they shouldn't see
# past that point. (same for invites/knocks)
last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
return sorted(
sync_room_map.keys(),
# Sort by the last activity (stream_ordering) in the room, tie-break on room_id
key=lambda room_id: (last_activity_in_room_map[room_id], room_id),
)

View file

@ -78,7 +78,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
)
self.assertEqual(room_id_results, set())
self.assertEqual(room_id_results.keys(), set())
def test_get_newly_joined_room(self) -> None:
"""
@ -102,7 +102,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
)
self.assertEqual(room_id_results, {room_id})
self.assertEqual(room_id_results.keys(), {room_id})
def test_get_already_joined_room(self) -> None:
"""
@ -123,7 +123,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
)
self.assertEqual(room_id_results, {room_id})
self.assertEqual(room_id_results.keys(), {room_id})
def test_get_invited_banned_knocked_room(self) -> None:
"""
@ -179,7 +179,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Ensure that the invited, ban, and knock rooms show up
self.assertEqual(
room_id_results,
room_id_results.keys(),
{
invited_room_id,
ban_room_id,
@ -225,7 +225,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# The kicked room should show up
self.assertEqual(room_id_results, {kick_room_id})
self.assertEqual(room_id_results.keys(), {kick_room_id})
def test_forgotten_rooms(self) -> None:
"""
@ -307,7 +307,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# We shouldn't see the room because it was forgotten
self.assertEqual(room_id_results, set())
self.assertEqual(room_id_results.keys(), set())
def test_only_newly_left_rooms_show_up(self) -> None:
"""
@ -339,7 +339,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Only the newly_left room should show up
self.assertEqual(room_id_results, {room_id2})
self.assertEqual(room_id_results.keys(), {room_id2})
def test_no_joins_after_to_token(self) -> None:
"""
@ -367,7 +367,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
)
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_join_during_range_and_left_room_after_to_token(self) -> None:
"""
@ -397,7 +397,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# We should still see the room because we were joined during the
# from_token/to_token time period.
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_join_before_range_and_left_room_after_to_token(self) -> None:
"""
@ -424,7 +424,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# We should still see the room because we were joined before the `from_token`
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_kicked_before_range_and_left_after_to_token(self) -> None:
"""
@ -472,7 +472,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# We shouldn't see the room because it was forgotten
self.assertEqual(room_id_results, {kick_room_id})
self.assertEqual(room_id_results.keys(), {kick_room_id})
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
"""
@ -509,7 +509,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_newly_left_during_range_and_join_after_to_token(self) -> None:
"""
@ -545,7 +545,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_no_from_token(self) -> None:
"""
@ -586,7 +586,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Only rooms we were joined to before the `to_token` should show up
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_from_token_ahead_of_to_token(self) -> None:
"""
@ -647,7 +647,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
#
# There won't be any newly_left rooms because the `from_token` is ahead of the
# `to_token` and that range will give no membership changes to check.
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
"""
@ -682,7 +682,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Room shouldn't show up because it was left before the `from_token`
self.assertEqual(room_id_results, set())
self.assertEqual(room_id_results.keys(), set())
def test_leave_before_range_and_join_after_to_token(self) -> None:
"""
@ -716,7 +716,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Room shouldn't show up because it was left before the `from_token`
self.assertEqual(room_id_results, set())
self.assertEqual(room_id_results.keys(), set())
def test_join_leave_multiple_times_during_range_and_after_to_token(
self,
@ -758,7 +758,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Room should show up because it was newly_left and joined during the from/to range
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_join_leave_multiple_times_before_range_and_after_to_token(
self,
@ -798,7 +798,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Room should show up because we were joined before the from/to range
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_invite_before_range_and_join_leave_after_to_token(
self,
@ -835,7 +835,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Room should show up because we were invited before the from/to range
self.assertEqual(room_id_results, {room_id1})
self.assertEqual(room_id_results.keys(), {room_id1})
def test_multiple_rooms_are_not_confused(
self,
@ -888,7 +888,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
self.assertEqual(
room_id_results,
room_id_results.keys(),
{
# `room_id1` shouldn't show up because we left before the from/to range
#
@ -1099,7 +1099,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertEqual(
room_id_results,
room_id_results.keys(),
{
room_id1,
# room_id2 shouldn't show up because we left before the from/to range