Refactor Sync handler to be able to return different sync responses (SyncVersion) (#17200)

Refactor Sync handler to be able to be able to return different sync
responses (`SyncVersion`). Preparation to be able support sync v2 and a
new Sliding Sync `/sync/e2ee` endpoint which returns a subset of sync
v2.

Split upon request:
https://github.com/element-hq/synapse/pull/17167#discussion_r1601497279

Split from https://github.com/element-hq/synapse/pull/17167 where we
will add `SyncVersion.E2EE_SYNC` and a new type of sync response.
This commit is contained in:
Eric Eastwood 2024-05-16 05:36:54 -05:00 committed by GitHub
parent 2359c64dec
commit d2d48cce85
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 128 additions and 25 deletions

1
changelog.d/17200.misc Normal file
View file

@ -0,0 +1 @@
Prepare sync handler to be able to return different sync responses (`SyncVersion`).

View file

@ -20,6 +20,7 @@
# #
import itertools import itertools
import logging import logging
from enum import Enum
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AbstractSet, AbstractSet,
@ -112,6 +113,23 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
SyncRequestKey = Tuple[Any, ...] SyncRequestKey = Tuple[Any, ...]
class SyncVersion(Enum):
"""
Enum for specifying the version of sync request. This is used to key which type of
sync response that we are generating.
This is different than the `sync_type` you might see used in other code below; which
specifies the sub-type sync request (e.g. initial_sync, full_state_sync,
incremental_sync) and is really only relevant for the `/sync` v2 endpoint.
"""
# These string values are semantically significant because they are used in the the
# metrics
# Traditional `/sync` endpoint
SYNC_V2 = "sync_v2"
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncConfig: class SyncConfig:
user: UserID user: UserID
@ -309,6 +327,7 @@ class SyncHandler:
self, self,
requester: Requester, requester: Requester,
sync_config: SyncConfig, sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
timeout: int = 0, timeout: int = 0,
full_state: bool = False, full_state: bool = False,
@ -316,6 +335,17 @@ class SyncHandler:
"""Get the sync for a client if we have new data for it now. Otherwise """Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result. return an empty sync result.
Args:
requester: The user requesting the sync response.
sync_config: Config/info necessary to process the sync request.
sync_version: Determines what kind of sync response to generate.
since_token: The point in the stream to sync from.
timeout: How long to wait for new data to arrive before giving up.
full_state: Whether to return the full state for each room.
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
""" """
# If the user is not part of the mau group, then check that limits have # If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain # not been exceeded (if not part of the group by this point, almost certain
@ -327,6 +357,7 @@ class SyncHandler:
sync_config.request_key, sync_config.request_key,
self._wait_for_sync_for_user, self._wait_for_sync_for_user,
sync_config, sync_config,
sync_version,
since_token, since_token,
timeout, timeout,
full_state, full_state,
@ -338,6 +369,7 @@ class SyncHandler:
async def _wait_for_sync_for_user( async def _wait_for_sync_for_user(
self, self,
sync_config: SyncConfig, sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken], since_token: Optional[StreamToken],
timeout: int, timeout: int,
full_state: bool, full_state: bool,
@ -363,9 +395,11 @@ class SyncHandler:
else: else:
sync_type = "incremental_sync" sync_type = "incremental_sync"
sync_label = f"{sync_version}:{sync_type}"
context = current_context() context = current_context()
if context: if context:
context.tag = sync_type context.tag = sync_label
# if we have a since token, delete any to-device messages before that token # if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them) # (since we now know that the device has received them)
@ -384,14 +418,16 @@ class SyncHandler:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
result: SyncResult = await self.current_sync_for_user( result: SyncResult = await self.current_sync_for_user(
sync_config, since_token, full_state=full_state sync_config, sync_version, since_token, full_state=full_state
) )
else: else:
# Otherwise, we wait for something to happen and report it to the user. # Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback( async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken before_token: StreamToken, after_token: StreamToken
) -> SyncResult: ) -> SyncResult:
return await self.current_sync_for_user(sync_config, since_token) return await self.current_sync_for_user(
sync_config, sync_version, since_token
)
result = await self.notifier.wait_for_events( result = await self.notifier.wait_for_events(
sync_config.user.to_string(), sync_config.user.to_string(),
@ -416,13 +452,14 @@ class SyncHandler:
lazy_loaded = "true" lazy_loaded = "true"
else: else:
lazy_loaded = "false" lazy_loaded = "false"
non_empty_sync_counter.labels(sync_type, lazy_loaded).inc() non_empty_sync_counter.labels(sync_label, lazy_loaded).inc()
return result return result
async def current_sync_for_user( async def current_sync_for_user(
self, self,
sync_config: SyncConfig, sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
full_state: bool = False, full_state: bool = False,
) -> SyncResult: ) -> SyncResult:
@ -431,12 +468,26 @@ class SyncHandler:
This is a wrapper around `generate_sync_result` which starts an open tracing This is a wrapper around `generate_sync_result` which starts an open tracing
span to track the sync. See `generate_sync_result` for the next part of your span to track the sync. See `generate_sync_result` for the next part of your
indoctrination. indoctrination.
Args:
sync_config: Config/info necessary to process the sync request.
sync_version: Determines what kind of sync response to generate.
since_token: The point in the stream to sync from.p.
full_state: Whether to return the full state for each room.
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
""" """
with start_active_span("sync.current_sync_for_user"): with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token}) log_kv({"since_token": since_token})
sync_result = await self.generate_sync_result( # Go through the `/sync` v2 path
sync_config, since_token, full_state if sync_version == SyncVersion.SYNC_V2:
) sync_result: SyncResult = await self.generate_sync_result(
sync_config, since_token, full_state
)
else:
raise Exception(
f"Unknown sync_version (this is a Synapse problem): {sync_version}"
)
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result)) set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result return sync_result

