Add type hints to the crypto module. (#8999)

This commit is contained in:
Patrick Cloke 2021-01-04 10:04:50 -05:00 committed by GitHub
parent a685bbb018
commit 1c9a850562
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 158 additions and 113 deletions

1
changelog.d/8999.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to the crypto module.

View file

@ -17,6 +17,7 @@ files =
synapse/api, synapse/api,
synapse/appservice, synapse/appservice,
synapse/config, synapse/config,
synapse/crypto,
synapse/event_auth.py, synapse/event_auth.py,
synapse/events/builder.py, synapse/events/builder.py,
synapse/events/validator.py, synapse/events/validator.py,
@ -75,6 +76,7 @@ files =
synapse/storage/background_updates.py, synapse/storage/background_updates.py,
synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py, synapse/storage/databases/main/events.py,
synapse/storage/databases/main/keys.py,
synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py, synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,

View file

@ -227,7 +227,7 @@ class ConnectionVerifier:
# This code is based on twisted.internet.ssl.ClientTLSOptions. # This code is based on twisted.internet.ssl.ClientTLSOptions.
def __init__(self, hostname: bytes, verify_certs): def __init__(self, hostname: bytes, verify_certs: bool):
self._verify_certs = verify_certs self._verify_certs = verify_certs
_decoded = hostname.decode("ascii") _decoded = hostname.decode("ascii")

View file

@ -18,7 +18,7 @@
import collections.abc import collections.abc
import hashlib import hashlib
import logging import logging
from typing import Dict from typing import Any, Callable, Dict, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.utils import prune_event, prune_event_dict from synapse.events.utils import prune_event, prune_event_dict
from synapse.types import JsonDict from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Hasher = Callable[[bytes], "hashlib._Hash"]
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
def check_event_content_hash(
event: EventBase, hash_algorithm: Hasher = hashlib.sha256
) -> bool:
"""Check whether the hash for this PDU matches the contents""" """Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm) name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
logger.debug( logger.debug(
@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
return message_hash_bytes == expected_hash return message_hash_bytes == expected_hash
def compute_content_hash(event_dict, hash_algorithm): def compute_content_hash(
event_dict: Dict[str, Any], hash_algorithm: Hasher
) -> Tuple[str, bytes]:
"""Compute the content hash of an event, which is the hash of the """Compute the content hash of an event, which is the hash of the
unredacted event. unredacted event.
Args: Args:
event_dict (dict): The unredacted event as a dict event_dict: The unredacted event as a dict
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event to hash the event
Returns: Returns:
tuple[str, bytes]: A tuple of the name of hash and the hash as raw A tuple of the name of hash and the hash as raw bytes.
bytes.
""" """
event_dict = dict(event_dict) event_dict = dict(event_dict)
event_dict.pop("age_ts", None) event_dict.pop("age_ts", None)
@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
return hashed.name, hashed.digest() return hashed.name, hashed.digest()
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): def compute_event_reference_hash(
event, hash_algorithm: Hasher = hashlib.sha256
) -> Tuple[str, bytes]:
"""Computes the event reference hash. This is the hash of the redacted """Computes the event reference hash. This is the hash of the redacted
event. event.
Args: Args:
event (FrozenEvent) event
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event to hash the event
Returns: Returns:
tuple[str, bytes]: A tuple of the name of hash and the hash as raw A tuple of the name of hash and the hash as raw bytes.
bytes.
""" """
tmp_event = prune_event(event) tmp_event = prune_event(event)
event_dict = tmp_event.get_pdu_json() event_dict = tmp_event.get_pdu_json()
@ -156,7 +163,7 @@ def add_hashes_and_signatures(
event_dict: JsonDict, event_dict: JsonDict,
signature_name: str, signature_name: str,
signing_key: SigningKey, signing_key: SigningKey,
): ) -> None:
"""Add content hash and sign the event """Add content hash and sign the event
Args: Args:

View file

@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
import urllib import urllib
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr import attr
from signedjson.key import ( from signedjson.key import (
@ -40,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.config.key import TrustedKeyServer
from synapse.logging.context import ( from synapse.logging.context import (
PreserveLoggingContext, PreserveLoggingContext,
make_deferred_yieldable, make_deferred_yieldable,
@ -47,11 +50,15 @@ from synapse.logging.context import (
run_in_background, run_in_background,
) )
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,16 +68,17 @@ class VerifyJsonRequest:
A request to verify a JSON object. A request to verify a JSON object.
Attributes: Attributes:
server_name(str): The name of the server to verify against. server_name: The name of the server to verify against.
key_ids(set[str]): The set of key_ids to that could be used to verify the json_object: The JSON object to verify.
JSON object
json_object(dict): The JSON object to verify. minimum_valid_until_ts: time at which we require the signing key to
minimum_valid_until_ts (int): time at which we require the signing key to
be valid. (0 implies we don't care) be valid. (0 implies we don't care)
request_name: The name of the request.
key_ids: The set of key_ids to that could be used to verify the JSON object
key_ready (Deferred[str, str, nacl.signing.VerifyKey]): key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no a verify key has been fetched. The deferreds' callbacks are run with no
@ -80,12 +88,12 @@ class VerifyJsonRequest:
errbacks with an M_UNAUTHORIZED SynapseError. errbacks with an M_UNAUTHORIZED SynapseError.
""" """
server_name = attr.ib() server_name = attr.ib(type=str)
json_object = attr.ib() json_object = attr.ib(type=JsonDict)
minimum_valid_until_ts = attr.ib() minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib() request_name = attr.ib(type=str)
key_ids = attr.ib(init=False) key_ids = attr.ib(init=False, type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred)) key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name) self.key_ids = signature_ids(self.json_object, self.server_name)
@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
class Keyring: class Keyring:
def __init__(self, hs, key_fetchers=None): def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
):
self.clock = hs.get_clock() self.clock = hs.get_clock()
if key_fetchers is None: if key_fetchers is None:
@ -112,22 +122,26 @@ class Keyring:
# completes. # completes.
# #
# These are regular, logcontext-agnostic Deferreds. # These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {} self.key_downloads = {} # type: Dict[str, defer.Deferred]
def verify_json_for_server( def verify_json_for_server(
self, server_name, json_object, validity_time, request_name self,
): server_name: str,
json_object: JsonDict,
validity_time: int,
request_name: str,
) -> defer.Deferred:
"""Verify that a JSON object has been signed by a given server """Verify that a JSON object has been signed by a given server
Args: Args:
server_name (str): name of the server which must have signed this object server_name: name of the server which must have signed this object
json_object (dict): object to be checked json_object: object to be checked
validity_time (int): timestamp at which we require the signing key to validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care) be valid. (0 implies we don't care)
request_name (str): an identifier for this json object (eg, an event id) request_name: an identifier for this json object (eg, an event id)
for logging. for logging.
Returns: Returns:
@ -138,12 +152,14 @@ class Keyring:
requests = (req,) requests = (req,)
return make_deferred_yieldable(self._verify_objects(requests)[0]) return make_deferred_yieldable(self._verify_objects(requests)[0])
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as """Bulk verifies signatures of json objects, bulk fetching keys as
necessary. necessary.
Args: Args:
server_and_json (iterable[Tuple[str, dict, int, str]): server_and_json:
Iterable of (server_name, json_object, validity_time, request_name) Iterable of (server_name, json_object, validity_time, request_name)
tuples. tuples.
@ -164,13 +180,14 @@ class Keyring:
for server_name, json_object, validity_time, request_name in server_and_json for server_name, json_object, validity_time, request_name in server_and_json
) )
def _verify_objects(self, verify_requests): def _verify_objects(
self, verify_requests: Iterable[VerifyJsonRequest]
) -> List[defer.Deferred]:
"""Does the work of verify_json_[objects_]for_server """Does the work of verify_json_[objects_]for_server
Args: Args:
verify_requests (iterable[VerifyJsonRequest]): verify_requests: Iterable of verification requests.
Iterable of verification requests.
Returns: Returns:
List<Deferred[None]>: for each input item, a deferred indicating success List<Deferred[None]>: for each input item, a deferred indicating success
@ -182,7 +199,7 @@ class Keyring:
key_lookups = [] key_lookups = []
handle = preserve_fn(_handle_key_deferred) handle = preserve_fn(_handle_key_deferred)
def process(verify_request): def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
"""Process an entry in the request list """Process an entry in the request list
Adds a key request to key_lookups, and returns a deferred which Adds a key request to key_lookups, and returns a deferred which
@ -222,18 +239,20 @@ class Keyring:
return results return results
async def _start_key_lookups(self, verify_requests): async def _start_key_lookups(
self, verify_requests: List[VerifyJsonRequest]
) -> None:
"""Sets off the key fetches for each verify request """Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved. Once each fetch completes, verify_request.key_ready will be resolved.
Args: Args:
verify_requests (List[VerifyJsonRequest]): verify_requests:
""" """
try: try:
# map from server name to a set of outstanding request ids # map from server name to a set of outstanding request ids
server_to_request_ids = {} server_to_request_ids = {} # type: Dict[str, Set[int]]
for verify_request in verify_requests: for verify_request in verify_requests:
server_name = verify_request.server_name server_name = verify_request.server_name
@ -275,11 +294,11 @@ class Keyring:
except Exception: except Exception:
logger.exception("Error starting key lookups") logger.exception("Error starting key lookups")
async def wait_for_previous_lookups(self, server_names) -> None: async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
Args: Args:
server_names (Iterable[str]): list of servers which we want to look up server_names: list of servers which we want to look up
Returns: Returns:
Resolves once all key lookups for the given servers have Resolves once all key lookups for the given servers have
@ -304,7 +323,7 @@ class Keyring:
loop_count += 1 loop_count += 1
def _get_server_verify_keys(self, verify_requests): def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
"""Tries to find at least one key for each verify request """Tries to find at least one key for each verify request
For each verify_request, verify_request.key_ready is called back with For each verify_request, verify_request.key_ready is called back with
@ -312,7 +331,7 @@ class Keyring:
with a SynapseError if none of the keys are found. with a SynapseError if none of the keys are found.
Args: Args:
verify_requests (list[VerifyJsonRequest]): list of verify requests verify_requests: list of verify requests
""" """
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@ -366,17 +385,19 @@ class Keyring:
run_in_background(do_iterations) run_in_background(do_iterations)
async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): async def _attempt_key_fetches_with_fetcher(
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
):
"""Use a key fetcher to attempt to satisfy some key requests """Use a key fetcher to attempt to satisfy some key requests
Args: Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys fetcher: fetcher to use to fetch the keys
remaining_requests (set[VerifyJsonRequest]): outstanding key requests. remaining_requests: outstanding key requests.
Any successfully-completed requests will be removed from the list. Any successfully-completed requests will be removed from the list.
""" """
# dict[str, dict[str, int]]: keys to fetch. # The keys to fetch.
# server_name -> key_id -> min_valid_ts # server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict) missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
for verify_request in remaining_requests: for verify_request in remaining_requests:
# any completed requests should already have been removed # any completed requests should already have been removed
@ -438,16 +459,18 @@ class Keyring:
remaining_requests.difference_update(completed) remaining_requests.difference_update(completed)
class KeyFetcher: class KeyFetcher(metaclass=abc.ABCMeta):
async def get_keys(self, keys_to_fetch): @abc.abstractmethod
async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
""" """
Args: Args:
keys_to_fetch (dict[str, dict[str, int]]): keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: Map from server_name -> key_id -> FetchKeyResult
map from server_name -> key_id -> FetchKeyResult
""" """
raise NotImplementedError raise NotImplementedError
@ -455,31 +478,35 @@ class KeyFetcher:
class StoreKeyFetcher(KeyFetcher): class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store""" """KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def get_keys(self, keys_to_fetch): async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
keys_to_fetch = ( key_ids_to_fetch = (
(server_name, key_id) (server_name, key_id)
for server_name, keys_for_server in keys_to_fetch.items() for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys() for key_id in keys_for_server.keys()
) )
res = await self.store.get_server_verify_keys(keys_to_fetch) res = await self.store.get_server_verify_keys(key_ids_to_fetch)
keys = {} keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key keys.setdefault(server_name, {})[key_id] = key
return keys return keys
class BaseV2KeyFetcher: class BaseV2KeyFetcher(KeyFetcher):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.config = hs.get_config() self.config = hs.get_config()
async def process_v2_response(self, from_server, response_json, time_added_ms): async def process_v2_response(
self, from_server: str, response_json: JsonDict, time_added_ms: int
) -> Dict[str, FetchKeyResult]:
"""Parse a 'Server Keys' structure from the result of a /key request """Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from This is used to parse either the entirety of the response from
@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
to /_matrix/key/v2/query. to /_matrix/key/v2/query.
Args: Args:
from_server (str): the name of the server producing this result: either from_server: the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query. for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object response_json: the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json time_added_ms: the timestamp to record in server_keys_json
Returns: Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object Map from key_id to result object
""" """
ts_valid_until_ms = response_json["valid_until_ts"] ts_valid_until_ms = response_json["valid_until_ts"]
@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
class PerspectivesKeyFetcher(BaseV2KeyFetcher): class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers""" """KeyFetcher impl which fetches keys from the "perspectives" servers"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers self.key_servers = self.config.key_servers
async def get_keys(self, keys_to_fetch): async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
async def get_key(key_server): async def get_key(key_server: TrustedKeyServer) -> Dict:
try: try:
result = await self.get_server_verify_key_v2_indirect( return await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server keys_to_fetch, key_server
) )
return result
except KeyLookupError as e: except KeyLookupError as e:
logger.warning( logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e "Key lookup failed from %r: %s", key_server.server_name, e
@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
) )
union_of_keys = {} union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for result in results: for result in results:
for server_name, keys in result.items(): for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys) union_of_keys.setdefault(server_name, {}).update(keys)
return union_of_keys return union_of_keys
async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): async def get_server_verify_key_v2_indirect(
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
) -> Dict[str, Dict[str, FetchKeyResult]]:
""" """
Args: Args:
keys_to_fetch (dict[str, dict[str, int]]): keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts the keys to be fetched. server_name -> key_id -> min_valid_ts
key_server (synapse.config.key.TrustedKeyServer): notary server to query for key_server: notary server to query for the keys
the keys
Returns: Returns:
dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map Map from server_name -> key_id -> FetchKeyResult
from server_name -> key_id -> FetchKeyResult
Raises: Raises:
KeyLookupError if there was an error processing the entire response from KeyLookupError if there was an error processing the entire response from
@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e: except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,)) raise KeyLookupError("Remote server returned an error: %s" % (e,))
keys = {} keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
added_keys = [] added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
assert isinstance(query_response, dict)
for response in query_response["server_keys"]: for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter # do this first, so that we can give useful errors thereafter
server_name = response.get("server_name") server_name = response.get("server_name")
@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return keys return keys
def _validate_perspectives_response(self, key_server, response): def _validate_perspectives_response(
self, key_server: TrustedKeyServer, response: JsonDict
) -> None:
"""Optionally check the signature on the result of a /key/query request """Optionally check the signature on the result of a /key/query request
Args: Args:
key_server (synapse.config.key.TrustedKeyServer): the notary server that key_server: the notary server that produced this result
produced this result
response (dict): the json-decoded Server Keys response object response: the json-decoded Server Keys response object
""" """
perspective_name = key_server.server_name perspective_name = key_server.server_name
perspective_keys = key_server.verify_keys perspective_keys = key_server.verify_keys
@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
class ServerKeyFetcher(BaseV2KeyFetcher): class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers""" """KeyFetcher impl which fetches keys from the origin servers"""
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
async def get_keys(self, keys_to_fetch): async def get_keys(
self, keys_to_fetch: Dict[str, Dict[str, int]]
) -> Dict[str, Dict[str, FetchKeyResult]]:
""" """
Args: Args:
keys_to_fetch (dict[str, iterable[str]]): keys_to_fetch:
the keys to be fetched. server_name -> key_ids the keys to be fetched. server_name -> key_ids
Returns: Returns:
dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: Map from server_name -> key_id -> FetchKeyResult
map from server_name -> key_id -> FetchKeyResult
""" """
results = {} results = {}
async def get_key(key_to_fetch_item): async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
server_name, key_ids = key_to_fetch_item server_name, key_ids = key_to_fetch_item
try: try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
await yieldable_gather_results(get_key, keys_to_fetch.items()) await yieldable_gather_results(get_key, keys_to_fetch.items())
return results return results
async def get_server_verify_key_v2_direct(self, server_name, key_ids): async def get_server_verify_key_v2_direct(
self, server_name: str, key_ids: Iterable[str]
) -> Dict[str, FetchKeyResult]:
""" """
Args: Args:
server_name (str): server_name:
key_ids (iterable[str]): key_ids:
Returns: Returns:
dict[str, FetchKeyResult]: map from key ID to lookup result Map from key ID to lookup result
Raises: Raises:
KeyLookupError if there was a problem making the lookup KeyLookupError if there was a problem making the lookup
""" """
keys = {} # type: dict[str, FetchKeyResult] keys = {} # type: Dict[str, FetchKeyResult]
for requested_key_id in key_ids: for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another. # we may have found this key as a side-effect of asking for another.
@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e: except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,)) raise KeyLookupError("Remote server returned an error: %s" % (e,))
assert isinstance(response, dict)
if response["server_name"] != server_name: if response["server_name"] != server_name:
raise KeyLookupError( raise KeyLookupError(
"Expected a response for server %r not %r" "Expected a response for server %r not %r"
@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys return keys
async def _handle_key_deferred(verify_request) -> None: async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification """Waits for the key to become available, and then performs a verification
Args: Args:
verify_request (VerifyJsonRequest): verify_request:
Raises: Raises:
SynapseError if there was a problem performing the verification SynapseError if there was a problem performing the verification

