Add type hints to E2E handler. (#9232)

This finishes adding type hints to the `synapse.handlers` module.
This commit is contained in:
Patrick Cloke 2021-01-28 08:34:19 -05:00 committed by GitHub
parent 34efb4c604
commit a78016dadf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 198 additions and 177 deletions

1
changelog.d/9232.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to handlers code.

View file

@ -23,47 +23,7 @@ files =
synapse/events/validator.py, synapse/events/validator.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/federation, synapse/federation,
synapse/handlers/_base.py, synapse/handlers,
synapse/handlers/account_data.py,
synapse/handlers/account_validity.py,
synapse/handlers/acme.py,
synapse/handlers/acme_issuing_service.py,
synapse/handlers/admin.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py,
synapse/handlers/device.py,
synapse/handlers/devicemessage.py,
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
synapse/handlers/groups_local.py,
synapse/handlers/identity.py,
synapse/handlers/initial_sync.py,
synapse/handlers/message.py,
synapse/handlers/oidc_handler.py,
synapse/handlers/pagination.py,
synapse/handlers/password_policy.py,
synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/receipts.py,
synapse/handlers/register.py,
synapse/handlers/room.py,
synapse/handlers/room_list.py,
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py,
synapse/handlers/search.py,
synapse/handlers/set_password.py,
synapse/handlers/sso.py,
synapse/handlers/state_deltas.py,
synapse/handlers/stats.py,
synapse/handlers/sync.py,
synapse/handlers/typing.py,
synapse/handlers/user_directory.py,
synapse/handlers/ui_auth,
synapse/http/client.py, synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py, synapse/http/federation/well_known_resolver.py,

View file

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@trace @trace
async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
""" """
Retrieve the given user's devices Retrieve the given user's devices
@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
return devices return devices
@trace @trace
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: async def get_device(self, user_id: str, device_id: str) -> JsonDict:
""" Retrieve the given device """ Retrieve the given device
Args: Args:
@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips( def _update_device_from_client_ips(
device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
) -> None: ) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@ -946,8 +946,8 @@ class DeviceListUpdater:
async def process_cross_signing_key_update( async def process_cross_signing_key_update(
self, self,
user_id: str, user_id: str,
master_key: Optional[Dict[str, Any]], master_key: Optional[JsonDict],
self_signing_key: Optional[Dict[str, Any]], self_signing_key: Optional[JsonDict],
) -> List[str]: ) -> List[str]:
"""Process the given new master and self-signing key for the given remote user. """Process the given new master and self-signing key for the given remote user.

View file

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import ( from synapse.types import (
JsonDict,
UserID, UserID,
get_domain_from_id, get_domain_from_id,
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class E2eKeysHandler: class E2eKeysHandler:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@ -78,7 +82,9 @@ class E2eKeysHandler:
) )
@trace @trace
async def query_devices(self, query_body, timeout, from_user_id): async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str
) -> JsonDict:
""" Handle a device key query from a client """ Handle a device key query from a client
{ {
@ -98,12 +104,14 @@ class E2eKeysHandler:
} }
Args: Args:
from_user_id (str): the user making the query. This is used when from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users adding cross-signing signatures to limit what signatures users
can see. can see.
""" """
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
# separate users by domain. # separate users by domain.
# make a map from domain to user_id to device_ids # make a map from domain to user_id to device_ids
@ -121,7 +129,8 @@ class E2eKeysHandler:
set_tag("remote_key_query", remote_queries) set_tag("remote_key_query", remote_queries)
# First get local devices. # First get local devices.
failures = {} # A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
results = {} results = {}
if local_query: if local_query:
local_result = await self.query_local_devices(local_query) local_result = await self.query_local_devices(local_query)
@ -135,9 +144,10 @@ class E2eKeysHandler:
) )
# Now attempt to get any remote devices from our local cache. # Now attempt to get any remote devices from our local cache.
remote_queries_not_in_cache = {} # A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries: if remote_queries:
query_list = [] query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items(): for user_id, device_ids in remote_queries.items():
if device_ids: if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids) query_list.extend((user_id, device_id) for device_id in device_ids)
@ -284,15 +294,15 @@ class E2eKeysHandler:
return ret return ret
async def get_cross_signing_keys_from_cache( async def get_cross_signing_keys_from_cache(
self, query, from_user_id self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database """Get cross-signing keys for users from the database
Args: Args:
query (Iterable[string]) an iterable of user IDs. A dict whose keys query: an iterable of user IDs. A dict whose keys
are user IDs satisfies this, so the query format used for are user IDs satisfies this, so the query format used for
query_devices can be used here. query_devices can be used here.
from_user_id (str): the user making the query. This is used when from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users adding cross-signing signatures to limit what signatures users
can see. can see.
@ -315,14 +325,12 @@ class E2eKeysHandler:
if "self_signing" in user_info: if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"] self_signing_keys[user_id] = user_info["self_signing"]
if ( # users can see other users' master and self-signing keys, but can
from_user_id in keys # only see their own user-signing keys
and keys[from_user_id] is not None if from_user_id:
and "user_signing" in keys[from_user_id] from_user_key = keys.get(from_user_id)
): if from_user_key and "user_signing" in from_user_key:
# users can see other users' master and self-signing keys, but can user_signing_keys[from_user_id] = from_user_key["user_signing"]
# only see their own user-signing keys
user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
return { return {
"master_keys": master_keys, "master_keys": master_keys,
@ -344,9 +352,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details A map from user_id -> device_id -> device details
""" """
set_tag("local_query", query) set_tag("local_query", query)
local_query = [] local_query = [] # type: List[Tuple[str, Optional[str]]]
result_dict = {} result_dict = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_ids in query.items(): for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)): if not self.is_mine(UserID.from_string(user_id)):
@ -380,10 +388,14 @@ class E2eKeysHandler:
log_kv(results) log_kv(results)
return result_dict return result_dict
async def on_federation_query_client_keys(self, query_body): async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
""" Handle a device key query from a federated server """ Handle a device key query from a federated server
""" """
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
res = await self.query_local_devices(device_keys_query) res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res} ret = {"device_keys": res}
@ -397,31 +409,34 @@ class E2eKeysHandler:
return ret return ret
@trace @trace
async def claim_one_time_keys(self, query, timeout): async def claim_one_time_keys(
local_query = [] self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
remote_queries = {} ) -> JsonDict:
local_query = [] # type: List[Tuple[str, str, str]]
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
for user_id, device_keys in query.get("one_time_keys", {}).items(): for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)): if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items(): for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm)) local_query.append((user_id, device_id, algorithm))
else: else:
domain = get_domain_from_id(user_id) domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query) set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries) set_tag("remote_key_query", remote_queries)
results = await self.store.claim_e2e_one_time_keys(local_query) results = await self.store.claim_e2e_one_time_keys(local_query)
json_result = {} # A map of user ID -> device ID -> key ID -> key.
failures = {} json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
failures = {} # type: Dict[str, JsonDict]
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items(): for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = { json_result.setdefault(user_id, {})[device_id] = {
key_id: json_decoder.decode(json_bytes) key_id: json_decoder.decode(json_str)
} }
@trace @trace
@ -468,7 +483,9 @@ class E2eKeysHandler:
return {"one_time_keys": json_result, "failures": failures} return {"one_time_keys": json_result, "failures": failures}
@tag_args @tag_args
async def upload_keys_for_user(self, user_id, device_id, keys): async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -543,8 +560,8 @@ class E2eKeysHandler:
return {"one_time_key_counts": result} return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user( async def _upload_one_time_keys_for_user(
self, user_id, device_id, time_now, one_time_keys self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
): ) -> None:
logger.info( logger.info(
"Adding one_time_keys %r for device %r for user %r at %d", "Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(), one_time_keys.keys(),
@ -585,12 +602,14 @@ class E2eKeysHandler:
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
async def upload_signing_keys_for_user(self, user_id, keys): async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict
) -> JsonDict:
"""Upload signing keys for cross-signing """Upload signing keys for cross-signing
Args: Args:
user_id (string): the user uploading the keys user_id: the user uploading the keys
keys (dict[string, dict]): the signing keys keys: the signing keys
""" """
# if a master key is uploaded, then check it. Otherwise, load the # if a master key is uploaded, then check it. Otherwise, load the
@ -667,16 +686,17 @@ class E2eKeysHandler:
return {} return {}
async def upload_signatures_for_device_keys(self, user_id, signatures): async def upload_signatures_for_device_keys(
self, user_id: str, signatures: JsonDict
) -> JsonDict:
"""Upload device signatures for cross-signing """Upload device signatures for cross-signing
Args: Args:
user_id (string): the user uploading the signatures user_id: the user uploading the signatures
signatures (dict[string, dict[string, dict]]): map of users to signatures: map of users to devices to signed keys. This is the submission
devices to signed keys. This is the submission from the user; an from the user; an exception will be raised if it is malformed.
exception will be raised if it is malformed.
Returns: Returns:
dict: response to be sent back to the client. The response will have The response to be sent back to the client. The response will have
a "failures" key, which will be a dict mapping users to devices a "failures" key, which will be a dict mapping users to devices
to errors for the signatures that failed. to errors for the signatures that failed.
Raises: Raises:
@ -719,7 +739,9 @@ class E2eKeysHandler:
return {"failures": failures} return {"failures": failures}
async def _process_self_signatures(self, user_id, signatures): async def _process_self_signatures(
self, user_id: str, signatures: JsonDict
) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of the user's own keys. """Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms: Signatures of the user's own keys from this API come in two forms:
@ -731,15 +753,14 @@ class E2eKeysHandler:
signatures (dict[string, dict]): map of devices to signed keys signatures (dict[string, dict]): map of devices to signed keys
Returns: Returns:
(list[SignatureListItem], dict[string, dict[string, dict]]): A tuple of a list of signatures to store, and a map of users to
a list of signatures to store, and a map of users to devices to failure devices to failure reasons
reasons
Raises: Raises:
SynapseError: if the input is malformed SynapseError: if the input is malformed
""" """
signature_list = [] signature_list = [] # type: List[SignatureListItem]
failures = {} failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures: if not signatures:
return signature_list, failures return signature_list, failures
@ -834,19 +855,24 @@ class E2eKeysHandler:
return signature_list, failures return signature_list, failures
def _check_master_key_signature( def _check_master_key_signature(
self, user_id, master_key_id, signed_master_key, stored_master_key, devices self,
): user_id: str,
master_key_id: str,
signed_master_key: JsonDict,
stored_master_key: JsonDict,
devices: Dict[str, Dict[str, JsonDict]],
) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices. """Check signatures of a user's master key made by their devices.
Args: Args:
user_id (string): the user whose master key is being checked user_id: the user whose master key is being checked
master_key_id (string): the ID of the user's master key master_key_id: the ID of the user's master key
signed_master_key (dict): the user's signed master key that was uploaded signed_master_key: the user's signed master key that was uploaded
stored_master_key (dict): our previously-stored copy of the user's master key stored_master_key: our previously-stored copy of the user's master key
devices (iterable(dict)): the user's devices devices: the user's devices
Returns: Returns:
list[SignatureListItem]: a list of signatures to store A list of signatures to store
Raises: Raises:
SynapseError: if a signature is invalid SynapseError: if a signature is invalid
@ -877,25 +903,26 @@ class E2eKeysHandler:
return master_key_signature_list return master_key_signature_list
async def _process_other_signatures(self, user_id, signatures): async def _process_other_signatures(
self, user_id: str, signatures: Dict[str, dict]
) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of other users' keys. These will be the """Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing target user's master keys, signed by the uploading user's user-signing
key. key.
Args: Args:
user_id (string): the user uploading the keys user_id: the user uploading the keys
signatures (dict[string, dict]): map of users to devices to signed keys signatures: map of users to devices to signed keys
Returns: Returns:
(list[SignatureListItem], dict[string, dict[string, dict]]): A list of signatures to store, and a map of users to devices to failure
a list of signatures to store, and a map of users to devices to failure
reasons reasons
Raises: Raises:
SynapseError: if the input is malformed SynapseError: if the input is malformed
""" """
signature_list = [] signature_list = [] # type: List[SignatureListItem]
failures = {} failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures: if not signatures:
return signature_list, failures return signature_list, failures
@ -983,7 +1010,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key( async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None self, user_id: str, key_type: str, from_user_id: str = None
): ) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key. """Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage. First, attempt to fetch the cross-signing public key from storage.
@ -997,8 +1024,7 @@ class E2eKeysHandler:
This affects what signatures are fetched. This affects what signatures are fetched.
Returns: Returns:
dict, str, VerifyKey: the raw key data, the key ID, and the The raw key data, the key ID, and the signedjson verify key
signedjson verify key
Raises: Raises:
NotFoundError: if the key is not found NotFoundError: if the key is not found
@ -1135,16 +1161,18 @@ class E2eKeysHandler:
return desired_key, desired_key_id, desired_verify_key return desired_key, desired_key_id, desired_verify_key
def _check_cross_signing_key(key, user_id, key_type, signing_key=None): def _check_cross_signing_key(
key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
) -> None:
"""Check a cross-signing key uploaded by a user. Performs some basic sanity """Check a cross-signing key uploaded by a user. Performs some basic sanity
checking, and ensures that it is signed, if a signature is required. checking, and ensures that it is signed, if a signature is required.
Args: Args:
key (dict): the key data to verify key: the key data to verify
user_id (str): the user whose key is being checked user_id: the user whose key is being checked
key_type (str): the type of key that the key should be key_type: the type of key that the key should be
signing_key (VerifyKey): (optional) the signing key that the key should signing_key: the signing key that the key should be signed with. If
be signed with. If omitted, signatures will not be checked. omitted, signatures will not be checked.
""" """
if ( if (
key.get("user_id") != user_id key.get("user_id") != user_id
@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
) )
def _check_device_signature(user_id, verify_key, signed_device, stored_device): def _check_device_signature(
user_id: str,
verify_key: VerifyKey,
signed_device: JsonDict,
stored_device: JsonDict,
) -> None:
"""Check that a signature on a device or cross-signing key is correct and """Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an matches the copy of the device/key that we have stored. Throws an
exception if an error is detected. exception if an error is detected.
Args: Args:
user_id (str): the user ID whose signature is being checked user_id: the user ID whose signature is being checked
verify_key (VerifyKey): the key to verify the device with verify_key: the key to verify the device with
signed_device (dict): the uploaded signed device data signed_device: the uploaded signed device data
stored_device (dict): our previously stored copy of the device stored_device: our previously stored copy of the device
Raises: Raises:
SynapseError: if the signature was invalid or the sent device is not the SynapseError: if the signature was invalid or the sent device is not the
@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
def _exception_to_failure(e): def _exception_to_failure(e: Exception) -> JsonDict:
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)} return {"status": e.code, "errcode": e.errcode, "message": str(e)}
@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
return {"status": 503, "message": str(e)} return {"status": 503, "message": str(e)}
def _one_time_keys_match(old_key_json, new_key): def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
old_key = json_decoder.decode(old_key_json) old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly # if either is a string rather than an object, they must match exactly
@ -1239,16 +1272,16 @@ class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys. """An item in the signature list as used by upload_signatures_for_device_keys.
""" """
signing_key_id = attr.ib() signing_key_id = attr.ib(type=str)
target_user_id = attr.ib() target_user_id = attr.ib(type=str)
target_device_id = attr.ib() target_device_id = attr.ib(type=str)
signature = attr.ib() signature = attr.ib(type=JsonDict)
class SigningKeyEduUpdater: class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB""" """Handles incoming signing key updates from federation and updates the DB"""
def __init__(self, hs, e2e_keys_handler): def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key") self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled. # user_id -> list of updates waiting to be handled.
self._pending_updates = {} self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB, # Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious # but they're useful to have them about to reduce the number of spurious
@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
iterable=True, iterable=True,
) )
async def incoming_signing_key_update(self, origin, edu_content): async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
) -> None:
"""Called on incoming signing key update from federation. Responsible for """Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list. parsing the EDU and adding to pending updates list.
Args: Args:
origin (string): the server that sent the EDU origin: the server that sent the EDU
edu_content (dict): the contents of the EDU edu_content: the contents of the EDU
""" """
user_id = edu_content.pop("user_id") user_id = edu_content.pop("user_id")
@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
await self._handle_signing_key_updates(user_id) await self._handle_signing_key_updates(user_id)
async def _handle_signing_key_updates(self, user_id): async def _handle_signing_key_updates(self, user_id: str) -> None:
"""Actually handle pending updates. """Actually handle pending updates.
Args: Args:
user_id (string): the user whose updates we are processing user_id: the user whose updates we are processing
""" """
device_handler = self.e2e_keys_handler.device_handler device_handler = self.e2e_keys_handler.device_handler
@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates # This can happen since we batch updates
return return
device_ids = [] device_ids = [] # type: List[str]
logger.info("pending updates: %r", pending_updates) logger.info("pending updates: %r", pending_updates)

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List, Optional
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
@ -24,8 +25,12 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.logging.opentracing import log_kv, trace from synapse.logging.opentracing import log_kv, trace
from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
The actual payload of the encrypted keys is completely opaque to the handler. The actual payload of the encrypted keys is completely opaque to the handler.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
# Used to lock whenever a client is uploading key data. This prevents collisions # Used to lock whenever a client is uploading key data. This prevents collisions
@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
self._upload_linearizer = Linearizer("upload_room_keys_lock") self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace @trace
async def get_room_keys(self, user_id, version, room_id=None, session_id=None): async def get_room_keys(
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> List[JsonDict]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given """Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session. room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details. See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
Args: Args:
user_id(str): the user whose keys we're getting user_id: the user whose keys we're getting
version(str): the version ID of the backup we're getting keys from version: the version ID of the backup we're getting keys from
room_id(string): room ID to get keys for, for None to get keys for all rooms room_id: room ID to get keys for, for None to get keys for all rooms
session_id(string): session ID to get keys for, for None to get keys for all session_id: session ID to get keys for, for None to get keys for all
sessions sessions
Raises: Raises:
NotFoundError: if the backup version does not exist NotFoundError: if the backup version does not exist
Returns: Returns:
A deferred list of dicts giving the session_data and message metadata for A list of dicts giving the session_data and message metadata for
these room keys. these room keys.
""" """
@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
return results return results
@trace @trace
async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): async def delete_room_keys(
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> JsonDict:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session. room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
Args: Args:
user_id(str): the user whose backup we're deleting user_id: the user whose backup we're deleting
version(str): the version ID of the backup we're deleting version: the version ID of the backup we're deleting
room_id(string): room ID to delete keys for, for None to delete keys for all room_id: room ID to delete keys for, for None to delete keys for all
rooms rooms
session_id(string): session ID to delete keys for, for None to delete keys session_id: session ID to delete keys for, for None to delete keys
for all sessions for all sessions
Raises: Raises:
NotFoundError: if the backup version does not exist NotFoundError: if the backup version does not exist
@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count} return {"etag": str(version_etag), "count": count}
@trace @trace
async def upload_room_keys(self, user_id, version, room_keys): async def upload_room_keys(
self, user_id: str, version: str, room_keys: JsonDict
) -> JsonDict:
"""Bulk upload a list of room keys into a given backup version, asserting """Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT(). into the current backup as described in RoomKeysServlet.on_PUT().
Args: Args:
user_id(str): the user whose backup we're setting user_id: the user whose backup we're setting
version(str): the version ID of the backup we're updating version: the version ID of the backup we're updating
room_keys(dict): a nested dict describing the room_keys we're setting: room_keys: a nested dict describing the room_keys we're setting:
{ {
"rooms": { "rooms": {
@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count} return {"etag": str(version_etag), "count": count}
@staticmethod @staticmethod
def _should_replace_room_key(current_room_key, room_key): def _should_replace_room_key(
current_room_key: Optional[JsonDict], room_key: JsonDict
) -> bool:
""" """
Determine whether to replace a given current_room_key (if any) Determine whether to replace a given current_room_key (if any)
with a newly uploaded room_key backup with a newly uploaded room_key backup
Args: Args:
current_room_key (dict): Optional, the current room_key dict if any current_room_key: Optional, the current room_key dict if any
room_key (dict): The new room_key dict which may or may not be fit to room_key : The new room_key dict which may or may not be fit to
replace the current_room_key replace the current_room_key
Returns: Returns:
@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
return True return True
@trace @trace
async def create_version(self, user_id, version_info): async def create_version(self, user_id: str, version_info: JsonDict) -> str:
"""Create a new backup version. This automatically becomes the new """Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be backup version for the user's keys; previous backups will no longer be
writeable to. writeable to.
Args: Args:
user_id(str): the user whose backup version we're creating user_id: the user whose backup version we're creating
version_info(dict): metadata about the new version being created version_info: metadata about the new version being created
{ {
"algorithm": "m.megolm_backup.v1", "algorithm": "m.megolm_backup.v1",
@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
} }
Returns: Returns:
A deferred of a string that gives the new version number. The new version number.
""" """
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
) )
return new_version return new_version
async def get_version_info(self, user_id, version=None): async def get_version_info(
self, user_id: str, version: Optional[str] = None
) -> JsonDict:
"""Get the info about a given version of the user's backup """Get the info about a given version of the user's backup
Args: Args:
user_id(str): the user whose current backup version we're querying user_id: the user whose current backup version we're querying
version(str): Optional; if None gives the most recent version version: Optional; if None gives the most recent version
otherwise a historical one. otherwise a historical one.
Raises: Raises:
NotFoundError: if the requested backup version doesn't exist NotFoundError: if the requested backup version doesn't exist
Returns: Returns:
A deferred of a info dict that gives the info about the new version. A info dict that gives the info about the new version.
{ {
"version": "1234", "version": "1234",
@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
return res return res
@trace @trace
async def delete_version(self, user_id, version=None): async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
"""Deletes a given version of the user's e2e_room_keys backup """Deletes a given version of the user's e2e_room_keys backup
Args: Args:
@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
raise raise
@trace @trace
async def update_version(self, user_id, version, version_info): async def update_version(
self, user_id: str, version: str, version_info: JsonDict
) -> JsonDict:
"""Update the info about a given version of the user's backup """Update the info about a given version of the user's backup
Args: Args:
user_id(str): the user whose current backup version we're updating user_id: the user whose current backup version we're updating
version(str): the backup version we're updating version: the backup version we're updating
version_info(dict): the new information about the backup version_info: the new information about the backup
Raises: Raises:
NotFoundError: if the requested backup version doesn't exist NotFoundError: if the requested backup version doesn't exist
Returns: Returns:
A deferred of an empty dict. An empty dict.
""" """
if "version" not in version_info: if "version" not in version_info:
version_info["version"] = version version_info["version"] = version

View file

@ -791,7 +791,7 @@ def tag_args(func):
@wraps(func) @wraps(func)
def _tag_args_inner(*args, **kwargs): def _tag_args_inner(*args, **kwargs):
argspec = inspect.getargspec(func) argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]): for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i]) set_tag("ARG_" + arg, args[i])
set_tag("args", args[len(argspec.args) :]) set_tag("args", args[len(argspec.args) :])

View file

@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk( async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Dict[str, dict]]: ) -> Dict[str, Optional[Dict[str, dict]]]:
"""Returns the cross-signing keys for a set of users. """Returns the cross-signing keys for a set of users.
Args: Args:
@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def claim_e2e_one_time_keys( async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]] self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, bytes]]]: ) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Take a list of one time keys out of the database. """Take a list of one time keys out of the database.
Args: Args: