Add typing information to the device handler. (#8407)

This commit is contained in:
Patrick Cloke 2020-10-07 08:58:21 -04:00 committed by GitHub
parent 9ca6341969
commit b460a088c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 38 deletions

1
changelog.d/8407.misc Normal file
View file

@ -0,0 +1 @@
Add typing information to the device handler.

View file

@ -17,6 +17,7 @@ files =
synapse/federation, synapse/federation,
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,
synapse/handlers/device.py,
synapse/handlers/directory.py, synapse/handlers/directory.py,
synapse/handlers/events.py, synapse/handlers/events.py,
synapse/handlers/federation.py, synapse/handlers/federation.py,

View file

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -29,8 +29,10 @@ from synapse.api.errors import (
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ( from synapse.types import (
Collection,
JsonDict, JsonDict,
StreamToken, StreamToken,
UserID,
get_domain_from_id, get_domain_from_id,
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
) )
@ -42,13 +44,16 @@ from synapse.util.retryutils import NotRetryingDestination
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_DEVICE_DISPLAY_NAME_LEN = 100 MAX_DEVICE_DISPLAY_NAME_LEN = 100
class DeviceWorkerHandler(BaseHandler): class DeviceWorkerHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.hs = hs self.hs = hs
@ -106,7 +111,9 @@ class DeviceWorkerHandler(BaseHandler):
@trace @trace
@measure_func("device.get_user_ids_changed") @measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id: str, from_token: StreamToken): async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
) -> JsonDict:
"""Get list of users that have had the devices updated, or have newly """Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in. joined a room, that `user_id` may be interested in.
""" """
@ -222,8 +229,8 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = possibly_changed & users_who_share_room possibly_joined = possibly_changed & users_who_share_room
possibly_left = (possibly_changed | possibly_left) - users_who_share_room possibly_left = (possibly_changed | possibly_left) - users_who_share_room
else: else:
possibly_joined = [] possibly_joined = set()
possibly_left = [] possibly_left = set()
result = {"changed": list(possibly_joined), "left": list(possibly_left)} result = {"changed": list(possibly_joined), "left": list(possibly_left)}
@ -231,7 +238,7 @@ class DeviceWorkerHandler(BaseHandler):
return result return result
async def on_federation_query_user_devices(self, user_id): async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
user_id user_id
) )
@ -250,7 +257,7 @@ class DeviceWorkerHandler(BaseHandler):
class DeviceHandler(DeviceWorkerHandler): class DeviceHandler(DeviceWorkerHandler):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
@ -265,7 +272,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
def _check_device_name_length(self, name: str): def _check_device_name_length(self, name: Optional[str]):
""" """
Checks whether a device name is longer than the maximum allowed length. Checks whether a device name is longer than the maximum allowed length.
@ -284,8 +291,11 @@ class DeviceHandler(DeviceWorkerHandler):
) )
async def check_device_registered( async def check_device_registered(
self, user_id, device_id, initial_device_display_name=None self,
): user_id: str,
device_id: Optional[str],
initial_device_display_name: Optional[str] = None,
) -> str:
""" """
If the given device has not been registered, register it with the If the given device has not been registered, register it with the
supplied display name. supplied display name.
@ -293,12 +303,11 @@ class DeviceHandler(DeviceWorkerHandler):
If no device_id is supplied, we make one up. If no device_id is supplied, we make one up.
Args: Args:
user_id (str): @user:id user_id: @user:id
device_id (str | None): device id supplied by client device_id: device id supplied by client
initial_device_display_name (str | None): device display name from initial_device_display_name: device display name from client
client
Returns: Returns:
str: device id (generated if none was supplied) device id (generated if none was supplied)
""" """
self._check_device_name_length(initial_device_display_name) self._check_device_name_length(initial_device_display_name)
@ -317,15 +326,15 @@ class DeviceHandler(DeviceWorkerHandler):
# times in case of a clash. # times in case of a clash.
attempts = 0 attempts = 0
while attempts < 5: while attempts < 5:
device_id = stringutils.random_string(10).upper() new_device_id = stringutils.random_string(10).upper()
new_device = await self.store.store_device( new_device = await self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=new_device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
) )
if new_device: if new_device:
await self.notify_device_update(user_id, [device_id]) await self.notify_device_update(user_id, [new_device_id])
return device_id return new_device_id
attempts += 1 attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.") raise errors.StoreError(500, "Couldn't generate a device ID.")
@ -434,7 +443,9 @@ class DeviceHandler(DeviceWorkerHandler):
@trace @trace
@measure_func("notify_device_update") @measure_func("notify_device_update")
async def notify_device_update(self, user_id, device_ids): async def notify_device_update(
self, user_id: str, device_ids: Collection[str]
) -> None:
"""Notify that a user's device(s) has changed. Pokes the notifier, and """Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local. remote servers if the user is local.
""" """
@ -446,7 +457,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id user_id
) )
hosts = set() hosts = set() # type: Set[str]
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name) hosts.discard(self.server_name)
@ -498,7 +509,7 @@ class DeviceHandler(DeviceWorkerHandler):
self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
async def user_left_room(self, user, room_id): async def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string() user_id = user.to_string()
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids: if not room_ids:
@ -586,7 +597,9 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True} return {"success": True}
def _update_device_from_client_ips(device, client_ips): def _update_device_from_client_ips(
device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@ -594,7 +607,7 @@ def _update_device_from_client_ips(device, client_ips):
class DeviceListUpdater: class DeviceListUpdater:
"Handles incoming device list updates from federation and updates the DB" "Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs, device_handler): def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -603,7 +616,9 @@ class DeviceListUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_device_list") self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled. # user_id -> list of updates waiting to be handled.
self._pending_updates = {} self._pending_updates = (
{}
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB, # Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious # but they're useful to have them about to reduce the number of spurious
@ -626,7 +641,9 @@ class DeviceListUpdater:
) )
@trace @trace
async def incoming_device_list_update(self, origin, edu_content): async def incoming_device_list_update(
self, origin: str, edu_content: JsonDict
) -> None:
"""Called on incoming device list update from federation. Responsible """Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list. for parsing the EDU and adding to pending updates list.
""" """
@ -687,7 +704,7 @@ class DeviceListUpdater:
await self._handle_device_updates(user_id) await self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update") @measure_func("_incoming_device_list_update")
async def _handle_device_updates(self, user_id): async def _handle_device_updates(self, user_id: str) -> None:
"Actually handle pending updates." "Actually handle pending updates."
with (await self._remote_edu_linearizer.queue(user_id)): with (await self._remote_edu_linearizer.queue(user_id)):
@ -735,7 +752,9 @@ class DeviceListUpdater:
stream_id for _, stream_id, _, _ in pending_updates stream_id for _, stream_id, _, _ in pending_updates
) )
async def _need_to_do_resync(self, user_id, updates): async def _need_to_do_resync(
self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
) -> bool:
"""Given a list of updates for a user figure out if we need to do a full """Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta. resync, or whether we have enough data that we can just apply the delta.
""" """
@ -766,7 +785,7 @@ class DeviceListUpdater:
return False return False
@trace @trace
async def _maybe_retry_device_resync(self): async def _maybe_retry_device_resync(self) -> None:
"""Retry to resync device lists that are out of sync, except if another retry is """Retry to resync device lists that are out of sync, except if another retry is
in progress. in progress.
""" """
@ -809,7 +828,7 @@ class DeviceListUpdater:
async def user_device_resync( async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[dict]: ) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them. """Fetches all devices for a user and updates the device cache with them.
Args: Args:
@ -833,7 +852,7 @@ class DeviceListUpdater:
# it later. # it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id) await self.store.mark_remote_user_device_cache_as_stale(user_id)
return return None
except (RequestSendFailed, HttpResponseException) as e: except (RequestSendFailed, HttpResponseException) as e:
logger.warning( logger.warning(
"Failed to handle device list update for %s: %s", user_id, e, "Failed to handle device list update for %s: %s", user_id, e,
@ -850,12 +869,12 @@ class DeviceListUpdater:
# next time we get a device list update for this user_id. # next time we get a device list update for this user_id.
# This makes it more likely that the device lists will # This makes it more likely that the device lists will
# eventually become consistent. # eventually become consistent.
return return None
except FederationDeniedError as e: except FederationDeniedError as e:
set_tag("error", True) set_tag("error", True)
log_kv({"reason": "FederationDeniedError"}) log_kv({"reason": "FederationDeniedError"})
logger.info(e) logger.info(e)
return return None
except Exception as e: except Exception as e:
set_tag("error", True) set_tag("error", True)
log_kv( log_kv(
@ -868,7 +887,7 @@ class DeviceListUpdater:
# it later. # it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id) await self.store.mark_remote_user_device_cache_as_stale(user_id)
return return None
log_kv({"result": result}) log_kv({"result": result})
stream_id = result["stream_id"] stream_id = result["stream_id"]
devices = result["devices"] devices = result["devices"]
@ -929,7 +948,7 @@ class DeviceListUpdater:
user_id: str, user_id: str,
master_key: Optional[Dict[str, Any]], master_key: Optional[Dict[str, Any]],
self_signing_key: Optional[Dict[str, Any]], self_signing_key: Optional[Dict[str, Any]],
) -> list: ) -> List[str]:
"""Process the given new master and self-signing key for the given remote user. """Process the given new master and self-signing key for the given remote user.
Args: Args:

View file

@ -911,7 +911,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
async def store_device( async def store_device(
self, user_id: str, device_id: str, initial_device_display_name: str self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
) -> bool: ) -> bool:
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
@ -1029,7 +1029,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
async def update_remote_device_list_cache_entry( async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int self, user_id: str, device_id: str, content: JsonDict, stream_id: str
) -> None: ) -> None:
"""Updates a single device in the cache of a remote user's devicelist. """Updates a single device in the cache of a remote user's devicelist.
@ -1057,7 +1057,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str, user_id: str,
device_id: str, device_id: str,
content: JsonDict, content: JsonDict,
stream_id: int, stream_id: str,
) -> None: ) -> None:
if content.get("deleted"): if content.get("deleted"):
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(