From 9d1e4942ab728ebfe09ff9a63c66708ceaaf7591 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2020 14:03:08 +0100 Subject: [PATCH] Fix typing for notifier (#8064) --- changelog.d/8064.misc | 1 + .../federation/sender/transaction_manager.py | 7 ++++-- synapse/notifier.py | 12 ++++++---- synapse/types.py | 23 +++++++++++++------ synapse/util/metrics.py | 9 +++++--- tox.ini | 2 ++ 6 files changed, 38 insertions(+), 16 deletions(-) create mode 100644 changelog.d/8064.misc diff --git a/changelog.d/8064.misc b/changelog.d/8064.misc new file mode 100644 index 0000000000..41a27e5d72 --- /dev/null +++ b/changelog.d/8064.misc @@ -0,0 +1 @@ +Add type hints to `Notifier`. diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 8280f8b900..c7f6cb3d73 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Tuple from canonicaljson import json @@ -54,7 +54,10 @@ class TransactionManager(object): @measure_func("_send_new_transaction") async def send_new_transaction( - self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] + self, + destination: str, + pending_pdus: List[Tuple[EventBase, int]], + pending_edus: List[Edu], ): # Make a transaction-sending opentracing span. This span follows on from diff --git a/synapse/notifier.py b/synapse/notifier.py index 694efe7116..dfb096e589 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -25,6 +25,7 @@ from typing import ( Set, Tuple, TypeVar, + Union, ) from prometheus_client import Counter @@ -186,7 +187,7 @@ class Notifier(object): self.store = hs.get_datastore() self.pending_new_room_events = ( [] - ) # type: List[Tuple[int, EventBase, Collection[str]]] + ) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -246,7 +247,7 @@ class Notifier(object): event: EventBase, room_stream_id: int, max_room_stream_id: int, - extra_users: Collection[str] = [], + extra_users: Collection[Union[str, UserID]] = [], ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -282,7 +283,10 @@ class Notifier(object): self._on_new_room_event(event, room_stream_id, extra_users) def _on_new_room_event( - self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = [] + self, + event: EventBase, + room_stream_id: int, + extra_users: Collection[Union[str, UserID]] = [], ): """Notify any user streams that are interested in this room event""" # poke any interested application service. @@ -310,7 +314,7 @@ class Notifier(object): self, stream_key: str, new_token: int, - users: Collection[str] = [], + users: Collection[Union[str, UserID]] = [], rooms: Collection[str] = [], ): """ Used to inform listeners that something has happened event wise. diff --git a/synapse/types.py b/synapse/types.py index 238b938064..9e580f4295 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -13,11 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import re import string import sys from collections import namedtuple -from typing import Any, Dict, Tuple, TypeVar +from typing import Any, Dict, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -33,7 +34,7 @@ else: T_co = TypeVar("T_co", covariant=True) - class Collection(Iterable[T_co], Container[T_co], Sized): + class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore __slots__ = () @@ -141,6 +142,9 @@ def get_localpart_from_id(string): return string[1:idx] +DS = TypeVar("DS", bound="DomainSpecificString") + + class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -151,6 +155,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 'domain' : The domain part of the name """ + __metaclass__ = abc.ABCMeta + + SIGIL = abc.abstractproperty() # type: str # type: ignore + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) @@ -166,7 +174,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s: str): + def from_string(cls: Type[DS], s: str) -> DS: """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( @@ -190,12 +198,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom # names on one HS return cls(localpart=parts[0], domain=domain) - def to_string(self): + def to_string(self) -> str: """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) @classmethod - def is_valid(cls, s): + def is_valid(cls: Type[DS], s: str) -> bool: try: cls.from_string(s) return True @@ -235,8 +243,9 @@ class GroupID(DomainSpecificString): SIGIL = "+" @classmethod - def from_string(cls, s): - group_id = super(GroupID, cls).from_string(s) + def from_string(cls: Type[DS], s: str) -> DS: + group_id = super().from_string(s) # type: DS # type: ignore + if not group_id.localpart: raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index a805f51df1..13775b43f9 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -15,6 +15,7 @@ import logging from functools import wraps +from typing import Any, Callable, Optional, TypeVar, cast from prometheus_client import Counter @@ -57,8 +58,10 @@ in_flight = InFlightGauge( sub_metrics=["real_time_max", "real_time_sum"], ) +T = TypeVar("T", bound=Callable[..., Any]) -def measure_func(name=None): + +def measure_func(name: Optional[str] = None) -> Callable[[T], T]: """ Used to decorate an async function with a `Measure` context manager. @@ -76,7 +79,7 @@ def measure_func(name=None): """ - def wrapper(func): + def wrapper(func: T) -> T: block_name = func.__name__ if name is None else name @wraps(func) @@ -85,7 +88,7 @@ def measure_func(name=None): r = await func(self, *args, **kwargs) return r - return measured_func + return cast(T, measured_func) return wrapper diff --git a/tox.ini b/tox.ini index 217590edef..45e129580f 100644 --- a/tox.ini +++ b/tox.ini @@ -212,7 +212,9 @@ commands = mypy \ synapse/storage/state.py \ synapse/storage/util \ synapse/streams \ + synapse/types.py \ synapse/util/caches/stream_change_cache.py \ + synapse/util/metrics.py \ tests/replication \ tests/test_utils \ tests/rest/client/v2_alpha/test_auth.py \