First pass on sort_rooms and refactor to include room membership alongside the sync rooms

This commit is contained in:
Eric Eastwood 2024-06-12 15:46:37 -05:00
parent afb6627b6f
commit af60f7b508
2 changed files with 114 additions and 45 deletions

View file

@ -24,7 +24,14 @@ from immutabledict import immutabledict
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import Requester, RoomStreamToken, StreamToken, UserID from synapse.storage.roommember import RoomsForUser
from synapse.types import (
PersistedEventPosition,
Requester,
RoomStreamToken,
StreamToken,
UserID,
)
from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult
if TYPE_CHECKING: if TYPE_CHECKING:
@ -33,6 +40,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def convert_event_to_rooms_for_user(event: EventBase) -> RoomsForUser:
"""
Quick helper to convert an event to a `RoomsForUser` object.
"""
# These fields should be present for all persisted events
assert event.internal_metadata.stream_ordering is not None
assert event.internal_metadata.instance_name is not None
return RoomsForUser(
room_id=event.room_id,
sender=event.sender,
membership=event.membership,
event_id=event.event_id,
event_pos=PersistedEventPosition(
event.internal_metadata.instance_name,
event.internal_metadata.stream_ordering,
),
room_version_id=event.room_version.identifier,
)
def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool: def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool:
""" """
Returns True if the membership event should be included in the sync response, Returns True if the membership event should be included in the sync response,
@ -151,25 +179,25 @@ class SlidingSyncHandler:
# See https://github.com/matrix-org/matrix-doc/issues/1144 # See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError() raise NotImplementedError()
# Get all of the room IDs that the user should be able to see in the sync
# response
room_id_set = await self.get_sync_room_ids_for_user(
sync_config.user,
from_token=from_token,
to_token=to_token,
)
# Assemble sliding window lists # Assemble sliding window lists
lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {} lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {}
if sync_config.lists: if sync_config.lists:
# Get all of the room IDs that the user should be able to see in the sync
# response
sync_room_map = await self.get_sync_room_ids_for_user(
sync_config.user,
from_token=from_token,
to_token=to_token,
)
for list_key, list_config in sync_config.lists.items(): for list_key, list_config in sync_config.lists.items():
# TODO: Apply filters # TODO: Apply filters
# #
# TODO: Exclude partially stated rooms unless the `required_state` has # TODO: Exclude partially stated rooms unless the `required_state` has
# `["m.room.member", "$LAZY"]` # `["m.room.member", "$LAZY"]`
filtered_room_ids = room_id_set filtered_room_map = sync_room_map
# TODO: Apply sorts # TODO: Apply sorts
sorted_room_ids = await self.sort_rooms(filtered_room_ids, to_token) sorted_room_ids = await self.sort_rooms(filtered_room_map, to_token)
ops: List[SlidingSyncResult.SlidingWindowList.Operation] = [] ops: List[SlidingSyncResult.SlidingWindowList.Operation] = []
if list_config.ranges: if list_config.ranges:
@ -200,7 +228,7 @@ class SlidingSyncHandler:
user: UserID, user: UserID,
to_token: StreamToken, to_token: StreamToken,
from_token: Optional[StreamToken] = None, from_token: Optional[StreamToken] = None,
) -> AbstractSet[str]: ) -> Dict[str, RoomsForUser]:
""" """
Fetch room IDs that should be listed for this user in the sync response (the Fetch room IDs that should be listed for this user in the sync response (the
full room list that will be filtered, sorted, and sliced). full room list that will be filtered, sorted, and sliced).
@ -217,6 +245,15 @@ class SlidingSyncHandler:
`forgotten` flag to the `room_memberships` table in Synapse. There isn't a way `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way
to tell when a room was forgotten at the moment so we can't factor it into the to tell when a room was forgotten at the moment so we can't factor it into the
from/to range. from/to range.
Args:
user: User to fetch rooms for
to_token: The token to fetch rooms up to.
from_token: The point in the stream to sync from.
Returns:
A dictionary of room IDs that should be listed in the sync response along
with membership information in that room at the time of `to_token`.
""" """
user_id = user.to_string() user_id = user.to_string()
@ -236,11 +273,11 @@ class SlidingSyncHandler:
# If the user has never joined any rooms before, we can just return an empty list # If the user has never joined any rooms before, we can just return an empty list
if not room_for_user_list: if not room_for_user_list:
return set() return {}
# Our working list of rooms that can show up in the sync response # Our working list of rooms that can show up in the sync response
sync_room_id_set = { sync_room_id_set = {
room_for_user.room_id room_for_user.room_id: room_for_user
for room_for_user in room_for_user_list for room_for_user in room_for_user_list
if filter_membership_for_sync( if filter_membership_for_sync(
membership=room_for_user.membership, membership=room_for_user.membership,
@ -390,7 +427,9 @@ class SlidingSyncHandler:
not was_last_membership_already_included not was_last_membership_already_included
and should_prev_membership_be_included and should_prev_membership_be_included
): ):
sync_room_id_set.add(room_id) sync_room_id_set[room_id] = convert_event_to_rooms_for_user(
last_membership_change_after_to_token
)
# 1b) Remove rooms that the user joined (hasn't left) after the `to_token` # 1b) Remove rooms that the user joined (hasn't left) after the `to_token`
# #
# For example, if the last membership event after the `to_token` is a "join" # For example, if the last membership event after the `to_token` is a "join"
@ -401,7 +440,7 @@ class SlidingSyncHandler:
was_last_membership_already_included was_last_membership_already_included
and not should_prev_membership_be_included and not should_prev_membership_be_included
): ):
sync_room_id_set.discard(room_id) del sync_room_id_set[room_id]
# 2) ----------------------------------------------------- # 2) -----------------------------------------------------
# We fix-up newly_left rooms after the first fixup because it may have removed # We fix-up newly_left rooms after the first fixup because it may have removed
@ -436,13 +475,15 @@ class SlidingSyncHandler:
# include newly_left rooms because the last event that the user should see # include newly_left rooms because the last event that the user should see
# is their own leave event # is their own leave event
if last_membership_change_in_from_to_range.membership == Membership.LEAVE: if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
sync_room_id_set.add(room_id) sync_room_id_set[room_id] = convert_event_to_rooms_for_user(
last_membership_change_in_from_to_range
)
return sync_room_id_set return sync_room_id_set
async def sort_rooms( async def sort_rooms(
self, self,
room_id_set: AbstractSet[str], sync_room_map: Dict[str, RoomsForUser],
to_token: StreamToken, to_token: StreamToken,
) -> List[str]: ) -> List[str]:
""" """
@ -450,11 +491,39 @@ class SlidingSyncHandler:
a stable sort, we tie-break by room ID. a stable sort, we tie-break by room ID.
Args: Args:
room_id_set: Set of room IDs to sort sync_room_map: Dictionary of room IDs to sort along with membership
to_token: We sort based on the events in the room at this token information in the room at the time of `to_token`.
to_token: We sort based on the events in the room at this token (<= `to_token`)
""" """
# TODO: `get_last_event_in_room_before_stream_ordering()`
# TODO: Handle when people are left/banned from the room and shouldn't see past that point # Assemble a map of room ID to the `stream_ordering` of the last activity that the
# user should see in the room (<= `to_token`)
last_activity_in_room_map: Dict[str, int] = {}
for room_id, room_for_user in sync_room_map.items():
# If they are fully-joined to the room, let's find the latest activity
# at/before the `to_token`.
if room_for_user.membership == Membership.JOIN:
last_event_result = (
await self.store.get_last_event_in_room_before_stream_ordering(
room_id, to_token.room_key
)
)
return list(room_id_set) # If the room has no events at/before the `to_token`, this is probably a
# mistake in the code that generates the `sync_room_map` since that should
# only give us rooms that the user had membership in during the token range.
assert last_event_result is not None
_, event_pos = last_event_result
last_activity_in_room_map[room_id] = event_pos.stream
else:
# Otherwise, if the user left/banned from the room, they shouldn't see
# past that point. (same for invites/knocks)
last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
return sorted(
sync_room_map.keys(),
# Sort by the last activity (stream_ordering) in the room, tie-break on room_id
key=lambda room_id: (last_activity_in_room_map[room_id], room_id),
)

