From 35808b38db7e89a9be341353b19cf14407e379a6 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 13 Jun 2024 14:53:23 -0500 Subject: [PATCH] Fix filtering --- synapse/handlers/sliding_sync.py | 28 ++++++++++++++++++---------- tests/handlers/test_sliding_sync.py | 23 +++++++++++++++++------ 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index ff0d3ba4f3..67471ac362 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -18,7 +18,7 @@ # # import logging -from typing import TYPE_CHECKING, AbstractSet, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from immutabledict import immutabledict @@ -210,14 +210,16 @@ class SlidingSyncHandler: for list_key, list_config in sync_config.lists.items(): # Apply filters - filtered_room_ids = room_id_set + filtered_sync_room_map = sync_room_map if list_config.filters is not None: - filtered_room_ids = await self.filter_rooms( - sync_config.user, room_id_set, list_config.filters, to_token + filtered_sync_room_map = await self.filter_rooms( + sync_config.user, sync_room_map, list_config.filters, to_token ) # TODO: Apply sorts - sorted_room_info = await self.sort_rooms(filtered_room_map, to_token) + sorted_room_info = await self.sort_rooms( + filtered_sync_room_map, to_token + ) ops: List[SlidingSyncResult.SlidingWindowList.Operation] = [] if list_config.ranges: @@ -509,18 +511,23 @@ class SlidingSyncHandler: async def filter_rooms( self, user: UserID, - room_id_set: AbstractSet[str], + sync_room_map: Dict[str, RoomsForUser], filters: SlidingSyncConfig.SlidingSyncList.Filters, to_token: StreamToken, - ) -> AbstractSet[str]: + ) -> Dict[str, RoomsForUser]: """ Filter rooms based on the sync request. Args: user: User to filter rooms for - room_id_set: Set of room IDs to filter down + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. filters: Filters to apply to_token: We filter based on the state of the room at this token + + Returns: + A filtered dictionary of room IDs along with membership information in the + room at the time of `to_token`. """ user_id = user.to_string() @@ -529,7 +536,7 @@ class SlidingSyncHandler: # TODO: Exclude partially stated rooms unless the `required_state` has # `["m.room.member", "$LAZY"]` - filtered_room_id_set = set(room_id_set) + filtered_room_id_set = set(sync_room_map.keys()) # Filter for Direct-Message (DM) rooms if filters.is_dm is not None: @@ -585,7 +592,8 @@ class SlidingSyncHandler: if filters.not_tags: raise NotImplementedError() - return filtered_room_id_set + # Assemble a new sync room map but only with the `filtered_room_id_set` + return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set} async def sort_rooms( self, diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index 755086e43d..f6b6d579ae 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -1216,11 +1216,20 @@ class FilterRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() + # Get the rooms the user should be syncing with + sync_room_map = self.get_success( + self.sliding_sync_handler.get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, + ) + ) + # Try with `is_dm=True` - truthy_filtered_room_ids = self.get_success( + truthy_filtered_room_map = self.get_success( self.sliding_sync_handler.filter_rooms( UserID.from_string(user1_id), - {room_id, dm_room_id}, + sync_room_map, SlidingSyncConfig.SlidingSyncList.Filters( is_dm=True, ), @@ -1228,13 +1237,13 @@ class FilterRoomsTestCase(HomeserverTestCase): ) ) - self.assertEqual(truthy_filtered_room_ids, {dm_room_id}) + self.assertEqual(truthy_filtered_room_map.keys(), {dm_room_id}) # Try with `is_dm=False` - falsy_filtered_room_ids = self.get_success( + falsy_filtered_room_map = self.get_success( self.sliding_sync_handler.filter_rooms( UserID.from_string(user1_id), - {room_id, dm_room_id}, + sync_room_map, SlidingSyncConfig.SlidingSyncList.Filters( is_dm=False, ), @@ -1242,7 +1251,7 @@ class FilterRoomsTestCase(HomeserverTestCase): ) ) - self.assertEqual(falsy_filtered_room_ids, {room_id}) + self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) class SortRoomsTestCase(HomeserverTestCase): @@ -1287,6 +1296,7 @@ class SortRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() + # Get the rooms the user should be syncing with sync_room_map = self.get_success( self.sliding_sync_handler.get_sync_room_ids_for_user( UserID.from_string(user1_id), @@ -1295,6 +1305,7 @@ class SortRoomsTestCase(HomeserverTestCase): ) ) + # Sort the rooms (what we're testing) sorted_room_info = self.get_success( self.sliding_sync_handler.sort_rooms( sync_room_map=sync_room_map,