Add most missing type hints to synapse.util (#11328)

This commit is contained in:
Patrick Cloke 2021-11-16 08:47:36 -05:00 committed by GitHub
parent 3a1462f7e0
commit 7468723697
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 161 additions and 165 deletions

1
changelog.d/11328.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `synapse.util`.

View file

@ -196,92 +196,11 @@ disallow_untyped_defs = True
[mypy-synapse.streams.*]
disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]
[mypy-synapse.util.*]
disallow_untyped_defs = True
[mypy-synapse.util.caches.cached_call]
disallow_untyped_defs = True
[mypy-synapse.util.caches.dictionary_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.lrucache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.response_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.stream_change_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.ttl_cache]
disallow_untyped_defs = True
[mypy-synapse.util.daemonize]
disallow_untyped_defs = True
[mypy-synapse.util.file_consumer]
disallow_untyped_defs = True
[mypy-synapse.util.frozenutils]
disallow_untyped_defs = True
[mypy-synapse.util.hash]
disallow_untyped_defs = True
[mypy-synapse.util.httpresourcetree]
disallow_untyped_defs = True
[mypy-synapse.util.iterutils]
disallow_untyped_defs = True
[mypy-synapse.util.linked_list]
disallow_untyped_defs = True
[mypy-synapse.util.logcontext]
disallow_untyped_defs = True
[mypy-synapse.util.logformatter]
disallow_untyped_defs = True
[mypy-synapse.util.macaroons]
disallow_untyped_defs = True
[mypy-synapse.util.manhole]
disallow_untyped_defs = True
[mypy-synapse.util.module_loader]
disallow_untyped_defs = True
[mypy-synapse.util.msisdn]
disallow_untyped_defs = True
[mypy-synapse.util.patch_inline_callbacks]
disallow_untyped_defs = True
[mypy-synapse.util.ratelimitutils]
disallow_untyped_defs = True
[mypy-synapse.util.retryutils]
disallow_untyped_defs = True
[mypy-synapse.util.rlimit]
disallow_untyped_defs = True
[mypy-synapse.util.stringutils]
disallow_untyped_defs = True
[mypy-synapse.util.templates]
disallow_untyped_defs = True
[mypy-synapse.util.threepids]
disallow_untyped_defs = True
[mypy-synapse.util.wheel_timer]
disallow_untyped_defs = True
[mypy-synapse.util.versionstring]
disallow_untyped_defs = True
[mypy-synapse.util.caches.treecache]
disallow_untyped_defs = False
[mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True

View file

@ -27,6 +27,7 @@ from typing import (
Generic,
Hashable,
Iterable,
Iterator,
Optional,
Set,
TypeVar,
@ -40,7 +41,6 @@ from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
from twisted.python.failure import Failure
from synapse.logging.context import (
@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", [])
def callback(r):
def callback(r: _T) -> _T:
object.__setattr__(self, "_result", (True, r))
# once we have set _result, no more entries will be added to _observers,
@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
)
return r
def errback(f):
def errback(f: Failure) -> Optional[Failure]:
object.__setattr__(self, "_result", (False, f))
# once we have set _result, no more entries will be added to _observers,
@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
f.value.__failure__ = f # type: ignore[union-attr]
try:
observer.errback(f)
except Exception as e:
@ -314,7 +314,7 @@ class Linearizer:
# will release the lock.
@contextmanager
def _ctx_manager(_):
def _ctx_manager(_: None) -> Iterator[None]:
try:
yield
finally:
@ -355,7 +355,7 @@ class Linearizer:
new_defer = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1
def cb(_r):
def cb(_r: None) -> "defer.Deferred[None]":
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry.count += 1
@ -371,7 +371,7 @@ class Linearizer:
# code must be synchronous, so this is the only sensible place.)
return self._clock.sleep(0)
def eb(e):
def eb(e: Failure) -> Failure:
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
@ -435,7 +435,7 @@ class ReadWriteLock:
await make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager():
def _ctx_manager() -> Iterator[None]:
try:
yield
finally:
@ -464,7 +464,7 @@ class ReadWriteLock:
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():
def _ctx_manager() -> Iterator[None]:
try:
yield
finally:
@ -524,7 +524,7 @@ def timeout_deferred(
delayed_call = reactor.callLater(timeout, time_it_out)
def convert_cancelled(value: failure.Failure):
def convert_cancelled(value: Failure) -> Failure:
# if the original deferred was cancelled, and our timeout has fired, then
# the reason it was cancelled was due to our timeout. Turn the CancelledError
# into a TimeoutError.
@ -534,7 +534,7 @@ def timeout_deferred(
deferred.addErrback(convert_cancelled)
def cancel_timeout(result):
def cancel_timeout(result: _T) -> _T:
# stop the pending call to cancel the deferred if it's been fired
if delayed_call.active():
delayed_call.cancel()
@ -542,11 +542,11 @@ def timeout_deferred(
deferred.addBoth(cancel_timeout)
def success_cb(val):
def success_cb(val: _T) -> None:
if not new_d.called:
new_d.callback(val)
def failure_cb(val):
def failure_cb(val: Failure) -> None:
if not new_d.called:
new_d.errback(val)
@ -557,13 +557,13 @@ def timeout_deferred(
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DoneAwaitable: # should be: Generic[R]
"""Simple awaitable that returns the provided value."""
value = attr.ib(type=Any) # should be: R
value: Any # should be: R
def __await__(self):
def __await__(self) -> Any:
return self
def __iter__(self) -> "DoneAwaitable":

View file

@ -17,7 +17,7 @@ import logging
import typing
from enum import Enum, auto
from sys import intern
from typing import Callable, Dict, Optional, Sized
from typing import Any, Callable, Dict, List, Optional, Sized
import attr
from prometheus_client.core import Gauge
@ -58,20 +58,20 @@ class EvictionReason(Enum):
time = auto()
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class CacheMetric:
_cache = attr.ib()
_cache_type = attr.ib(type=str)
_cache_name = attr.ib(type=str)
_collect_callback = attr.ib(type=Optional[Callable])
_cache: Sized
_cache_type: str
_cache_name: str
_collect_callback: Optional[Callable]
hits = attr.ib(default=0)
misses = attr.ib(default=0)
hits: int = 0
misses: int = 0
eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
factory=collections.Counter
)
memory_usage = attr.ib(default=None)
memory_usage: Optional[int] = None
def inc_hits(self) -> None:
self.hits += 1
@ -89,13 +89,14 @@ class CacheMetric:
self.memory_usage += memory
def dec_memory_usage(self, memory: int) -> None:
assert self.memory_usage is not None
self.memory_usage -= memory
def clear_memory_usage(self) -> None:
if self.memory_usage is not None:
self.memory_usage = 0
def describe(self):
def describe(self) -> List[str]:
return []
def collect(self) -> None:
@ -118,8 +119,9 @@ class CacheMetric:
self.eviction_size_by_reason[reason]
)
cache_total.labels(self._cache_name).set(self.hits + self.misses)
if getattr(self._cache, "max_size", None):
cache_max_size.labels(self._cache_name).set(self._cache.max_size)
max_size = getattr(self._cache, "max_size", None)
if max_size:
cache_max_size.labels(self._cache_name).set(max_size)
if TRACK_MEMORY_USAGE:
# self.memory_usage can be None if nothing has been inserted
@ -193,7 +195,7 @@ KNOWN_KEYS = {
}
def intern_string(string):
def intern_string(string: Optional[str]) -> Optional[str]:
"""Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None:
return None
@ -204,7 +206,7 @@ def intern_string(string):
return string
def intern_dict(dictionary):
def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""Takes a dictionary and interns well known keys and their values"""
return {
KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
@ -212,7 +214,7 @@ def intern_dict(dictionary):
}
def _intern_known_values(key, value):
def _intern_known_values(key: str, value: Any) -> Any:
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys:

View file

@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key) -> None:
def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
If the cache is backed by a regular dict, then "key" must be of

View file

@ -19,12 +19,15 @@ import logging
from typing import (
Any,
Callable,
Dict,
Generic,
Hashable,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
@ -32,6 +35,7 @@ from typing import (
from weakref import WeakValueDictionary
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
def __init__(
self,
orig: Callable[..., Any],
num_args: Optional[int],
cache_context: bool = False,
):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __init__(
self,
orig,
orig: Callable[..., Any],
max_entries: int = 1000,
cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
def __get__(self, obj, owner):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__,
max_size=self.max_entries,
@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
def _wrapped(*args: Any, **kwargs: Any) -> Any:
invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else ()
@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return r1 + r2
Args:
num_args (int): number of positional arguments (excluding ``self`` and
num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
def __init__(
self,
orig,
max_entries=1000,
num_args=None,
tree=False,
cache_context=False,
iterable=False,
orig: Callable[..., Any],
max_entries: int = 1000,
num_args: Optional[int] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
self.prune_unread_entries = prune_unread_entries
def __get__(self, obj, owner):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
def _wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
of results.
"""
def __init__(self, orig, cached_method_name, list_name, num_args=None):
def __init__(
self,
orig: Callable[..., Any],
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
):
"""
Args:
orig (function)
cached_method_name (str): The name of the cached method.
list_name (str): Name of the argument which is the bulk lookup list
num_args (int): number of positional arguments (excluding ``self``,
orig
cached_method_name: The name of the cached method.
list_name: Name of the argument which is the bulk lookup list
num_args: number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
% (self.list_name, cached_method_name)
)
def __get__(self, obj, objtype=None):
def __get__(
self, obj: Optional[Any], objtype: Optional[Type] = None
) -> Callable[..., Any]:
cached_method = getattr(obj, self.cached_method_name)
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
results = {}
def update_results_dict(res, arg):
def update_results_dict(res: Any, arg: Hashable) -> None:
results[arg] = res
# list of deferreds to wait for
@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# otherwise a tuple is used.
if num_args == 1:
def arg_to_cache_key(arg):
def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
else:
keylist = list(keyargs)
def arg_to_cache_key(arg):
def arg_to_cache_key(arg: Hashable) -> Hashable:
keylist[self.list_pos] = arg
return tuple(keylist)
@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
deferreds_map[e].callback(val)
results[e] = val
def errback(f):
def errback(f: Failure) -> Failure:
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.

View file

@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
import attr
from typing_extensions import Literal
from twisted.internet import defer
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
# Don't bother starting the loop if things never expire
return
def f():
def f() -> "defer.Deferred[None]":
return run_as_background_process(
"prune_cache_%s" % self._cache_name, self._prune_cache
)
@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
return False
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class _CacheEntry:
time = attr.ib(type=int)
value = attr.ib()
time: int
value: Any

View file

@ -18,12 +18,13 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id):
def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None:
distributor.fire("user_left_room", user=user, room_id=room_id)
@ -63,7 +64,7 @@ class Distributor:
self.pre_registration[name] = []
self.pre_registration[name].append(observer)
def fire(self, name: str, *args, **kwargs) -> None:
def fire(self, name: str, *args: Any, **kwargs: Any) -> None:
"""Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred.
@ -95,7 +96,7 @@ class Signal:
Each observer callable may return a Deferred."""
self.observers.append(observer)
def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers.
@ -103,7 +104,7 @@ class Signal:
Returns a Deferred that will complete when all the observers have
completed."""
async def do(observer):
async def do(observer: Callable[..., Any]) -> Any:
try:
return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
@ -120,5 +121,5 @@ class Signal:
defer.gatherResults(deferreds, consumeErrors=True)
)
def __repr__(self):
def __repr__(self) -> str:
return "<Signal name=%r>" % (self.name,)

View file

@ -3,23 +3,52 @@
# We copy it here as we need to instantiate `GAIResolver` manually, but it is a
# private class.
from socket import (
AF_INET,
AF_INET6,
AF_UNSPEC,
SOCK_DGRAM,
SOCK_STREAM,
AddressFamily,
SocketKind,
gaierror,
getaddrinfo,
)
from typing import (
TYPE_CHECKING,
Callable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from zope.interface import implementer
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import IHostnameResolver, IHostResolution
from twisted.internet.interfaces import (
IAddress,
IHostnameResolver,
IHostResolution,
IReactorThreads,
IResolutionReceiver,
)
from twisted.internet.threads import deferToThreadPool
if TYPE_CHECKING:
# The types below are copied from
# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
# so that the type hints can match the interfaces.
from twisted.python.runtime import platform
if platform.supportsThreads():
from twisted.python.threadpool import ThreadPool
else:
ThreadPool = object # type: ignore[misc, assignment]
@implementer(IHostResolution)
class HostResolution:
@ -27,13 +56,13 @@ class HostResolution:
The in-progress resolution of a given hostname.
"""
def __init__(self, name):
def __init__(self, name: str):
"""
Create a L{HostResolution} with the given name.
"""
self.name = name
def cancel(self):
def cancel(self) -> NoReturn:
# IHostResolution.cancel
raise NotImplementedError()
@ -62,6 +91,17 @@ _socktypeToType = {
}
_GETADDRINFO_RESULT = List[
Tuple[
AddressFamily,
SocketKind,
int,
str,
Union[Tuple[str, int], Tuple[str, int, int, int]],
]
]
@implementer(IHostnameResolver)
class GAIResolver:
"""
@ -69,7 +109,12 @@ class GAIResolver:
L{getaddrinfo} in a thread.
"""
def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
def __init__(
self,
reactor: IReactorThreads,
getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
):
"""
Create a L{GAIResolver}.
@param reactor: the reactor to schedule result-delivery on
@ -89,14 +134,16 @@ class GAIResolver:
)
self._getaddrinfo = getaddrinfo
def resolveHostName(
# The types on IHostnameResolver is incorrect in Twisted, see
# https://twistedmatrix.com/trac/ticket/10276
def resolveHostName( # type: ignore[override]
self,
resolutionReceiver,
hostName,
portNumber=0,
addressTypes=None,
transportSemantics="TCP",
):
resolutionReceiver: IResolutionReceiver,
hostName: str,
portNumber: int = 0,
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
transportSemantics: str = "TCP",
) -> IHostResolution:
"""
See L{IHostnameResolver.resolveHostName}
@param resolutionReceiver: see interface
@ -112,7 +159,7 @@ class GAIResolver:
]
socketType = _transportToSocket[transportSemantics]
def get():
def get() -> _GETADDRINFO_RESULT:
try:
return self._getaddrinfo(
hostName, portNumber, addressFamily, socketType
@ -125,7 +172,7 @@ class GAIResolver:
resolutionReceiver.resolutionBegan(resolution)
@d.addCallback
def deliverResults(result):
def deliverResults(result: _GETADDRINFO_RESULT) -> None:
for family, socktype, _proto, _cannoname, sockaddr in result:
addrType = _afToType[family]
resolutionReceiver.addressResolved(

View file

@ -64,6 +64,13 @@ in_flight = InFlightGauge(
sub_metrics=["real_time_max", "real_time_sum"],
)
# This is dynamically created in InFlightGauge.__init__.
class _InFlightMetric(Protocol):
real_time_max: float
real_time_sum: float
T = TypeVar("T", bound=Callable[..., Any])
@ -180,7 +187,7 @@ class Measure:
"""
return self._logging_context.get_resource_usage()
def _update_in_flight(self, metrics) -> None:
def _update_in_flight(self, metrics: _InFlightMetric) -> None:
"""Gets called when processing in flight metrics"""
assert self.start is not None
duration = self.clock.time() - self.start