View file

@ -78,7 +78,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
) )
self.assertEqual(room_id_results, set()) self.assertEqual(room_id_results.keys(), set())
def test_get_newly_joined_room(self) -> None: def test_get_newly_joined_room(self) -> None:
""" """
@ -102,7 +102,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
) )
self.assertEqual(room_id_results, {room_id}) self.assertEqual(room_id_results.keys(), {room_id})
def test_get_already_joined_room(self) -> None: def test_get_already_joined_room(self) -> None:
""" """
@ -123,7 +123,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
) )
self.assertEqual(room_id_results, {room_id}) self.assertEqual(room_id_results.keys(), {room_id})
def test_get_invited_banned_knocked_room(self) -> None: def test_get_invited_banned_knocked_room(self) -> None:
""" """
@ -179,7 +179,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Ensure that the invited, ban, and knock rooms show up # Ensure that the invited, ban, and knock rooms show up
self.assertEqual( self.assertEqual(
room_id_results, room_id_results.keys(),
{ {
invited_room_id, invited_room_id,
ban_room_id, ban_room_id,
@ -225,7 +225,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# The kicked room should show up # The kicked room should show up
self.assertEqual(room_id_results, {kick_room_id}) self.assertEqual(room_id_results.keys(), {kick_room_id})
def test_forgotten_rooms(self) -> None: def test_forgotten_rooms(self) -> None:
""" """
@ -307,7 +307,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# We shouldn't see the room because it was forgotten # We shouldn't see the room because it was forgotten
self.assertEqual(room_id_results, set()) self.assertEqual(room_id_results.keys(), set())
def test_only_newly_left_rooms_show_up(self) -> None: def test_only_newly_left_rooms_show_up(self) -> None:
""" """
@ -339,7 +339,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Only the newly_left room should show up # Only the newly_left room should show up
self.assertEqual(room_id_results, {room_id2}) self.assertEqual(room_id_results.keys(), {room_id2})
def test_no_joins_after_to_token(self) -> None: def test_no_joins_after_to_token(self) -> None:
""" """
@ -367,7 +367,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
) )
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_join_during_range_and_left_room_after_to_token(self) -> None: def test_join_during_range_and_left_room_after_to_token(self) -> None:
""" """
@ -397,7 +397,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# We should still see the room because we were joined during the # We should still see the room because we were joined during the
# from_token/to_token time period. # from_token/to_token time period.
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_join_before_range_and_left_room_after_to_token(self) -> None: def test_join_before_range_and_left_room_after_to_token(self) -> None:
""" """
@ -424,7 +424,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# We should still see the room because we were joined before the `from_token` # We should still see the room because we were joined before the `from_token`
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_kicked_before_range_and_left_after_to_token(self) -> None: def test_kicked_before_range_and_left_after_to_token(self) -> None:
""" """
@ -472,7 +472,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# We shouldn't see the room because it was forgotten # We shouldn't see the room because it was forgotten
self.assertEqual(room_id_results, {kick_room_id}) self.assertEqual(room_id_results.keys(), {kick_room_id})
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None: def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
""" """
@ -509,7 +509,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Room should still show up because it's newly_left during the from/to range # Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_newly_left_during_range_and_join_after_to_token(self) -> None: def test_newly_left_during_range_and_join_after_to_token(self) -> None:
""" """
@ -545,7 +545,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Room should still show up because it's newly_left during the from/to range # Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_no_from_token(self) -> None: def test_no_from_token(self) -> None:
""" """
@ -586,7 +586,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Only rooms we were joined to before the `to_token` should show up # Only rooms we were joined to before the `to_token` should show up
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_from_token_ahead_of_to_token(self) -> None: def test_from_token_ahead_of_to_token(self) -> None:
""" """
@ -647,7 +647,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# #
# There won't be any newly_left rooms because the `from_token` is ahead of the # There won't be any newly_left rooms because the `from_token` is ahead of the
# `to_token` and that range will give no membership changes to check. # `to_token` and that range will give no membership changes to check.
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_leave_before_range_and_join_leave_after_to_token(self) -> None: def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
""" """
@ -682,7 +682,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Room shouldn't show up because it was left before the `from_token` # Room shouldn't show up because it was left before the `from_token`
self.assertEqual(room_id_results, set()) self.assertEqual(room_id_results.keys(), set())
def test_leave_before_range_and_join_after_to_token(self) -> None: def test_leave_before_range_and_join_after_to_token(self) -> None:
""" """
@ -716,7 +716,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Room shouldn't show up because it was left before the `from_token` # Room shouldn't show up because it was left before the `from_token`
self.assertEqual(room_id_results, set()) self.assertEqual(room_id_results.keys(), set())
def test_join_leave_multiple_times_during_range_and_after_to_token( def test_join_leave_multiple_times_during_range_and_after_to_token(
self, self,
@ -758,7 +758,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Room should show up because it was newly_left and joined during the from/to range # Room should show up because it was newly_left and joined during the from/to range
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_join_leave_multiple_times_before_range_and_after_to_token( def test_join_leave_multiple_times_before_range_and_after_to_token(
self, self,
@ -798,7 +798,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Room should show up because we were joined before the from/to range # Room should show up because we were joined before the from/to range
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_invite_before_range_and_join_leave_after_to_token( def test_invite_before_range_and_join_leave_after_to_token(
self, self,
@ -835,7 +835,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
# Room should show up because we were invited before the from/to range # Room should show up because we were invited before the from/to range
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results.keys(), {room_id1})
def test_multiple_rooms_are_not_confused( def test_multiple_rooms_are_not_confused(
self, self,
@ -888,7 +888,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
self.assertEqual( self.assertEqual(
room_id_results, room_id_results.keys(),
{ {
# `room_id1` shouldn't show up because we left before the from/to range # `room_id1` shouldn't show up because we left before the from/to range
# #
@ -1099,7 +1099,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
) )
self.assertEqual( self.assertEqual(
room_id_results, room_id_results.keys(),
{ {
room_id1, room_id1,
# room_id2 shouldn't show up because we left before the from/to range # room_id2 shouldn't show up because we left before the from/to range