From 1e05a05f03ad4d9f001edc5d2035455314b5126d Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 7 May 2024 18:16:35 -0500 Subject: [PATCH] Add Sliding Sync `/sync/e2ee` endpoint for To-Device messages Based on: - MSC3575: Sliding Sync (aka Sync v3): https://github.com/matrix-org/matrix-spec-proposals/pull/3575 - MSC3885: Sliding Sync Extension: To-Device messages: https://github.com/matrix-org/matrix-spec-proposals/pull/3885 - MSC3884: Sliding Sync Extension: E2EE: https://github.com/matrix-org/matrix-spec-proposals/pull/3884 --- synapse/handlers/sync.py | 110 ++++++++++++++++++++---- synapse/rest/client/sync.py | 111 +++++++++++++++++++++++-- tests/rest/client/test_sendtodevice.py | 6 +- tests/rest/client/test_sliding_sync.py | 74 +++++++++++++++++ 4 files changed, 276 insertions(+), 25 deletions(-) create mode 100644 tests/rest/client/test_sliding_sync.py diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 8ff45a3353..0183393e34 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -18,6 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # +from enum import Enum import itertools import logging from typing import ( @@ -112,12 +113,21 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 SyncRequestKey = Tuple[Any, ...] +class SyncType(Enum): + """Enum for specifying the type of sync request.""" + + # These string values are semantically significant and are used in the the metrics + INITIAL_SYNC = "initial_sync" + FULL_STATE_SYNC = "full_state_sync" + INCREMENTAL_SYNC = "incremental_sync" + E2EE_SYNC = "e2ee_sync" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class SyncConfig: user: UserID filter_collection: FilterCollection is_guest: bool - request_key: SyncRequestKey device_id: Optional[str] @@ -263,6 +273,15 @@ class SyncResult: ) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class E2eeSyncResult: + next_batch: StreamToken + to_device: List[JsonDict] + # device_lists: DeviceListUpdates + # device_one_time_keys_count: JsonMapping + # device_unused_fallback_key_types: List[str] + + class SyncHandler: def __init__(self, hs: "HomeServer"): self.hs_config = hs.config @@ -309,6 +328,8 @@ class SyncHandler: self, requester: Requester, sync_config: SyncConfig, + sync_type: SyncType, + request_key: SyncRequestKey, since_token: Optional[StreamToken] = None, timeout: int = 0, full_state: bool = False, @@ -316,6 +337,9 @@ class SyncHandler: """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. + + Args: + request_key: The key to use for caching the response. """ # If the user is not part of the mau group, then check that limits have # not been exceeded (if not part of the group by this point, almost certain @@ -324,9 +348,10 @@ class SyncHandler: await self.auth_blocking.check_auth_blocking(requester=requester) res = await self.response_cache.wrap( - sync_config.request_key, + request_key, self._wait_for_sync_for_user, sync_config, + sync_type, since_token, timeout, full_state, @@ -338,6 +363,7 @@ class SyncHandler: async def _wait_for_sync_for_user( self, sync_config: SyncConfig, + sync_type: SyncType, since_token: Optional[StreamToken], timeout: int, full_state: bool, @@ -356,13 +382,6 @@ class SyncHandler: Computing the body of the response begins in the next method, `current_sync_for_user`. """ - if since_token is None: - sync_type = "initial_sync" - elif full_state: - sync_type = "full_state_sync" - else: - sync_type = "incremental_sync" - context = current_context() if context: context.tag = sync_type @@ -384,14 +403,16 @@ class SyncHandler: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. result: SyncResult = await self.current_sync_for_user( - sync_config, since_token, full_state=full_state + sync_config, sync_type, since_token, full_state=full_state ) else: # Otherwise, we wait for something to happen and report it to the user. async def current_sync_callback( before_token: StreamToken, after_token: StreamToken ) -> SyncResult: - return await self.current_sync_for_user(sync_config, since_token) + return await self.current_sync_for_user( + sync_config, sync_type, since_token + ) result = await self.notifier.wait_for_events( sync_config.user.to_string(), @@ -423,6 +444,7 @@ class SyncHandler: async def current_sync_for_user( self, sync_config: SyncConfig, + sync_type: SyncType, since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: @@ -434,9 +456,25 @@ class SyncHandler: """ with start_active_span("sync.current_sync_for_user"): log_kv({"since_token": since_token}) - sync_result = await self.generate_sync_result( - sync_config, since_token, full_state - ) + + # Go through the `/sync` v2 path + if sync_type in { + SyncType.INITIAL_SYNC, + SyncType.FULL_STATE_SYNC, + SyncType.INCREMENTAL_SYNC, + }: + sync_result = await self.generate_sync_result( + sync_config, since_token, full_state + ) + # Go through the MSC3575 Sliding Sync `/sync/e2ee` path + elif sync_type == SyncType.E2EE_SYNC: + sync_result = await self.generate_e2ee_sync_result( + sync_config, since_token + ) + else: + raise Exception( + f"Unknown sync_type (this is a Synapse problem): {sync_type}" + ) set_tag(SynapseTags.SYNC_RESULT, bool(sync_result)) return sync_result @@ -1751,6 +1789,50 @@ class SyncHandler: next_batch=sync_result_builder.now_token, ) + async def generate_e2ee_sync_result( + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + ) -> SyncResult: + """Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.""" + + user_id = sync_config.user.to_string() + # TODO: Should we exclude app services here? There could be an argument to allow + # them since the appservice doesn't have to make a massive initial sync. + # (related to https://github.com/matrix-org/matrix-doc/issues/1144) + + # NB: The now_token gets changed by some of the generate_sync_* methods, + # this is due to some of the underlying streams not supporting the ability + # to query up to a given point. + # Always use the `now_token` in `SyncResultBuilder` + now_token = self.event_sources.get_current_token() + log_kv({"now_token": now_token}) + + joined_room_ids = await self.store.get_rooms_for_user(user_id) + + sync_result_builder = SyncResultBuilder( + sync_config, + full_state=False, + since_token=since_token, + now_token=now_token, + joined_room_ids=joined_room_ids, + # Dummy values to fill out `SyncResultBuilder` + excluded_room_ids=frozenset({}), + forced_newly_joined_room_ids=frozenset({}), + membership_change_events=frozenset({}), + ) + + await self._generate_sync_entry_for_to_device(sync_result_builder) + + return E2eeSyncResult( + to_device=sync_result_builder.to_device, + # to_device: List[JsonDict] + # device_lists: DeviceListUpdates + # device_one_time_keys_count: JsonMapping + # device_unused_fallback_key_types: List[str] + next_batch=sync_result_builder.now_token, + ) + @measure_func("_generate_sync_entry_for_device_list") async def _generate_sync_entry_for_device_list( self, diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 2acff4494a..3b09b20dc7 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -41,6 +41,7 @@ from synapse.handlers.sync import ( KnockedSyncResult, SyncConfig, SyncResult, + SyncType, ) from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string @@ -198,7 +199,6 @@ class SyncRestServlet(RestServlet): user=user, filter_collection=filter_collection, is_guest=requester.is_guest, - request_key=request_key, device_id=device_id, ) @@ -206,6 +206,13 @@ class SyncRestServlet(RestServlet): if since is not None: since_token = await StreamToken.from_string(self.store, since) + if since_token is None: + sync_type = SyncType.INITIAL_SYNC + elif full_state: + sync_type = SyncType.FULL_STATE_SYNC + else: + sync_type = SyncType.INCREMENTAL_SYNC + # send any outstanding server notices to the user. await self._server_notices_sender.on_user_syncing(user.to_string()) @@ -221,6 +228,8 @@ class SyncRestServlet(RestServlet): sync_result = await self.sync_handler.wait_for_sync_for_user( requester, sync_config, + sync_type, + request_key, since_token=since_token, timeout=timeout, full_state=full_state, @@ -554,27 +563,111 @@ class SyncRestServlet(RestServlet): return result -class SlidingSyncRestServlet(RestServlet): +class SlidingSyncE2eeRestServlet(RestServlet): """ - API endpoint TODO - Useful for cases like TODO + API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part + of Sliding Sync but doesn't have any sliding window component. It's just a way to + get E2EE events without having to sit through a initial sync. And not have + encryption events backed up by the main sync response. + + GET parameters:: + timeout(int): How long to wait for new events in milliseconds. + since(batch_token): Batch token when asking for incremental deltas. + + Response JSON:: + { + "next_batch": // batch token for the next /sync + "to_device": { + // list of to-device events + "events": [ + { + "content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" }, + "type": "m.room.encrypted", + "sender": "@alice:example.com", + } + // ... + ] + }, + "device_one_time_keys_count": { + "signed_curve25519": 50 + }, + "device_lists": { + "changed": ["@alice:example.com"], + "left": ["@bob:example.com"] + }, + "device_unused_fallback_key_types": [ + "signed_curve25519" + ] + } """ PATTERNS = (re.compile("^/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee$"),) def __init__(self, hs: "HomeServer"): super().__init__() - self._auth = hs.get_auth() + self.auth = hs.get_auth() self.store = hs.get_datastores().main + self.filtering = hs.get_filtering() + self.sync_handler = hs.get_sync_handler() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - return 200, { - "foo": "bar", - } + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = requester.user + device_id = requester.device_id + + timeout = parse_integer(request, "timeout", default=0) + since = parse_string(request, "since") + + sync_config = SyncConfig( + user=user, + # Filtering doesn't apply to this endpoint so just use a default to fill in + # the SyncConfig + filter_collection=self.filtering.DEFAULT_FILTER_COLLECTION, + is_guest=requester.is_guest, + device_id=device_id, + ) + sync_type = SyncType.E2EE_SYNC + + since_token = None + if since is not None: + since_token = await StreamToken.from_string(self.store, since) + + # Request cache key + request_key = ( + sync_type, + user, + timeout, + since, + ) + + # Gather data for the response + sync_result = await self.sync_handler.wait_for_sync_for_user( + requester, + sync_config, + sync_type, + request_key, + since_token=since_token, + timeout=timeout, + full_state=False, + ) + + # The client may have disconnected by now; don't bother to serialize the + # response if so. + if request._disconnected: + logger.info("Client has disconnected; not serializing response.") + return 200, {} + + response: JsonDict = defaultdict(dict) + response["next_batch"] = await sync_result.next_batch.to_string(self.store) + + if sync_result.to_device: + response["to_device"] = {"events": sync_result.to_device} + + return 200, response def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) if hs.config.experimental.msc3575_enabled: - SlidingSyncRestServlet(hs).register(http_server) + SlidingSyncE2eeRestServlet(hs).register(http_server) diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index 2f994ad553..48decccf38 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -67,7 +67,9 @@ class SendToDeviceTestCase(HomeserverTestCase): } self.assertEqual(channel.json_body["to_device"], expected_result) - # it should re-appear if we do another sync + # it should re-appear if we do another sync because the to-device message is not + # deleted until we acknowledge it by sending a `?since=...` parameter in the + # next sync request corresponding to the `next_batch` value from the response. channel = self.make_request("GET", "/sync", access_token=user2_tok) self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.json_body["to_device"], expected_result) @@ -99,7 +101,7 @@ class SendToDeviceTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 200, chan.result) - # now sync: we should get two of the three + # now sync: we should get two of the three (because burst_count=2) channel = self.make_request("GET", "/sync", access_token=user2_tok) self.assertEqual(channel.code, 200, channel.result) msgs = channel.json_body["to_device"]["events"] diff --git a/tests/rest/client/test_sliding_sync.py b/tests/rest/client/test_sliding_sync.py new file mode 100644 index 0000000000..59c47a175a --- /dev/null +++ b/tests/rest/client/test_sliding_sync.py @@ -0,0 +1,74 @@ +from synapse.api.constants import EduTypes +from synapse.rest import admin +from synapse.rest.client import login, sendtodevice, sync +from synapse.types import JsonDict + +from tests.unittest import HomeserverTestCase, override_config + + +class SendToDeviceTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + sendtodevice.register_servlets, + sync.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def test_user_to_user(self) -> None: + """A to-device message from one user to another should get delivered""" + + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send the message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # check it appears + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + expected_result = { + "events": [ + { + "sender": user1, + "type": "m.test", + "content": test_msg, + } + ] + } + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should re-appear if we do another sync because the to-device message is not + # deleted until we acknowledge it by sending a `?since=...` parameter in the + # next sync request corresponding to the `next_batch` value from the response. + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should *not* appear if we do an incremental sync + sync_token = channel.json_body["next_batch"] + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee?since={sync_token}", + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])