View file

@ -144,7 +144,7 @@ class Authenticator:
): ):
raise FederationDeniedError(origin) raise FederationDeniedError(origin)
if not json_request["signatures"]: if origin is None or not json_request["signatures"]:
raise NoAuthenticationError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED 401, "Missing Authorization headers", Codes.UNAUTHORIZED
) )

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Set from typing import Dict
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
cache_misses = {} # type: Dict[str, Set[str]] # Note that the value is unused.
cache_misses = {} # type: Dict[str, Dict[str, int]]
for (server_name, key_id, from_server), results in cached.items(): for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results] results = [(result["ts_added_ms"], result) for result in results]
if not results and key_id is not None: if not results and key_id is not None:
cache_misses.setdefault(server_name, set()).add(key_id) cache_misses.setdefault(server_name, {})[key_id] = 0
continue continue
if key_id is not None: if key_id is not None:
@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
) )
if miss: if miss:
cache_misses.setdefault(server_name, set()).add(key_id) cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview. # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"])) json_results.add(bytes(most_recent_result["key_json"]))
else: else:

View file

@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore):
) )
async def get_server_verify_keys( async def get_server_verify_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]] self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: ) -> Dict[Tuple[str, str], FetchKeyResult]:
""" """
Args: Args:
server_name_and_key_ids: server_name_and_key_ids:
@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore):
""" """
keys = {} keys = {}
def _get_keys(txn, batch): def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`.""" """Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch) # batch_iter always returns tuples so it's safe to do len(batch)
@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore):
# `ts_valid_until_ms`. # `ts_valid_until_ms`.
ts_valid_until_ms = 0 ts_valid_until_ms = 0
res = FetchKeyResult( keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms, valid_until_ts=ts_valid_until_ms,
) )
keys[(server_name, key_id)] = res
def _txn(txn): def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50): for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch) _get_keys(txn, batch)
return keys return keys

View file

@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
return val return val
def test_verify_json_objects_for_server_awaits_previous_requests(self): def test_verify_json_objects_for_server_awaits_previous_requests(self):
mock_fetcher = keyring.KeyFetcher() mock_fetcher = Mock()
mock_fetcher.get_keys = Mock() mock_fetcher.get_keys = Mock()
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Tests that we correctly handle key requests for keys we've stored """Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms` with a null `ts_valid_until_ms`
""" """
mock_fetcher = keyring.KeyFetcher() mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring( kr = keyring.Keyring(
@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
} }
} }
mock_fetcher = keyring.KeyFetcher() mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(side_effect=get_keys) mock_fetcher.get_keys = Mock(side_effect=get_keys)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
} }
} }
mock_fetcher1 = keyring.KeyFetcher() mock_fetcher1 = Mock()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1) mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
mock_fetcher2 = keyring.KeyFetcher() mock_fetcher2 = Mock()
mock_fetcher2.get_keys = Mock(side_effect=get_keys2) mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))