From 76ce7a903413e56e2c4ca46e4c61c273a872fc95 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 6 Jun 2024 15:19:21 -0500 Subject: [PATCH] Add `is_dm` filtering to Sliding Sync `/sync` Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync --- changelog.d/17244.feature | 1 + synapse/handlers/sliding_sync.py | 85 ++++++++++++++- tests/handlers/test_sliding_sync.py | 156 +++++++++++++++++++++++++++- tests/rest/client/test_sync.py | 127 ++++++++++++++++++++++ 4 files changed, 363 insertions(+), 6 deletions(-) create mode 100644 changelog.d/17244.feature diff --git a/changelog.d/17244.feature b/changelog.d/17244.feature new file mode 100644 index 0000000000..5c16342c11 --- /dev/null +++ b/changelog.d/17244.feature @@ -0,0 +1 @@ +Add `is_dm` filtering to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 34ae21ba50..08c6aadff6 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -31,7 +31,7 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2: else: from pydantic import Extra -from synapse.api.constants import Membership +from synapse.api.constants import AccountDataTypes, Membership from synapse.events import EventBase from synapse.rest.client.models import SlidingSyncBody from synapse.types import JsonMapping, Requester, RoomStreamToken, StreamToken, UserID @@ -332,11 +332,15 @@ class SlidingSyncHandler: lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {} if sync_config.lists: for list_key, list_config in sync_config.lists.items(): - # TODO: Apply filters - # - # TODO: Exclude partially stated rooms unless the `required_state` has - # `["m.room.member", "$LAZY"]` + # Apply filters filtered_room_ids = room_id_set + if list_config.filters is not None: + # TODO: To be absolutely correct, this could also take into account + # from/to tokens but some of the streams don't support looking back + # in time (like global account_data). + filtered_room_ids = await self.filter_rooms( + sync_config.user, room_id_set, list_config.filters + ) # TODO: Apply sorts sorted_room_ids = sorted(filtered_room_ids) @@ -608,3 +612,74 @@ class SlidingSyncHandler: sync_room_id_set.add(room_id) return sync_room_id_set + + async def filter_rooms( + self, + user: UserID, + room_id_set: AbstractSet[str], + filters: SlidingSyncConfig.SlidingSyncList.Filters, + ) -> AbstractSet[str]: + """ + Filter rooms based on the sync request. + """ + user_id = user.to_string() + + # TODO: Apply filters + # + # TODO: Exclude partially stated rooms unless the `required_state` has + # `["m.room.member", "$LAZY"]` + + filtered_room_id_set = set(room_id_set) + + # Filter for Direct-Message (DM) rooms + if filters.is_dm is not None: + # We're using global account data (`m.direct`) instead of checking for + # `is_direct` on membership events because that property only appears for + # the invitee membership event (doesn't show up for the inviter). Account + # data is set by the client so it needs to be scrutinized. + dm_map = await self.store.get_global_account_data_by_type_for_user( + user_id, AccountDataTypes.DIRECT + ) + logger.warn("dm_map: %s", dm_map) + # Flatten out the map + dm_room_id_set = set() + if dm_map: + for room_ids in dm_map.values(): + # Account data should be a list of room IDs. Ignore anything else + if isinstance(room_ids, list): + for room_id in room_ids: + if isinstance(room_id, str): + dm_room_id_set.add(room_id) + + if filters.is_dm: + # Only DM rooms please + filtered_room_id_set = filtered_room_id_set.intersection(dm_room_id_set) + else: + # Only non-DM rooms please + filtered_room_id_set = filtered_room_id_set.difference(dm_room_id_set) + + if filters.spaces: + raise NotImplementedError() + + if filters.is_encrypted: + raise NotImplementedError() + + if filters.is_invite: + raise NotImplementedError() + + if filters.room_types: + raise NotImplementedError() + + if filters.not_room_types: + raise NotImplementedError() + + if filters.room_name_like: + raise NotImplementedError() + + if filters.tags: + raise NotImplementedError() + + if filters.not_tags: + raise NotImplementedError() + + return filtered_room_id_set diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index 5c27474b96..220683b9d6 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -22,7 +22,7 @@ from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.rest import admin from synapse.rest.client import knock, login, room @@ -1116,3 +1116,157 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase): room_id3, }, ) + + +class FilterRoomsTestCase(HomeserverTestCase): + """ + Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms + correctly. + """ + + servlets = [ + admin.register_servlets, + knock.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sliding_sync_handler = self.hs.get_sliding_sync_handler() + self.store = self.hs.get_datastores().main + + def _create_dm_room( + self, + inviter_user_id: str, + inviter_tok: str, + invitee_user_id: str, + invitee_tok: str, + ) -> str: + """ + Helper to create a DM room as the "inviter" and invite the "invitee" user to the room. The + "invitee" user also will join the room. The `m.direct` account data will be set + for both users. + """ + + # Create a room and send an invite the other user + room_id = self.helper.create_room_as( + inviter_user_id, + is_public=False, + tok=inviter_tok, + ) + self.helper.invite( + room_id, + src=inviter_user_id, + targ=invitee_user_id, + tok=inviter_tok, + extra_data={"is_direct": True}, + ) + # Person that was invited joins the room + self.helper.join(room_id, invitee_user_id, tok=invitee_tok) + + # Mimic the client setting the room as a direct message in the global account + # data + self.get_success( + self.store.add_account_data_for_user( + invitee_user_id, + AccountDataTypes.DIRECT, + {inviter_user_id: [room_id]}, + ) + ) + self.get_success( + self.store.add_account_data_for_user( + inviter_user_id, + AccountDataTypes.DIRECT, + {invitee_user_id: [room_id]}, + ) + ) + + return room_id + + def test_filter_dm_rooms(self) -> None: + """ + Test filter for DM rooms + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create a normal room + room_id = self.helper.create_room_as( + user1_id, + is_public=False, + tok=user1_tok, + ) + + # Create a DM room + dm_room_id = self._create_dm_room( + inviter_user_id=user1_id, + inviter_tok=user1_tok, + invitee_user_id=user2_id, + invitee_tok=user2_tok, + ) + + # TODO: Better way to avoid the circular import? (see + # https://github.com/element-hq/synapse/pull/17187#discussion_r1619492779) + from synapse.handlers.sliding_sync import SlidingSyncConfig + + filters = SlidingSyncConfig.SlidingSyncList.Filters( + is_dm=True, + ) + + # Try filtering the rooms + filtered_room_ids = self.get_success( + self.sliding_sync_handler.filter_rooms( + UserID.from_string(user1_id), {room_id, dm_room_id}, filters + ) + ) + + self.assertEqual(filtered_room_ids, {dm_room_id}) + + def test_filter_non_dm_rooms(self) -> None: + """ + Test filter for non-DM rooms + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create a normal room + room_id = self.helper.create_room_as( + user1_id, + is_public=False, + tok=user1_tok, + ) + + # Create a DM room + dm_room_id = self._create_dm_room( + inviter_user_id=user1_id, + inviter_tok=user1_tok, + invitee_user_id=user2_id, + invitee_tok=user2_tok, + ) + + # TODO: Better way to avoid the circular import? (see + # https://github.com/element-hq/synapse/pull/17187#discussion_r1619492779) + from synapse.handlers.sliding_sync import SlidingSyncConfig + + filters = SlidingSyncConfig.SlidingSyncList.Filters( + is_dm=False, + ) + + # Try filtering the rooms + filtered_room_ids = self.get_success( + self.sliding_sync_handler.filter_rooms( + UserID.from_string(user1_id), {room_id, dm_room_id}, filters + ) + ) + + self.assertEqual(filtered_room_ids, {room_id}) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index a20a3fb40d..40870b2cfe 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -27,6 +27,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( + AccountDataTypes, EventContentFields, EventTypes, ReceiptTypes, @@ -1226,10 +1227,59 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync" self.store = hs.get_datastores().main self.event_sources = hs.get_event_sources() + def _create_dm_room( + self, + inviter_user_id: str, + inviter_tok: str, + invitee_user_id: str, + invitee_tok: str, + ) -> str: + """ + Helper to create a DM room as the "inviter" and invite the "invitee" user to the + room. The "invitee" user also will join the room. The `m.direct` account data + will be set for both users. + """ + + # Create a room and send an invite the other user + room_id = self.helper.create_room_as( + inviter_user_id, + is_public=False, + tok=inviter_tok, + ) + self.helper.invite( + room_id, + src=inviter_user_id, + targ=invitee_user_id, + tok=inviter_tok, + extra_data={"is_direct": True}, + ) + # Person that was invited joins the room + self.helper.join(room_id, invitee_user_id, tok=invitee_tok) + + # Mimic the client setting the room as a direct message in the global account + # data + self.get_success( + self.store.add_account_data_for_user( + invitee_user_id, + AccountDataTypes.DIRECT, + {inviter_user_id: [room_id]}, + ) + ) + self.get_success( + self.store.add_account_data_for_user( + inviter_user_id, + AccountDataTypes.DIRECT, + {invitee_user_id: [room_id]}, + ) + ) + + return room_id + def test_sync_list(self) -> None: """ Test that room IDs show up in the Sliding Sync lists @@ -1336,3 +1386,80 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.assertEqual( channel.json_body["next_pos"], future_position_token_serialized ) + + def test_filter_list(self) -> None: + """ + Test that filters apply to lists + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create a DM room + dm_room_id = self._create_dm_room( + inviter_user_id=user1_id, + inviter_tok=user1_tok, + invitee_user_id=user2_id, + invitee_tok=user2_tok, + ) + + # Create a normal room + room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "dms": { + "ranges": [[0, 99]], + "sort": ["by_recency"], + "required_state": [], + "timeline_limit": 1, + "filters": {"is_dm": True}, + }, + "foo-list": { + "ranges": [[0, 99]], + "sort": ["by_recency"], + "required_state": [], + "timeline_limit": 1, + "filters": {"is_dm": False}, + }, + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Make sure it has the foo-list we requested + self.assertListEqual( + list(channel.json_body["lists"].keys()), + ["dms", "foo-list"], + channel.json_body["lists"].keys(), + ) + + # Make sure the list includes the room we are joined to + self.assertListEqual( + list(channel.json_body["lists"]["dms"]["ops"]), + [ + { + "op": "SYNC", + "range": [0, 99], + "room_ids": [dm_room_id], + } + ], + list(channel.json_body["lists"]["dms"]), + ) + self.assertListEqual( + list(channel.json_body["lists"]["foo-list"]["ops"]), + [ + { + "op": "SYNC", + "range": [0, 99], + "room_ids": [room_id], + } + ], + list(channel.json_body["lists"]["foo-list"]), + )