Merge branch 'develop' into madlittlemods/msc3575-sliding-sync-0.0.1

This commit is contained in:
Eric Eastwood 2024-06-06 12:10:21 -05:00
commit c89f012d8b
6 changed files with 109 additions and 48 deletions

1
changelog.d/17271.misc Normal file
View file

@ -0,0 +1 @@
Handle OTK uploads off master.

1
changelog.d/17273.misc Normal file
View file

@ -0,0 +1 @@
Don't try and resync devices for remote users whose servers are marked as down.

1
changelog.d/17275.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where OTKs were not always included in `/sync` response when using workers.

View file

@ -35,6 +35,7 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.types import (
JsonDict,
JsonMapping,
@ -45,7 +46,10 @@ from synapse.types import (
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.retryutils import (
NotRetryingDestination,
filter_destinations_by_retry_limiter,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -86,6 +90,12 @@ class E2eKeysHandler:
edu_updater.incoming_signing_key_update,
)
self.device_key_uploader = self.upload_device_keys_for_user
else:
self.device_key_uploader = (
ReplicationUploadKeysForUserRestServlet.make_client(hs)
)
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
@ -268,10 +278,8 @@ class E2eKeysHandler:
"%d destinations to query devices for", len(remote_queries_not_in_cache)
)
async def _query(
destination_queries: Tuple[str, Dict[str, Iterable[str]]]
) -> None:
destination, queries = destination_queries
async def _query(destination: str) -> None:
queries = remote_queries_not_in_cache[destination]
return await self._query_devices_for_destination(
results,
cross_signing_keys,
@ -281,9 +289,20 @@ class E2eKeysHandler:
timeout,
)
# Only try and fetch keys for destinations that are not marked as
# down.
filtered_destinations = await filter_destinations_by_retry_limiter(
remote_queries_not_in_cache.keys(),
self.clock,
self.store,
# Let's give an arbitrary grace period for those hosts that are
# only recently down
retry_due_within_ms=60 * 1000,
)
await concurrently_execute(
_query,
remote_queries_not_in_cache.items(),
filtered_destinations,
10,
delay_cancellation=True,
)
@ -784,36 +803,17 @@ class E2eKeysHandler:
"one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
await self.device_key_uploader(
user_id=user_id,
device_id=device_id,
keys={"device_keys": device_keys},
)
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
log_kv(
@ -849,6 +849,49 @@ class E2eKeysHandler:
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
@tag_args
async def upload_device_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> None:
"""
Args:
user_id: user whose keys are being uploaded.
device_id: device whose keys are being uploaded.
device_keys: the `device_keys` of an /keys/upload request.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
device_keys = keys["device_keys"]
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
)
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
@ -856,11 +899,6 @@ class E2eKeysHandler:
# keys without a corresponding device.
await self.device_handler.check_device_registered(user_id, device_id)
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None:

View file

@ -285,7 +285,11 @@ class SyncResult:
)
@staticmethod
def empty(next_batch: StreamToken) -> "SyncResult":
def empty(
next_batch: StreamToken,
device_one_time_keys_count: JsonMapping,
device_unused_fallback_key_types: List[str],
) -> "SyncResult":
"Return a new empty result"
return SyncResult(
next_batch=next_batch,
@ -297,8 +301,8 @@ class SyncResult:
archived=[],
to_device=[],
device_lists=DeviceListUpdates(),
device_one_time_keys_count={},
device_unused_fallback_key_types=[],
device_one_time_keys_count=device_one_time_keys_count,
device_unused_fallback_key_types=device_unused_fallback_key_types,
)
@ -523,7 +527,28 @@ class SyncHandler:
logger.warning(
"Timed out waiting for worker to catch up. Returning empty response"
)
return SyncResult.empty(since_token)
device_id = sync_config.device_id
one_time_keys_count: JsonMapping = {}
unused_fallback_key_types: List[str] = []
if device_id:
user_id = sync_config.user.to_string()
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(
user_id, device_id
)
)
cache_context.should_cache = False # Don't cache empty responses
return SyncResult.empty(
since_token, one_time_keys_count, unused_fallback_key_types
)
# If we've spent significant time waiting to catch up, take it off
# the timeout.

View file

@ -36,7 +36,6 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken
from synapse.util.cancellation import cancellable
@ -105,13 +104,8 @@ class KeyUploadServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
if hs.config.worker.worker_app is None:
# if main process
self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
else:
# then a worker
self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
@ -151,9 +145,10 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating"
)
result = await self.key_uploader(
result = await self.e2e_keys_handler.upload_keys_for_user(
user_id=user_id, device_id=device_id, keys=body
)
return 200, result