Add unstable /keys/claim endpoint which always returns fallback keys. (#15462)

It can be useful to always return the fallback key when attempting to
claim keys. This adds an unstable endpoint for `/keys/claim` which
always returns fallback keys in addition to one-time-keys.

The fallback key(s) are not marked as "used" unless there are no
corresponding OTKs.

This is currently defined in MSC3983 (although likely to be split out
to a separate MSC). The endpoint shape may change or be requested
differently (i.e. a keyword parameter on the current endpoint), but the
core logic should be reasonable.
This commit is contained in:
Patrick Cloke 2023-04-25 13:30:41 -04:00 committed by GitHub
parent b39b02c26e
commit 8e9739449d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 371 additions and 29 deletions

1
changelog.d/15462.misc Normal file
View file

@ -0,0 +1 @@
Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.

View file

@ -1005,7 +1005,7 @@ class FederationServer(FederationBase):
@trace @trace
async def on_claim_client_keys( async def on_claim_client_keys(
self, origin: str, content: JsonDict self, origin: str, content: JsonDict, always_include_fallback_keys: bool
) -> Dict[str, Any]: ) -> Dict[str, Any]:
query = [] query = []
for user_id, device_keys in content.get("one_time_keys", {}).items(): for user_id, device_keys in content.get("one_time_keys", {}).items():
@ -1013,7 +1013,9 @@ class FederationServer(FederationBase):
query.append((user_id, device_id, algorithm)) query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(query) results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
)
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results: for result in results:

View file

@ -25,6 +25,7 @@ from synapse.federation.transport.server._base import (
from synapse.federation.transport.server.federation import ( from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES, FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet, FederationAccountStatusServlet,
FederationUnstableClientKeysClaimServlet,
) )
from synapse.http.server import HttpServer, JsonResource from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -298,6 +299,11 @@ def register_servlets(
and not hs.config.experimental.msc3720_enabled and not hs.config.experimental.msc3720_enabled
): ):
continue continue
if (
servletclass == FederationUnstableClientKeysClaimServlet
and not hs.config.experimental.msc3983_appservice_otk_claims
):
continue
servletclass( servletclass(
hs=hs, hs=hs,

View file

@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST( async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(origin, content) response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=False
)
return 200, response
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
always includes fallback keys in the response.
"""
PREFIX = FEDERATION_UNSTABLE_PREFIX
PATH = "/user/keys/claim"
CATEGORY = "Federation requests"
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=True
)
return 200, response return 200, response

View file

@ -842,9 +842,7 @@ class ApplicationServicesHandler:
async def claim_e2e_one_time_keys( async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]] self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[ ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
]:
"""Claim one time keys from application services. """Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a Users which are exclusively owned by an application service are sent a
@ -856,7 +854,7 @@ class ApplicationServicesHandler:
Returns: Returns:
A tuple of: A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. A map of user ID -> a map device ID -> a map of key ID -> JSON.
A copy of the input which has not been fulfilled (either because A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support they are not appservice users or the appservice does not support
@ -897,12 +895,11 @@ class ApplicationServicesHandler:
) )
# Patch together the results -- they are all independent (since they # Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a list # require exclusive control over the users, which is the outermost key).
# and the caller combines them. claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
for success, result in results: for success, result in results:
if success: if success:
claimed_keys.append(result[0]) claimed_keys.update(result[0])
missing.extend(result[1]) missing.extend(result[1])
return claimed_keys, missing return claimed_keys, missing

View file

@ -563,7 +563,9 @@ class E2eKeysHandler:
return ret return ret
async def claim_local_one_time_keys( async def claim_local_one_time_keys(
self, local_query: List[Tuple[str, str, str]] self,
local_query: List[Tuple[str, str, str]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users. """Claim one time keys for local users.
@ -573,6 +575,7 @@ class E2eKeysHandler:
Args: Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm). local_query: An iterable of tuples of (user ID, device ID, algorithm).
always_include_fallback_keys: True to always include fallback keys.
Returns: Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
@ -583,24 +586,73 @@ class E2eKeysHandler:
# If the application services have not provided any keys via the C-S # If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys. # API, query it directly for one-time keys.
if self._query_appservices_for_otks: if self._query_appservices_for_otks:
# TODO Should this query for fallback keys of uploaded OTKs if
# always_include_fallback_keys is True? The MSC is ambiguous.
( (
appservice_results, appservice_results,
not_found, not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else: else:
appservice_results = [] appservice_results = {}
# Calculate which user ID / device ID / algorithm tuples to get fallback
# keys for. This can be either only missing results *or* all results
# (which don't already have a fallback key).
if always_include_fallback_keys:
# Build the fallback query as any part of the original query where
# the appservice didn't respond with a fallback key.
fallback_query = []
# Iterate each item in the original query and search the results
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
for user_id, device_id, algorithm in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
for key_id, key_json in as_result.items():
if key_id.startswith(f"{algorithm}:"):
# A OTK or fallback key was found for this query.
found_otk = True
# A fallback key was found for this query, no need to
# query further.
if key_json.get("fallback", False):
break
else:
# No fallback key was found from appservices, query for it.
# Only mark the fallback key as used if no OTK was found
# (from either the database or appservices).
mark_as_used = not found_otk and not any(
key_id.startswith(f"{algorithm}:")
for key_id in otk_results.get(user_id, {})
.get(device_id, {})
.keys()
)
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
else:
# All fallback keys get marked as used.
fallback_query = [
(user_id, device_id, algorithm, True)
for user_id, device_id, algorithm in not_found
]
# For each user that does not have a one-time keys available, see if # For each user that does not have a one-time keys available, see if
# there is a fallback key. # there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found) fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
# Return the results in order, each item from the input query should # Return the results in order, each item from the input query should
# only appear once in the combined list. # only appear once in the combined list.
return (otk_results, *appservice_results, fallback_results) return (otk_results, appservice_results, fallback_results)
@trace @trace
async def claim_one_time_keys( async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] self,
query: Dict[str, Dict[str, Dict[str, str]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict: ) -> JsonDict:
local_query: List[Tuple[str, str, str]] = [] local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@ -617,7 +669,9 @@ class E2eKeysHandler:
set_tag("local_key_query", str(local_query)) set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries)) set_tag("remote_key_query", str(remote_queries))
results = await self.claim_local_one_time_keys(local_query) results = await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys
)
# A map of user ID -> device ID -> key ID -> key. # A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
@ -625,7 +679,9 @@ class E2eKeysHandler:
for user_id, device_keys in result.items(): for user_id, device_keys in result.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, key in keys.items(): for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key} json_result.setdefault(user_id, {}).setdefault(
device_id, {}
).update({key_id: key})
# Remote failures. # Remote failures.
failures: Dict[str, JsonDict] = {} failures: Dict[str, JsonDict] = {}

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from typing import TYPE_CHECKING, Any, Optional, Tuple from typing import TYPE_CHECKING, Any, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError from synapse.api.errors import InvalidAPICallError, SynapseError
@ -288,7 +289,33 @@ class OneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=False
)
return 200, result
class UnstableOneTimeKeyServlet(RestServlet):
"""
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
fallback keys in the response.
"""
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=True
)
return 200, result return 200, result
@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server)
if hs.config.experimental.msc3983_appservice_otk_claims:
UnstableOneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server) SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server) SignaturesUploadServlet(hs).register(http_server)

