Add Sliding Sync /sync/e2ee endpoint for To-Device messages

Based on:

 - MSC3575: Sliding Sync (aka Sync v3): https://github.com/matrix-org/matrix-spec-proposals/pull/3575
 - MSC3885: Sliding Sync Extension: To-Device messages: https://github.com/matrix-org/matrix-spec-proposals/pull/3885
 - MSC3884: Sliding Sync Extension: E2EE: https://github.com/matrix-org/matrix-spec-proposals/pull/3884
This commit is contained in:
Eric Eastwood 2024-05-07 18:16:35 -05:00
parent f9e6e53130
commit 1e05a05f03
4 changed files with 276 additions and 25 deletions

View file

@ -18,6 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from enum import Enum
import itertools import itertools
import logging import logging
from typing import ( from typing import (
@ -112,12 +113,21 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
SyncRequestKey = Tuple[Any, ...] SyncRequestKey = Tuple[Any, ...]
class SyncType(Enum):
"""Enum for specifying the type of sync request."""
# These string values are semantically significant and are used in the the metrics
INITIAL_SYNC = "initial_sync"
FULL_STATE_SYNC = "full_state_sync"
INCREMENTAL_SYNC = "incremental_sync"
E2EE_SYNC = "e2ee_sync"
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncConfig: class SyncConfig:
user: UserID user: UserID
filter_collection: FilterCollection filter_collection: FilterCollection
is_guest: bool is_guest: bool
request_key: SyncRequestKey
device_id: Optional[str] device_id: Optional[str]
@ -263,6 +273,15 @@ class SyncResult:
) )
@attr.s(slots=True, frozen=True, auto_attribs=True)
class E2eeSyncResult:
next_batch: StreamToken
to_device: List[JsonDict]
# device_lists: DeviceListUpdates
# device_one_time_keys_count: JsonMapping
# device_unused_fallback_key_types: List[str]
class SyncHandler: class SyncHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config self.hs_config = hs.config
@ -309,6 +328,8 @@ class SyncHandler:
self, self,
requester: Requester, requester: Requester,
sync_config: SyncConfig, sync_config: SyncConfig,
sync_type: SyncType,
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
timeout: int = 0, timeout: int = 0,
full_state: bool = False, full_state: bool = False,
@ -316,6 +337,9 @@ class SyncHandler:
"""Get the sync for a client if we have new data for it now. Otherwise """Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result. return an empty sync result.
Args:
request_key: The key to use for caching the response.
""" """
# If the user is not part of the mau group, then check that limits have # If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain # not been exceeded (if not part of the group by this point, almost certain
@ -324,9 +348,10 @@ class SyncHandler:
await self.auth_blocking.check_auth_blocking(requester=requester) await self.auth_blocking.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap( res = await self.response_cache.wrap(
sync_config.request_key, request_key,
self._wait_for_sync_for_user, self._wait_for_sync_for_user,
sync_config, sync_config,
sync_type,
since_token, since_token,
timeout, timeout,
full_state, full_state,
@ -338,6 +363,7 @@ class SyncHandler:
async def _wait_for_sync_for_user( async def _wait_for_sync_for_user(
self, self,
sync_config: SyncConfig, sync_config: SyncConfig,
sync_type: SyncType,
since_token: Optional[StreamToken], since_token: Optional[StreamToken],
timeout: int, timeout: int,
full_state: bool, full_state: bool,
@ -356,13 +382,6 @@ class SyncHandler:
Computing the body of the response begins in the next method, Computing the body of the response begins in the next method,
`current_sync_for_user`. `current_sync_for_user`.
""" """
if since_token is None:
sync_type = "initial_sync"
elif full_state:
sync_type = "full_state_sync"
else:
sync_type = "incremental_sync"
context = current_context() context = current_context()
if context: if context:
context.tag = sync_type context.tag = sync_type
@ -384,14 +403,16 @@ class SyncHandler:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
result: SyncResult = await self.current_sync_for_user( result: SyncResult = await self.current_sync_for_user(
sync_config, since_token, full_state=full_state sync_config, sync_type, since_token, full_state=full_state
) )
else: else:
# Otherwise, we wait for something to happen and report it to the user. # Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback( async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken before_token: StreamToken, after_token: StreamToken
) -> SyncResult: ) -> SyncResult:
return await self.current_sync_for_user(sync_config, since_token) return await self.current_sync_for_user(
sync_config, sync_type, since_token
)
result = await self.notifier.wait_for_events( result = await self.notifier.wait_for_events(
sync_config.user.to_string(), sync_config.user.to_string(),
@ -423,6 +444,7 @@ class SyncHandler:
async def current_sync_for_user( async def current_sync_for_user(
self, self,
sync_config: SyncConfig, sync_config: SyncConfig,
sync_type: SyncType,
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
full_state: bool = False, full_state: bool = False,
) -> SyncResult: ) -> SyncResult:
@ -434,9 +456,25 @@ class SyncHandler:
""" """
with start_active_span("sync.current_sync_for_user"): with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token}) log_kv({"since_token": since_token})
sync_result = await self.generate_sync_result(
sync_config, since_token, full_state # Go through the `/sync` v2 path
) if sync_type in {
SyncType.INITIAL_SYNC,
SyncType.FULL_STATE_SYNC,
SyncType.INCREMENTAL_SYNC,
}:
sync_result = await self.generate_sync_result(
sync_config, since_token, full_state
)
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
elif sync_type == SyncType.E2EE_SYNC:
sync_result = await self.generate_e2ee_sync_result(
sync_config, since_token
)
else:
raise Exception(
f"Unknown sync_type (this is a Synapse problem): {sync_type}"
)
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result)) set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result return sync_result
@ -1751,6 +1789,50 @@ class SyncHandler:
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
) )
async def generate_e2ee_sync_result(
self,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
) -> SyncResult:
"""Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result."""
user_id = sync_config.user.to_string()
# TODO: Should we exclude app services here? There could be an argument to allow
# them since the appservice doesn't have to make a massive initial sync.
# (related to https://github.com/matrix-org/matrix-doc/issues/1144)
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder`
now_token = self.event_sources.get_current_token()
log_kv({"now_token": now_token})
joined_room_ids = await self.store.get_rooms_for_user(user_id)
sync_result_builder = SyncResultBuilder(
sync_config,
full_state=False,
since_token=since_token,
now_token=now_token,
joined_room_ids=joined_room_ids,
# Dummy values to fill out `SyncResultBuilder`
excluded_room_ids=frozenset({}),
forced_newly_joined_room_ids=frozenset({}),
membership_change_events=frozenset({}),
)
await self._generate_sync_entry_for_to_device(sync_result_builder)
return E2eeSyncResult(
to_device=sync_result_builder.to_device,
# to_device: List[JsonDict]
# device_lists: DeviceListUpdates
# device_one_time_keys_count: JsonMapping
# device_unused_fallback_key_types: List[str]
next_batch=sync_result_builder.now_token,
)
@measure_func("_generate_sync_entry_for_device_list") @measure_func("_generate_sync_entry_for_device_list")
async def _generate_sync_entry_for_device_list( async def _generate_sync_entry_for_device_list(
self, self,

View file

@ -41,6 +41,7 @@ from synapse.handlers.sync import (
KnockedSyncResult, KnockedSyncResult,
SyncConfig, SyncConfig,
SyncResult, SyncResult,
SyncType,
) )
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
@ -198,7 +199,6 @@ class SyncRestServlet(RestServlet):
user=user, user=user,
filter_collection=filter_collection, filter_collection=filter_collection,
is_guest=requester.is_guest, is_guest=requester.is_guest,
request_key=request_key,
device_id=device_id, device_id=device_id,
) )
@ -206,6 +206,13 @@ class SyncRestServlet(RestServlet):
if since is not None: if since is not None:
since_token = await StreamToken.from_string(self.store, since) since_token = await StreamToken.from_string(self.store, since)
if since_token is None:
sync_type = SyncType.INITIAL_SYNC
elif full_state:
sync_type = SyncType.FULL_STATE_SYNC
else:
sync_type = SyncType.INCREMENTAL_SYNC
# send any outstanding server notices to the user. # send any outstanding server notices to the user.
await self._server_notices_sender.on_user_syncing(user.to_string()) await self._server_notices_sender.on_user_syncing(user.to_string())
@ -221,6 +228,8 @@ class SyncRestServlet(RestServlet):
sync_result = await self.sync_handler.wait_for_sync_for_user( sync_result = await self.sync_handler.wait_for_sync_for_user(
requester, requester,
sync_config, sync_config,
sync_type,
request_key,
since_token=since_token, since_token=since_token,
timeout=timeout, timeout=timeout,
full_state=full_state, full_state=full_state,
@ -554,27 +563,111 @@ class SyncRestServlet(RestServlet):
return result return result
class SlidingSyncRestServlet(RestServlet): class SlidingSyncE2eeRestServlet(RestServlet):
""" """
API endpoint TODO API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
Useful for cases like TODO of Sliding Sync but doesn't have any sliding window component. It's just a way to
get E2EE events without having to sit through a initial sync. And not have
encryption events backed up by the main sync response.
GET parameters::
timeout(int): How long to wait for new events in milliseconds.
since(batch_token): Batch token when asking for incremental deltas.
Response JSON::
{
"next_batch": // batch token for the next /sync
"to_device": {
// list of to-device events
"events": [
{
"content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
"type": "m.room.encrypted",
"sender": "@alice:example.com",
}
// ...
]
},
"device_one_time_keys_count": {
"signed_curve25519": 50
},
"device_lists": {
"changed": ["@alice:example.com"],
"left": ["@bob:example.com"]
},
"device_unused_fallback_key_types": [
"signed_curve25519"
]
}
""" """
PATTERNS = (re.compile("^/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee$"),) PATTERNS = (re.compile("^/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee$"),)
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self._auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.filtering = hs.get_filtering()
self.sync_handler = hs.get_sync_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, { requester = await self.auth.get_user_by_req(request, allow_guest=True)
"foo": "bar", user = requester.user
} device_id = requester.device_id
timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since")
sync_config = SyncConfig(
user=user,
# Filtering doesn't apply to this endpoint so just use a default to fill in
# the SyncConfig
filter_collection=self.filtering.DEFAULT_FILTER_COLLECTION,
is_guest=requester.is_guest,
device_id=device_id,
)
sync_type = SyncType.E2EE_SYNC
since_token = None
if since is not None:
since_token = await StreamToken.from_string(self.store, since)
# Request cache key
request_key = (
sync_type,
user,
timeout,
since,
)
# Gather data for the response
sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
sync_type,
request_key,
since_token=since_token,
timeout=timeout,
full_state=False,
)
# The client may have disconnected by now; don't bother to serialize the
# response if so.
if request._disconnected:
logger.info("Client has disconnected; not serializing response.")
return 200, {}
response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.to_device:
response["to_device"] = {"events": sync_result.to_device}
return 200, response
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server) SyncRestServlet(hs).register(http_server)
if hs.config.experimental.msc3575_enabled: if hs.config.experimental.msc3575_enabled:
SlidingSyncRestServlet(hs).register(http_server) SlidingSyncE2eeRestServlet(hs).register(http_server)

View file

@ -67,7 +67,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
} }
self.assertEqual(channel.json_body["to_device"], expected_result) self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync # it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request("GET", "/sync", access_token=user2_tok) channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result) self.assertEqual(channel.json_body["to_device"], expected_result)
@ -99,7 +101,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
) )
self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three # now sync: we should get two of the three (because burst_count=2)
channel = self.make_request("GET", "/sync", access_token=user2_tok) channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"] msgs = channel.json_body["to_device"]["events"]

View file

@ -0,0 +1,74 @@
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
class SendToDeviceTestCase(HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
sendtodevice.register_servlets,
sync.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc3575_enabled": True}
return config
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send the message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
{
"sender": user1,
"type": "m.test",
"content": test_msg,
}
]
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])