View file

@ -40,6 +40,7 @@ from synapse.handlers.sync import (
KnockedSyncResult, KnockedSyncResult,
SyncConfig, SyncConfig,
SyncResult, SyncResult,
SyncVersion,
) )
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
@ -232,6 +233,7 @@ class SyncRestServlet(RestServlet):
sync_result = await self.sync_handler.wait_for_sync_for_user( sync_result = await self.sync_handler.wait_for_sync_for_user(
requester, requester,
sync_config, sync_config,
SyncVersion.SYNC_V2,
since_token=since_token, since_token=since_token,
timeout=timeout, timeout=timeout,
full_state=full_state, full_state=full_state,

View file

@ -36,7 +36,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, StreamToken, create_requester from synapse.types import JsonDict, StreamToken, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config from tests.handlers.test_sync import SyncVersion, generate_sync_config
from tests.unittest import ( from tests.unittest import (
FederatingHomeserverTestCase, FederatingHomeserverTestCase,
HomeserverTestCase, HomeserverTestCase,
@ -521,7 +521,7 @@ def sync_presence(
sync_config = generate_sync_config(requester.user.to_string()) sync_config = generate_sync_config(requester.user.to_string())
sync_result = testcase.get_success( sync_result = testcase.get_success(
testcase.hs.get_sync_handler().wait_for_sync_for_user( testcase.hs.get_sync_handler().wait_for_sync_for_user(
requester, sync_config, since_token requester, sync_config, SyncVersion.SYNC_V2, since_token
) )
) )

View file

@ -31,7 +31,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json from synapse.federation.federation_base import event_from_pdu_json
from synapse.handlers.sync import SyncConfig, SyncResult from synapse.handlers.sync import SyncConfig, SyncResult, SyncVersion
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import knock, login, room from synapse.rest.client import knock, login, room
from synapse.server import HomeServer from synapse.server import HomeServer
@ -73,13 +73,21 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors # Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success( self.get_success(
self.sync_handler.wait_for_sync_for_user(requester, sync_config) self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
sync_version=SyncVersion.SYNC_V2,
)
) )
# Test that global lock works # Test that global lock works
self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled = True
e = self.get_failure( e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(requester, sync_config), self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
sync_version=SyncVersion.SYNC_V2,
),
ResourceLimitError, ResourceLimitError,
) )
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@ -90,7 +98,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
requester = create_requester(user_id2) requester = create_requester(user_id2)
e = self.get_failure( e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(requester, sync_config), self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
sync_version=SyncVersion.SYNC_V2,
),
ResourceLimitError, ResourceLimitError,
) )
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@ -109,7 +121,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
requester = create_requester(user) requester = create_requester(user)
initial_result = self.get_success( initial_result = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
requester, sync_config=generate_sync_config(user, device_id="dev") requester,
sync_config=generate_sync_config(user, device_id="dev"),
sync_version=SyncVersion.SYNC_V2,
) )
) )
@ -140,7 +154,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# The rooms should appear in the sync response. # The rooms should appear in the sync response.
result = self.get_success( result = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
requester, sync_config=generate_sync_config(user) requester,
sync_config=generate_sync_config(user),
sync_version=SyncVersion.SYNC_V2,
) )
) )
self.assertIn(joined_room, [r.room_id for r in result.joined]) self.assertIn(joined_room, [r.room_id for r in result.joined])
@ -152,6 +168,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
requester, requester,
sync_config=generate_sync_config(user, device_id="dev"), sync_config=generate_sync_config(user, device_id="dev"),
sync_version=SyncVersion.SYNC_V2,
since_token=initial_result.next_batch, since_token=initial_result.next_batch,
) )
) )
@ -180,7 +197,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Get a new request key. # Get a new request key.
result = self.get_success( result = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
requester, sync_config=generate_sync_config(user) requester,
sync_config=generate_sync_config(user),
sync_version=SyncVersion.SYNC_V2,
) )
) )
self.assertNotIn(joined_room, [r.room_id for r in result.joined]) self.assertNotIn(joined_room, [r.room_id for r in result.joined])
@ -192,6 +211,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
requester, requester,
sync_config=generate_sync_config(user, device_id="dev"), sync_config=generate_sync_config(user, device_id="dev"),
sync_version=SyncVersion.SYNC_V2,
since_token=initial_result.next_batch, since_token=initial_result.next_batch,
) )
) )
@ -231,7 +251,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Do a sync as Alice to get the latest event in the room. # Do a sync as Alice to get the latest event in the room.
alice_sync_result: SyncResult = self.get_success( alice_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
create_requester(owner), generate_sync_config(owner) create_requester(owner),
generate_sync_config(owner),
sync_version=SyncVersion.SYNC_V2,
) )
) )
self.assertEqual(len(alice_sync_result.joined), 1) self.assertEqual(len(alice_sync_result.joined), 1)
@ -251,7 +273,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
eve_requester = create_requester(eve) eve_requester = create_requester(eve)
eve_sync_config = generate_sync_config(eve) eve_sync_config = generate_sync_config(eve)
eve_sync_after_ban: SyncResult = self.get_success( eve_sync_after_ban: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config) self.sync_handler.wait_for_sync_for_user(
eve_requester,
eve_sync_config,
sync_version=SyncVersion.SYNC_V2,
)
) )
# Sanity check this sync result. We shouldn't be joined to the room. # Sanity check this sync result. We shouldn't be joined to the room.
@ -268,6 +294,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
eve_requester, eve_requester,
eve_sync_config, eve_sync_config,
sync_version=SyncVersion.SYNC_V2,
since_token=eve_sync_after_ban.next_batch, since_token=eve_sync_after_ban.next_batch,
) )
) )
@ -279,6 +306,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
eve_requester, eve_requester,
eve_sync_config, eve_sync_config,
sync_version=SyncVersion.SYNC_V2,
since_token=None, since_token=None,
) )
) )
@ -310,7 +338,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Do an initial sync as Alice to get a known starting point. # Do an initial sync as Alice to get a known starting point.
initial_sync_result = self.get_success( initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice) alice_requester,
generate_sync_config(alice),
sync_version=SyncVersion.SYNC_V2,
) )
) )
last_room_creation_event_id = ( last_room_creation_event_id = (
@ -338,6 +368,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.hs, {"room": {"timeline": {"limit": 2}}} self.hs, {"room": {"timeline": {"limit": 2}}}
), ),
), ),
sync_version=SyncVersion.SYNC_V2,
since_token=initial_sync_result.next_batch, since_token=initial_sync_result.next_batch,
) )
) )
@ -380,7 +411,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Do an initial sync as Alice to get a known starting point. # Do an initial sync as Alice to get a known starting point.
initial_sync_result = self.get_success( initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice) alice_requester,
generate_sync_config(alice),
sync_version=SyncVersion.SYNC_V2,
) )
) )
last_room_creation_event_id = ( last_room_creation_event_id = (
@ -418,6 +451,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
}, },
), ),
), ),
sync_version=SyncVersion.SYNC_V2,
since_token=initial_sync_result.next_batch, since_token=initial_sync_result.next_batch,
) )
) )
@ -461,7 +495,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Do an initial sync as Alice to get a known starting point. # Do an initial sync as Alice to get a known starting point.
initial_sync_result = self.get_success( initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice) alice_requester,
generate_sync_config(alice),
sync_version=SyncVersion.SYNC_V2,
) )
) )
last_room_creation_event_id = ( last_room_creation_event_id = (
@ -486,6 +522,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.hs, {"room": {"timeline": {"limit": 1}}} self.hs, {"room": {"timeline": {"limit": 1}}}
), ),
), ),
sync_version=SyncVersion.SYNC_V2,
since_token=initial_sync_result.next_batch, since_token=initial_sync_result.next_batch,
) )
) )
@ -515,6 +552,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.hs, {"room": {"timeline": {"limit": 1}}} self.hs, {"room": {"timeline": {"limit": 1}}}
), ),
), ),
sync_version=SyncVersion.SYNC_V2,
since_token=incremental_sync.next_batch, since_token=incremental_sync.next_batch,
) )
) )
@ -574,7 +612,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Do an initial sync to get a known starting point. # Do an initial sync to get a known starting point.
initial_sync_result = self.get_success( initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice) alice_requester,
generate_sync_config(alice),
sync_version=SyncVersion.SYNC_V2,
) )
) )
last_room_creation_event_id = ( last_room_creation_event_id = (
@ -598,6 +638,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.hs, {"room": {"timeline": {"limit": 1}}} self.hs, {"room": {"timeline": {"limit": 1}}}
), ),
), ),
sync_version=SyncVersion.SYNC_V2,
) )
) )
room_sync = initial_sync_result.joined[0] room_sync = initial_sync_result.joined[0]
@ -618,6 +659,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
alice_requester, alice_requester,
generate_sync_config(alice), generate_sync_config(alice),
sync_version=SyncVersion.SYNC_V2,
since_token=initial_sync_result.next_batch, since_token=initial_sync_result.next_batch,
) )
) )
@ -668,7 +710,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success( initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
bob_requester, generate_sync_config(bob) bob_requester,
generate_sync_config(bob),
sync_version=SyncVersion.SYNC_V2,
) )
) )
@ -699,6 +743,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
generate_sync_config( generate_sync_config(
bob, filter_collection=FilterCollection(self.hs, filter_dict) bob, filter_collection=FilterCollection(self.hs, filter_dict)
), ),
sync_version=SyncVersion.SYNC_V2,
since_token=None if initial_sync else initial_sync_result.next_batch, since_token=None if initial_sync else initial_sync_result.next_batch,
) )
).archived[0] ).archived[0]
@ -791,7 +836,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# but that it does not come down /sync in public room # but that it does not come down /sync in public room
sync_result: SyncResult = self.get_success( sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
create_requester(user), generate_sync_config(user) create_requester(user),
generate_sync_config(user),
sync_version=SyncVersion.SYNC_V2,
) )
) )
event_ids = [] event_ids = []
@ -837,7 +884,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
private_sync_result: SyncResult = self.get_success( private_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user( self.sync_handler.wait_for_sync_for_user(
create_requester(user2), generate_sync_config(user2) create_requester(user2),
generate_sync_config(user2),
sync_version=SyncVersion.SYNC_V2,
) )
) )
priv_event_ids = [] priv_event_ids = []