View file

@ -1149,18 +1149,19 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return results, missing return results, missing
async def claim_e2e_fallback_keys( async def claim_e2e_fallback_keys(
self, query_list: Iterable[Tuple[str, str, str]] self, query_list: Iterable[Tuple[str, str, str, bool]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Take a list of fallback keys out of the database. """Take a list of fallback keys out of the database.
Args: Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm). query_list: An iterable of tuples of
(user ID, device ID, algorithm, whether the key should be marked as used).
Returns: Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON. A map of user ID -> a map device ID -> a map of key ID -> JSON.
""" """
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm in query_list: for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json", table="e2e_fallback_keys_json",
keyvalues={ keyvalues={
@ -1180,7 +1181,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
used = row["used"] used = row["used"]
# Mark fallback key as used if not already. # Mark fallback key as used if not already.
if not used: if not used and mark_as_used:
await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="e2e_fallback_keys_json", table="e2e_fallback_keys_json",
keyvalues={ keyvalues={

View file

@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res2 = self.get_success( res2 = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
) )
) )
self.assertEqual( self.assertEqual(
@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# key # key
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
) )
) )
self.assertEqual( self.assertEqual(
@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# claiming an OTK again should return the same fallback key # claiming an OTK again should return the same fallback key
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
) )
) )
self.assertEqual( self.assertEqual(
@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
) )
) )
self.assertEqual( self.assertEqual(
@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
) )
) )
self.assertEqual( self.assertEqual(
@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
) )
) )
self.assertEqual( self.assertEqual(
@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
) )
def test_fallback_key_always_returned(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
# Upload a OTK & fallback key.
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"one_time_keys": otk, "fallback_keys": fallback_key},
)
)
# we should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, ["alg1"])
# Claiming an OTK and requesting to always return the fallback key should
# return both.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
},
)
# This should not mark the key as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, ["alg1"])
# Claiming an OTK again should return only the fallback key.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# And mark it as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, [])
def test_replace_master_key(self) -> None: def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable""" """uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
} }
}, },
timeout=None, timeout=None,
always_include_fallback_keys=False,
) )
) )
self.assertEqual( self.assertEqual(
@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
@override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
def test_query_appservice_with_fallback(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id_1 = "xyz"
fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
otk = {"alg1:k2": {"desc": "key2"}}
as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
as_otk = {"alg1:k4": {"desc": "key4"}}
# Inject an appservice interested in this user.
appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above
sender="@as_main:test",
)
self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
[appservice]
)
# Setup a response.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
)
# Claim OTKs, which will ask the appservice and do nothing else.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {
local_user: {device_id_1: {**as_otk, **as_fallback_key}}
},
},
)
# Now upload a fallback key.
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(res, [])
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id_1,
{"fallback_keys": fallback_key},
)
)
# we should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(fallback_res, ["alg1"])
# The appservice will return only the OTK.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: as_otk}}, [])
)
# Claim OTKs, which should return the OTK from the appservice and the
# uploaded fallback key.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {
local_user: {device_id_1: {**as_otk, **fallback_key}}
},
},
)
# But the fallback key should not be marked as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(fallback_res, ["alg1"])
# Now upload a OTK.
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id_1,
{"one_time_keys": otk},
)
)
# Claim OTKs, which will return information only from the database.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
},
)
# But the fallback key should not be marked as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(fallback_res, ["alg1"])
# Finally, return only the fallback key from the appservice.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: as_fallback_key}}, [])
)
# Claim OTKs, which will return only the fallback key from the database.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {local_user: {device_id_1: as_fallback_key}},
},
)
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}}) @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
def test_query_local_devices_appservice(self) -> None: def test_query_local_devices_appservice(self) -> None:
"""Test that querying of appservices for keys overrides responses from the database.""" """Test that querying of appservices for keys overrides responses from the database."""