diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 352e00c79c..d22191f1b9 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -24,7 +24,14 @@ from immutabledict import immutabledict from synapse.api.constants import Membership from synapse.events import EventBase -from synapse.types import Requester, RoomStreamToken, StreamToken, UserID +from synapse.storage.roommember import RoomsForUser +from synapse.types import ( + PersistedEventPosition, + Requester, + RoomStreamToken, + StreamToken, + UserID, +) from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult if TYPE_CHECKING: @@ -33,6 +40,27 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def convert_event_to_rooms_for_user(event: EventBase) -> RoomsForUser: + """ + Quick helper to convert an event to a `RoomsForUser` object. + """ + # These fields should be present for all persisted events + assert event.internal_metadata.stream_ordering is not None + assert event.internal_metadata.instance_name is not None + + return RoomsForUser( + room_id=event.room_id, + sender=event.sender, + membership=event.membership, + event_id=event.event_id, + event_pos=PersistedEventPosition( + event.internal_metadata.instance_name, + event.internal_metadata.stream_ordering, + ), + room_version_id=event.room_version.identifier, + ) + + def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool: """ Returns True if the membership event should be included in the sync response, @@ -151,25 +179,25 @@ class SlidingSyncHandler: # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() - # Get all of the room IDs that the user should be able to see in the sync - # response - room_id_set = await self.get_sync_room_ids_for_user( - sync_config.user, - from_token=from_token, - to_token=to_token, - ) - # Assemble sliding window lists lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {} if sync_config.lists: + # Get all of the room IDs that the user should be able to see in the sync + # response + sync_room_map = await self.get_sync_room_ids_for_user( + sync_config.user, + from_token=from_token, + to_token=to_token, + ) + for list_key, list_config in sync_config.lists.items(): # TODO: Apply filters # # TODO: Exclude partially stated rooms unless the `required_state` has # `["m.room.member", "$LAZY"]` - filtered_room_ids = room_id_set + filtered_room_map = sync_room_map # TODO: Apply sorts - sorted_room_ids = await self.sort_rooms(filtered_room_ids, to_token) + sorted_room_ids = await self.sort_rooms(filtered_room_map, to_token) ops: List[SlidingSyncResult.SlidingWindowList.Operation] = [] if list_config.ranges: @@ -200,7 +228,7 @@ class SlidingSyncHandler: user: UserID, to_token: StreamToken, from_token: Optional[StreamToken] = None, - ) -> AbstractSet[str]: + ) -> Dict[str, RoomsForUser]: """ Fetch room IDs that should be listed for this user in the sync response (the full room list that will be filtered, sorted, and sliced). @@ -217,6 +245,15 @@ class SlidingSyncHandler: `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way to tell when a room was forgotten at the moment so we can't factor it into the from/to range. + + Args: + user: User to fetch rooms for + to_token: The token to fetch rooms up to. + from_token: The point in the stream to sync from. + + Returns: + A dictionary of room IDs that should be listed in the sync response along + with membership information in that room at the time of `to_token`. """ user_id = user.to_string() @@ -236,11 +273,11 @@ class SlidingSyncHandler: # If the user has never joined any rooms before, we can just return an empty list if not room_for_user_list: - return set() + return {} # Our working list of rooms that can show up in the sync response sync_room_id_set = { - room_for_user.room_id + room_for_user.room_id: room_for_user for room_for_user in room_for_user_list if filter_membership_for_sync( membership=room_for_user.membership, @@ -390,7 +427,9 @@ class SlidingSyncHandler: not was_last_membership_already_included and should_prev_membership_be_included ): - sync_room_id_set.add(room_id) + sync_room_id_set[room_id] = convert_event_to_rooms_for_user( + last_membership_change_after_to_token + ) # 1b) Remove rooms that the user joined (hasn't left) after the `to_token` # # For example, if the last membership event after the `to_token` is a "join" @@ -401,7 +440,7 @@ class SlidingSyncHandler: was_last_membership_already_included and not should_prev_membership_be_included ): - sync_room_id_set.discard(room_id) + del sync_room_id_set[room_id] # 2) ----------------------------------------------------- # We fix-up newly_left rooms after the first fixup because it may have removed @@ -436,13 +475,15 @@ class SlidingSyncHandler: # include newly_left rooms because the last event that the user should see # is their own leave event if last_membership_change_in_from_to_range.membership == Membership.LEAVE: - sync_room_id_set.add(room_id) + sync_room_id_set[room_id] = convert_event_to_rooms_for_user( + last_membership_change_in_from_to_range + ) return sync_room_id_set async def sort_rooms( self, - room_id_set: AbstractSet[str], + sync_room_map: Dict[str, RoomsForUser], to_token: StreamToken, ) -> List[str]: """ @@ -450,11 +491,39 @@ class SlidingSyncHandler: a stable sort, we tie-break by room ID. Args: - room_id_set: Set of room IDs to sort - to_token: We sort based on the events in the room at this token + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. + to_token: We sort based on the events in the room at this token (<= `to_token`) """ - # TODO: `get_last_event_in_room_before_stream_ordering()` - # TODO: Handle when people are left/banned from the room and shouldn't see past that point + # Assemble a map of room ID to the `stream_ordering` of the last activity that the + # user should see in the room (<= `to_token`) + last_activity_in_room_map: Dict[str, int] = {} + for room_id, room_for_user in sync_room_map.items(): + # If they are fully-joined to the room, let's find the latest activity + # at/before the `to_token`. + if room_for_user.membership == Membership.JOIN: + last_event_result = ( + await self.store.get_last_event_in_room_before_stream_ordering( + room_id, to_token.room_key + ) + ) - return list(room_id_set) + # If the room has no events at/before the `to_token`, this is probably a + # mistake in the code that generates the `sync_room_map` since that should + # only give us rooms that the user had membership in during the token range. + assert last_event_result is not None + + _, event_pos = last_event_result + + last_activity_in_room_map[room_id] = event_pos.stream + else: + # Otherwise, if the user left/banned from the room, they shouldn't see + # past that point. (same for invites/knocks) + last_activity_in_room_map[room_id] = room_for_user.event_pos.stream + + return sorted( + sync_room_map.keys(), + # Sort by the last activity (stream_ordering) in the room, tie-break on room_id + key=lambda room_id: (last_activity_in_room_map[room_id], room_id), + ) diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index 6389afdd60..354d48c41d 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -78,7 +78,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) ) - self.assertEqual(room_id_results, set()) + self.assertEqual(room_id_results.keys(), set()) def test_get_newly_joined_room(self) -> None: """ @@ -102,7 +102,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) ) - self.assertEqual(room_id_results, {room_id}) + self.assertEqual(room_id_results.keys(), {room_id}) def test_get_already_joined_room(self) -> None: """ @@ -123,7 +123,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) ) - self.assertEqual(room_id_results, {room_id}) + self.assertEqual(room_id_results.keys(), {room_id}) def test_get_invited_banned_knocked_room(self) -> None: """ @@ -179,7 +179,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): # Ensure that the invited, ban, and knock rooms show up self.assertEqual( - room_id_results, + room_id_results.keys(), { invited_room_id, ban_room_id, @@ -225,7 +225,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # The kicked room should show up - self.assertEqual(room_id_results, {kick_room_id}) + self.assertEqual(room_id_results.keys(), {kick_room_id}) def test_forgotten_rooms(self) -> None: """ @@ -307,7 +307,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # We shouldn't see the room because it was forgotten - self.assertEqual(room_id_results, set()) + self.assertEqual(room_id_results.keys(), set()) def test_only_newly_left_rooms_show_up(self) -> None: """ @@ -339,7 +339,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Only the newly_left room should show up - self.assertEqual(room_id_results, {room_id2}) + self.assertEqual(room_id_results.keys(), {room_id2}) def test_no_joins_after_to_token(self) -> None: """ @@ -367,7 +367,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) ) - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_join_during_range_and_left_room_after_to_token(self) -> None: """ @@ -397,7 +397,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): # We should still see the room because we were joined during the # from_token/to_token time period. - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_join_before_range_and_left_room_after_to_token(self) -> None: """ @@ -424,7 +424,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # We should still see the room because we were joined before the `from_token` - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_kicked_before_range_and_left_after_to_token(self) -> None: """ @@ -472,7 +472,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # We shouldn't see the room because it was forgotten - self.assertEqual(room_id_results, {kick_room_id}) + self.assertEqual(room_id_results.keys(), {kick_room_id}) def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None: """ @@ -509,7 +509,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Room should still show up because it's newly_left during the from/to range - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_newly_left_during_range_and_join_after_to_token(self) -> None: """ @@ -545,7 +545,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Room should still show up because it's newly_left during the from/to range - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_no_from_token(self) -> None: """ @@ -586,7 +586,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Only rooms we were joined to before the `to_token` should show up - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_from_token_ahead_of_to_token(self) -> None: """ @@ -647,7 +647,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): # # There won't be any newly_left rooms because the `from_token` is ahead of the # `to_token` and that range will give no membership changes to check. - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_leave_before_range_and_join_leave_after_to_token(self) -> None: """ @@ -682,7 +682,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Room shouldn't show up because it was left before the `from_token` - self.assertEqual(room_id_results, set()) + self.assertEqual(room_id_results.keys(), set()) def test_leave_before_range_and_join_after_to_token(self) -> None: """ @@ -716,7 +716,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Room shouldn't show up because it was left before the `from_token` - self.assertEqual(room_id_results, set()) + self.assertEqual(room_id_results.keys(), set()) def test_join_leave_multiple_times_during_range_and_after_to_token( self, @@ -758,7 +758,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Room should show up because it was newly_left and joined during the from/to range - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_join_leave_multiple_times_before_range_and_after_to_token( self, @@ -798,7 +798,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Room should show up because we were joined before the from/to range - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_invite_before_range_and_join_leave_after_to_token( self, @@ -835,7 +835,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Room should show up because we were invited before the from/to range - self.assertEqual(room_id_results, {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1}) def test_multiple_rooms_are_not_confused( self, @@ -888,7 +888,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) self.assertEqual( - room_id_results, + room_id_results.keys(), { # `room_id1` shouldn't show up because we left before the from/to range # @@ -1099,7 +1099,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase): ) self.assertEqual( - room_id_results, + room_id_results.keys(), { room_id1, # room_id2 shouldn't show up because we left before the from/to range