Add most missing type hints to synapse.util (#11328)

This commit is contained in:
Patrick Cloke 2021-11-16 08:47:36 -05:00 committed by GitHub
parent 3a1462f7e0
commit 7468723697
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 161 additions and 165 deletions

1
changelog.d/11328.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `synapse.util`.

View file

@ -196,92 +196,11 @@ disallow_untyped_defs = True
[mypy-synapse.streams.*] [mypy-synapse.streams.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.batching_queue] [mypy-synapse.util.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.caches.cached_call] [mypy-synapse.util.caches.treecache]
disallow_untyped_defs = True disallow_untyped_defs = False
[mypy-synapse.util.caches.dictionary_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.lrucache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.response_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.stream_change_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.ttl_cache]
disallow_untyped_defs = True
[mypy-synapse.util.daemonize]
disallow_untyped_defs = True
[mypy-synapse.util.file_consumer]
disallow_untyped_defs = True
[mypy-synapse.util.frozenutils]
disallow_untyped_defs = True
[mypy-synapse.util.hash]
disallow_untyped_defs = True
[mypy-synapse.util.httpresourcetree]
disallow_untyped_defs = True
[mypy-synapse.util.iterutils]
disallow_untyped_defs = True
[mypy-synapse.util.linked_list]
disallow_untyped_defs = True
[mypy-synapse.util.logcontext]
disallow_untyped_defs = True
[mypy-synapse.util.logformatter]
disallow_untyped_defs = True
[mypy-synapse.util.macaroons]
disallow_untyped_defs = True
[mypy-synapse.util.manhole]
disallow_untyped_defs = True
[mypy-synapse.util.module_loader]
disallow_untyped_defs = True
[mypy-synapse.util.msisdn]
disallow_untyped_defs = True
[mypy-synapse.util.patch_inline_callbacks]
disallow_untyped_defs = True
[mypy-synapse.util.ratelimitutils]
disallow_untyped_defs = True
[mypy-synapse.util.retryutils]
disallow_untyped_defs = True
[mypy-synapse.util.rlimit]
disallow_untyped_defs = True
[mypy-synapse.util.stringutils]
disallow_untyped_defs = True
[mypy-synapse.util.templates]
disallow_untyped_defs = True
[mypy-synapse.util.threepids]
disallow_untyped_defs = True
[mypy-synapse.util.wheel_timer]
disallow_untyped_defs = True
[mypy-synapse.util.versionstring]
disallow_untyped_defs = True
[mypy-tests.handlers.test_user_directory] [mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -27,6 +27,7 @@ from typing import (
Generic, Generic,
Hashable, Hashable,
Iterable, Iterable,
Iterator,
Optional, Optional,
Set, Set,
TypeVar, TypeVar,
@ -40,7 +41,6 @@ from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.logging.context import ( from synapse.logging.context import (
@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
object.__setattr__(self, "_result", None) object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", []) object.__setattr__(self, "_observers", [])
def callback(r): def callback(r: _T) -> _T:
object.__setattr__(self, "_result", (True, r)) object.__setattr__(self, "_result", (True, r))
# once we have set _result, no more entries will be added to _observers, # once we have set _result, no more entries will be added to _observers,
@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
) )
return r return r
def errback(f): def errback(f: Failure) -> Optional[Failure]:
object.__setattr__(self, "_result", (False, f)) object.__setattr__(self, "_result", (False, f))
# once we have set _result, no more entries will be added to _observers, # once we have set _result, no more entries will be added to _observers,
@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
for observer in observers: for observer in observers:
# This is a little bit of magic to correctly propagate stack # This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds. # traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f f.value.__failure__ = f # type: ignore[union-attr]
try: try:
observer.errback(f) observer.errback(f)
except Exception as e: except Exception as e:
@ -314,7 +314,7 @@ class Linearizer:
# will release the lock. # will release the lock.
@contextmanager @contextmanager
def _ctx_manager(_): def _ctx_manager(_: None) -> Iterator[None]:
try: try:
yield yield
finally: finally:
@ -355,7 +355,7 @@ class Linearizer:
new_defer = make_deferred_yieldable(defer.Deferred()) new_defer = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1 entry.deferreds[new_defer] = 1
def cb(_r): def cb(_r: None) -> "defer.Deferred[None]":
logger.debug("Acquired linearizer lock %r for key %r", self.name, key) logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry.count += 1 entry.count += 1
@ -371,7 +371,7 @@ class Linearizer:
# code must be synchronous, so this is the only sensible place.) # code must be synchronous, so this is the only sensible place.)
return self._clock.sleep(0) return self._clock.sleep(0)
def eb(e): def eb(e: Failure) -> Failure:
logger.info("defer %r got err %r", new_defer, e) logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError): if isinstance(e, CancelledError):
logger.debug( logger.debug(
@ -435,7 +435,7 @@ class ReadWriteLock:
await make_deferred_yieldable(curr_writer) await make_deferred_yieldable(curr_writer)
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager() -> Iterator[None]:
try: try:
yield yield
finally: finally:
@ -464,7 +464,7 @@ class ReadWriteLock:
await make_deferred_yieldable(defer.gatherResults(to_wait_on)) await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager() -> Iterator[None]:
try: try:
yield yield
finally: finally:
@ -524,7 +524,7 @@ def timeout_deferred(
delayed_call = reactor.callLater(timeout, time_it_out) delayed_call = reactor.callLater(timeout, time_it_out)
def convert_cancelled(value: failure.Failure): def convert_cancelled(value: Failure) -> Failure:
# if the original deferred was cancelled, and our timeout has fired, then # if the original deferred was cancelled, and our timeout has fired, then
# the reason it was cancelled was due to our timeout. Turn the CancelledError # the reason it was cancelled was due to our timeout. Turn the CancelledError
# into a TimeoutError. # into a TimeoutError.
@ -534,7 +534,7 @@ def timeout_deferred(
deferred.addErrback(convert_cancelled) deferred.addErrback(convert_cancelled)
def cancel_timeout(result): def cancel_timeout(result: _T) -> _T:
# stop the pending call to cancel the deferred if it's been fired # stop the pending call to cancel the deferred if it's been fired
if delayed_call.active(): if delayed_call.active():
delayed_call.cancel() delayed_call.cancel()
@ -542,11 +542,11 @@ def timeout_deferred(
deferred.addBoth(cancel_timeout) deferred.addBoth(cancel_timeout)
def success_cb(val): def success_cb(val: _T) -> None:
if not new_d.called: if not new_d.called:
new_d.callback(val) new_d.callback(val)
def failure_cb(val): def failure_cb(val: Failure) -> None:
if not new_d.called: if not new_d.called:
new_d.errback(val) new_d.errback(val)
@ -557,13 +557,13 @@ def timeout_deferred(
# This class can't be generic because it uses slots with attrs. # This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313 # See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class DoneAwaitable: # should be: Generic[R] class DoneAwaitable: # should be: Generic[R]
"""Simple awaitable that returns the provided value.""" """Simple awaitable that returns the provided value."""
value = attr.ib(type=Any) # should be: R value: Any # should be: R
def __await__(self): def __await__(self) -> Any:
return self return self
def __iter__(self) -> "DoneAwaitable": def __iter__(self) -> "DoneAwaitable":

View file

@ -17,7 +17,7 @@ import logging
import typing import typing
from enum import Enum, auto from enum import Enum, auto
from sys import intern from sys import intern
from typing import Callable, Dict, Optional, Sized from typing import Any, Callable, Dict, List, Optional, Sized
import attr import attr
from prometheus_client.core import Gauge from prometheus_client.core import Gauge
@ -58,20 +58,20 @@ class EvictionReason(Enum):
time = auto() time = auto()
@attr.s(slots=True) @attr.s(slots=True, auto_attribs=True)
class CacheMetric: class CacheMetric:
_cache = attr.ib() _cache: Sized
_cache_type = attr.ib(type=str) _cache_type: str
_cache_name = attr.ib(type=str) _cache_name: str
_collect_callback = attr.ib(type=Optional[Callable]) _collect_callback: Optional[Callable]
hits = attr.ib(default=0) hits: int = 0
misses = attr.ib(default=0) misses: int = 0
eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib( eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
factory=collections.Counter factory=collections.Counter
) )
memory_usage = attr.ib(default=None) memory_usage: Optional[int] = None
def inc_hits(self) -> None: def inc_hits(self) -> None:
self.hits += 1 self.hits += 1
@ -89,13 +89,14 @@ class CacheMetric:
self.memory_usage += memory self.memory_usage += memory
def dec_memory_usage(self, memory: int) -> None: def dec_memory_usage(self, memory: int) -> None:
assert self.memory_usage is not None
self.memory_usage -= memory self.memory_usage -= memory
def clear_memory_usage(self) -> None: def clear_memory_usage(self) -> None:
if self.memory_usage is not None: if self.memory_usage is not None:
self.memory_usage = 0 self.memory_usage = 0
def describe(self): def describe(self) -> List[str]:
return [] return []
def collect(self) -> None: def collect(self) -> None:
@ -118,8 +119,9 @@ class CacheMetric:
self.eviction_size_by_reason[reason] self.eviction_size_by_reason[reason]
) )
cache_total.labels(self._cache_name).set(self.hits + self.misses) cache_total.labels(self._cache_name).set(self.hits + self.misses)
if getattr(self._cache, "max_size", None): max_size = getattr(self._cache, "max_size", None)
cache_max_size.labels(self._cache_name).set(self._cache.max_size) if max_size:
cache_max_size.labels(self._cache_name).set(max_size)
if TRACK_MEMORY_USAGE: if TRACK_MEMORY_USAGE:
# self.memory_usage can be None if nothing has been inserted # self.memory_usage can be None if nothing has been inserted
@ -193,7 +195,7 @@ KNOWN_KEYS = {
} }
def intern_string(string): def intern_string(string: Optional[str]) -> Optional[str]:
"""Takes a (potentially) unicode string and interns it if it's ascii""" """Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None: if string is None:
return None return None
@ -204,7 +206,7 @@ def intern_string(string):
return string return string
def intern_dict(dictionary): def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""Takes a dictionary and interns well known keys and their values""" """Takes a dictionary and interns well known keys and their values"""
return { return {
KNOWN_KEYS.get(key, key): _intern_known_values(key, value) KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
@ -212,7 +214,7 @@ def intern_dict(dictionary):
} }
def _intern_known_values(key, value): def _intern_known_values(key: str, value: Any) -> Any:
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys: if key in intern_keys:

View file

@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]):
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks) self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key) -> None: def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries """Delete a key, or tree of entries
If the cache is backed by a regular dict, then "key" must be of If the cache is backed by a regular dict, then "key" must be of

View file

@ -19,12 +19,15 @@ import logging
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Generic, Generic,
Hashable,
Iterable, Iterable,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -32,6 +35,7 @@ from typing import (
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from twisted.internet import defer from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase: class _CacheDescriptorBase:
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): def __init__(
self,
orig: Callable[..., Any],
num_args: Optional[int],
cache_context: bool = False,
):
self.orig = orig self.orig = orig
arg_spec = inspect.getfullargspec(orig) arg_spec = inspect.getfullargspec(orig)
@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __init__( def __init__(
self, self,
orig, orig: Callable[..., Any],
max_entries: int = 1000, max_entries: int = 1000,
cache_context: bool = False, cache_context: bool = False,
): ):
super().__init__(orig, num_args=None, cache_context=cache_context) super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries self.max_entries = max_entries
def __get__(self, obj, owner): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache( cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__, cache_name=self.orig.__name__,
max_size=self.max_entries, max_size=self.max_entries,
@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = LruCacheDescriptor._Sentinel.sentinel sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig) @functools.wraps(self.orig)
def _wrapped(*args, **kwargs): def _wrapped(*args: Any, **kwargs: Any) -> Any:
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else () callbacks = (invalidate_callback,) if invalidate_callback else ()
@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return r1 + r2 return r1 + r2
Args: Args:
num_args (int): number of positional arguments (excluding ``self`` and num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named ``cache_context``) to use as cache keys. Defaults to all named
args of the function. args of the function.
""" """
def __init__( def __init__(
self, self,
orig, orig: Callable[..., Any],
max_entries=1000, max_entries: int = 1000,
num_args=None, num_args: Optional[int] = None,
tree=False, tree: bool = False,
cache_context=False, cache_context: bool = False,
iterable=False, iterable: bool = False,
prune_unread_entries: bool = True, prune_unread_entries: bool = True,
): ):
super().__init__(orig, num_args=num_args, cache_context=cache_context) super().__init__(orig, num_args=num_args, cache_context=cache_context)
@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable self.iterable = iterable
self.prune_unread_entries = prune_unread_entries self.prune_unread_entries = prune_unread_entries
def __get__(self, obj, owner): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache( cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder get_cache_key = self.cache_key_builder
@functools.wraps(self.orig) @functools.wraps(self.orig)
def _wrapped(*args, **kwargs): def _wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
of results. of results.
""" """
def __init__(self, orig, cached_method_name, list_name, num_args=None): def __init__(
self,
orig: Callable[..., Any],
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
):
""" """
Args: Args:
orig (function) orig
cached_method_name (str): The name of the cached method. cached_method_name: The name of the cached method.
list_name (str): Name of the argument which is the bulk lookup list list_name: Name of the argument which is the bulk lookup list
num_args (int): number of positional arguments (excluding ``self``, num_args: number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all but including list_name) to use as cache keys. Defaults to all
named args of the function. named args of the function.
""" """
@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
% (self.list_name, cached_method_name) % (self.list_name, cached_method_name)
) )
def __get__(self, obj, objtype=None): def __get__(
self, obj: Optional[Any], objtype: Optional[Type] = None
) -> Callable[..., Any]:
cached_method = getattr(obj, self.cached_method_name) cached_method = getattr(obj, self.cached_method_name)
cache: DeferredCache[CacheKey, Any] = cached_method.cache cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args num_args = cached_method.num_args
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its # If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
results = {} results = {}
def update_results_dict(res, arg): def update_results_dict(res: Any, arg: Hashable) -> None:
results[arg] = res results[arg] = res
# list of deferreds to wait for # list of deferreds to wait for
@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# otherwise a tuple is used. # otherwise a tuple is used.
if num_args == 1: if num_args == 1:
def arg_to_cache_key(arg): def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg return arg
else: else:
keylist = list(keyargs) keylist = list(keyargs)
def arg_to_cache_key(arg): def arg_to_cache_key(arg: Hashable) -> Hashable:
keylist[self.list_pos] = arg keylist[self.list_pos] = arg
return tuple(keylist) return tuple(keylist)
@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
key = arg_to_cache_key(arg) key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback) cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res): def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a # the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in # a dict. We can now resolve the observable deferreds in
# the cache and update our own result map. # the cache and update our own result map.
@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
deferreds_map[e].callback(val) deferreds_map[e].callback(val)
results[e] = val results[e] = val
def errback(f): def errback(f: Failure) -> Failure:
# the wrapped function has failed. Invalidate any cache # the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail # entries we're supposed to be populating, and fail
# their deferreds. # their deferreds.

View file

@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
import attr import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import defer
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock from synapse.util import Clock
@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
# Don't bother starting the loop if things never expire # Don't bother starting the loop if things never expire
return return
def f(): def f() -> "defer.Deferred[None]":
return run_as_background_process( return run_as_background_process(
"prune_cache_%s" % self._cache_name, self._prune_cache "prune_cache_%s" % self._cache_name, self._prune_cache
) )
@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
return False return False
@attr.s(slots=True) @attr.s(slots=True, auto_attribs=True)
class _CacheEntry: class _CacheEntry:
time = attr.ib(type=int) time: int
value = attr.ib() value: Any

View file

@ -18,12 +18,13 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id): def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None:
distributor.fire("user_left_room", user=user, room_id=room_id) distributor.fire("user_left_room", user=user, room_id=room_id)
@ -63,7 +64,7 @@ class Distributor:
self.pre_registration[name] = [] self.pre_registration[name] = []
self.pre_registration[name].append(observer) self.pre_registration[name].append(observer)
def fire(self, name: str, *args, **kwargs) -> None: def fire(self, name: str, *args: Any, **kwargs: Any) -> None:
"""Dispatches the given signal to the registered observers. """Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred. Runs the observers as a background process. Does not return a deferred.
@ -95,7 +96,7 @@ class Signal:
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]": def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers. not an error to fire a signal with no observers.
@ -103,7 +104,7 @@ class Signal:
Returns a Deferred that will complete when all the observers have Returns a Deferred that will complete when all the observers have
completed.""" completed."""
async def do(observer): async def do(observer: Callable[..., Any]) -> Any:
try: try:
return await maybe_awaitable(observer(*args, **kwargs)) return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e: except Exception as e:
@ -120,5 +121,5 @@ class Signal:
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
def __repr__(self): def __repr__(self) -> str:
return "<Signal name=%r>" % (self.name,) return "<Signal name=%r>" % (self.name,)

View file

@ -3,23 +3,52 @@
# We copy it here as we need to instantiate `GAIResolver` manually, but it is a # We copy it here as we need to instantiate `GAIResolver` manually, but it is a
# private class. # private class.
from socket import ( from socket import (
AF_INET, AF_INET,
AF_INET6, AF_INET6,
AF_UNSPEC, AF_UNSPEC,
SOCK_DGRAM, SOCK_DGRAM,
SOCK_STREAM, SOCK_STREAM,
AddressFamily,
SocketKind,
gaierror, gaierror,
getaddrinfo, getaddrinfo,
) )
from typing import (
TYPE_CHECKING,
Callable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import IHostnameResolver, IHostResolution from twisted.internet.interfaces import (
IAddress,
IHostnameResolver,
IHostResolution,
IReactorThreads,
IResolutionReceiver,
)
from twisted.internet.threads import deferToThreadPool from twisted.internet.threads import deferToThreadPool
if TYPE_CHECKING:
# The types below are copied from
# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
# so that the type hints can match the interfaces.
from twisted.python.runtime import platform
if platform.supportsThreads():
from twisted.python.threadpool import ThreadPool
else:
ThreadPool = object # type: ignore[misc, assignment]
@implementer(IHostResolution) @implementer(IHostResolution)
class HostResolution: class HostResolution:
@ -27,13 +56,13 @@ class HostResolution:
The in-progress resolution of a given hostname. The in-progress resolution of a given hostname.
""" """
def __init__(self, name): def __init__(self, name: str):
""" """
Create a L{HostResolution} with the given name. Create a L{HostResolution} with the given name.
""" """
self.name = name self.name = name
def cancel(self): def cancel(self) -> NoReturn:
# IHostResolution.cancel # IHostResolution.cancel
raise NotImplementedError() raise NotImplementedError()
@ -62,6 +91,17 @@ _socktypeToType = {
} }
_GETADDRINFO_RESULT = List[
Tuple[
AddressFamily,
SocketKind,
int,
str,
Union[Tuple[str, int], Tuple[str, int, int, int]],
]
]
@implementer(IHostnameResolver) @implementer(IHostnameResolver)
class GAIResolver: class GAIResolver:
""" """
@ -69,7 +109,12 @@ class GAIResolver:
L{getaddrinfo} in a thread. L{getaddrinfo} in a thread.
""" """
def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo): def __init__(
self,
reactor: IReactorThreads,
getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
):
""" """
Create a L{GAIResolver}. Create a L{GAIResolver}.
@param reactor: the reactor to schedule result-delivery on @param reactor: the reactor to schedule result-delivery on
@ -89,14 +134,16 @@ class GAIResolver:
) )
self._getaddrinfo = getaddrinfo self._getaddrinfo = getaddrinfo
def resolveHostName( # The types on IHostnameResolver is incorrect in Twisted, see
# https://twistedmatrix.com/trac/ticket/10276
def resolveHostName( # type: ignore[override]
self, self,
resolutionReceiver, resolutionReceiver: IResolutionReceiver,
hostName, hostName: str,
portNumber=0, portNumber: int = 0,
addressTypes=None, addressTypes: Optional[Sequence[Type[IAddress]]] = None,
transportSemantics="TCP", transportSemantics: str = "TCP",
): ) -> IHostResolution:
""" """
See L{IHostnameResolver.resolveHostName} See L{IHostnameResolver.resolveHostName}
@param resolutionReceiver: see interface @param resolutionReceiver: see interface
@ -112,7 +159,7 @@ class GAIResolver:
] ]
socketType = _transportToSocket[transportSemantics] socketType = _transportToSocket[transportSemantics]
def get(): def get() -> _GETADDRINFO_RESULT:
try: try:
return self._getaddrinfo( return self._getaddrinfo(
hostName, portNumber, addressFamily, socketType hostName, portNumber, addressFamily, socketType
@ -125,7 +172,7 @@ class GAIResolver:
resolutionReceiver.resolutionBegan(resolution) resolutionReceiver.resolutionBegan(resolution)
@d.addCallback @d.addCallback
def deliverResults(result): def deliverResults(result: _GETADDRINFO_RESULT) -> None:
for family, socktype, _proto, _cannoname, sockaddr in result: for family, socktype, _proto, _cannoname, sockaddr in result:
addrType = _afToType[family] addrType = _afToType[family]
resolutionReceiver.addressResolved( resolutionReceiver.addressResolved(

View file

@ -64,6 +64,13 @@ in_flight = InFlightGauge(
sub_metrics=["real_time_max", "real_time_sum"], sub_metrics=["real_time_max", "real_time_sum"],
) )
# This is dynamically created in InFlightGauge.__init__.
class _InFlightMetric(Protocol):
real_time_max: float
real_time_sum: float
T = TypeVar("T", bound=Callable[..., Any]) T = TypeVar("T", bound=Callable[..., Any])
@ -180,7 +187,7 @@ class Measure:
""" """
return self._logging_context.get_resource_usage() return self._logging_context.get_resource_usage()
def _update_in_flight(self, metrics) -> None: def _update_in_flight(self, metrics: _InFlightMetric) -> None:
"""Gets called when processing in flight metrics""" """Gets called when processing in flight metrics"""
assert self.start is not None assert self.start is not None
duration = self.clock.time() - self.start duration = self.clock.time() - self.start