Reduce the number of "untyped defs" (#12716)

This commit is contained in:
David Robertson 2022-05-12 15:33:50 +01:00 committed by GitHub
parent de1e599b9d
commit 17e1eb7749
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 142 additions and 69 deletions

1
changelog.d/12716.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations to increase the number of modules passing `disallow-untyped-defs`.

View file

@ -119,9 +119,18 @@ disallow_untyped_defs = True
[mypy-synapse.federation.transport.client] [mypy-synapse.federation.transport.client]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-synapse.groups.*]
disallow_untyped_defs = True
[mypy-synapse.handlers.*] [mypy-synapse.handlers.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.http.federation.*]
disallow_untyped_defs = True
[mypy-synapse.http.request_metrics]
disallow_untyped_defs = True
[mypy-synapse.http.server] [mypy-synapse.http.server]
disallow_untyped_defs = True disallow_untyped_defs = True
@ -196,12 +205,27 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.state_deltas] [mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.stream]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.transactions] [mypy-synapse.storage.databases.main.transactions]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.user_erasure_store] [mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.prepare_database]
disallow_untyped_defs = True
[mypy-synapse.storage.persist_events]
disallow_untyped_defs = True
[mypy-synapse.storage.state]
disallow_untyped_defs = True
[mypy-synapse.storage.types]
disallow_untyped_defs = True
[mypy-synapse.storage.util.*] [mypy-synapse.storage.util.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
# Before deleting the group lets kick everyone out of it # Before deleting the group lets kick everyone out of it
users = await self.store.get_users_in_group(group_id, include_private=True) users = await self.store.get_users_in_group(group_id, include_private=True)
async def _kick_user_from_group(user_id): async def _kick_user_from_group(user_id: str) -> None:
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler() groups_local = self.hs.get_groups_local_handler()
assert isinstance( assert isinstance(

View file

@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IAddress, IAddress,
IDelayedCall,
IHostResolution, IHostResolution,
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTime,
IResolutionReceiver, IResolutionReceiver,
ITCPTransport, ITCPTransport,
) )
@ -121,13 +123,15 @@ def check_against_blacklist(
_EPSILON = 0.00000001 _EPSILON = 0.00000001
def _make_scheduler(reactor): def _make_scheduler(
reactor: IReactorTime,
) -> Callable[[Callable[[], object]], IDelayedCall]:
"""Makes a schedular suitable for a Cooperator using the given reactor. """Makes a schedular suitable for a Cooperator using the given reactor.
(This is effectively just a copy from `twisted.internet.task`) (This is effectively just a copy from `twisted.internet.task`)
""" """
def _scheduler(x): def _scheduler(x: Callable[[], object]) -> IDelayedCall:
return reactor.callLater(_EPSILON, x) return reactor.callLater(_EPSILON, x)
return _scheduler return _scheduler
@ -775,7 +779,7 @@ class SimpleHttpClient:
) )
def _timeout_to_request_timed_out_error(f: Failure): def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError): if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
# The TCP connection has its own timeout (set by the 'connectTimeout' param # The TCP connection has its own timeout (set by the 'connectTimeout' param
# on the Agent), which raises twisted_error.TimeoutError exception. # on the Agent), which raises twisted_error.TimeoutError exception.
@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(self, deferred: defer.Deferred): def __init__(self, deferred: defer.Deferred):
self.deferred = deferred self.deferred = deferred
def _maybe_fail(self): def _maybe_fail(self) -> None:
""" """
Report a max size exceed error and disconnect the first time this is called. Report a max size exceed error and disconnect the first time this is called.
""" """
@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
Do not use this since it allows an attacker to intercept your communications. Do not use this since it allows an attacker to intercept your communications.
""" """
def __init__(self): def __init__(self) -> None:
self._context = SSL.Context(SSL.SSLv23_METHOD) self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: False) self._context.set_verify(VERIFY_NONE, lambda *_: False)
def getContext(self, hostname=None, port=None): def getContext(self, hostname=None, port=None):
return self._context return self._context
def creatorForNetloc(self, hostname, port): def creatorForNetloc(self, hostname: bytes, port: int):
return self return self

View file

@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:
self._srv_resolver = srv_resolver self._srv_resolver = srv_resolver
def endpointForURI(self, parsed_uri: URI): def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
return MatrixHostnameEndpoint( return MatrixHostnameEndpoint(
self._reactor, self._reactor,
self._proxy_reactor, self._proxy_reactor,

View file

@ -16,7 +16,7 @@
import logging import logging
import random import random
import time import time
from typing import Callable, Dict, List from typing import Any, Callable, Dict, List
import attr import attr
@ -109,7 +109,7 @@ class SrvResolver:
def __init__( def __init__(
self, self,
dns_client=client, dns_client: Any = client,
cache: Dict[bytes, List[Server]] = SERVER_CACHE, cache: Dict[bytes, List[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time, get_time: Callable[[], float] = time.time,
): ):

View file

@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known") _had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class WellKnownLookupResult: class WellKnownLookupResult:
delegated_server = attr.ib() delegated_server: Optional[bytes]
class WellKnownResolver: class WellKnownResolver:
@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
class _FetchWellKnownFailure(Exception): class _FetchWellKnownFailure(Exception):
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
# a temporary failure. # a temporary failure.
temporary = attr.ib() temporary: bool = attr.ib()

View file

@ -23,6 +23,8 @@ from http import HTTPStatus
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
BinaryIO,
Callable, Callable,
Dict, Dict,
Generic, Generic,
@ -44,7 +46,7 @@ from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse from twisted.web.iweb import IBodyProducer, IResponse
@ -58,11 +60,13 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer from synapse.http import QuieterFileBodyProducer
from synapse.http.client import ( from synapse.http.client import (
BlacklistingAgentWrapper, BlacklistingAgentWrapper,
BodyExceededMaxSize, BodyExceededMaxSize,
ByteWriteable, ByteWriteable,
_make_scheduler,
encode_query_args, encode_query_args,
read_body_with_max_size, read_body_with_max_size,
) )
@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
CONTENT_TYPE = "application/json" CONTENT_TYPE = "application/json"
def __init__(self): def __init__(self) -> None:
self._buffer = StringIO() self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer) self._binary_wrapper = BinaryIOWrapper(self._buffer)
@ -299,7 +303,9 @@ async def _handle_response(
class BinaryIOWrapper: class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly.""" """A wrapper for a TextIO which converts from bytes on the fly."""
def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"): def __init__(
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
):
self.decoder = codecs.getincrementaldecoder(encoding)(errors) self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file self.file = file
@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
requests. requests.
""" """
def __init__(self, hs: "HomeServer", tls_client_options_factory): def __init__(
self,
hs: "HomeServer",
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
):
self.hs = hs self.hs = hs
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
self.server_name = hs.hostname self.server_name = hs.hostname
@ -348,10 +358,7 @@ class MatrixFederationHttpClient:
self.version_string_bytes = hs.version_string.encode("ascii") self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60 self.default_timeout = 60
def schedule(x): self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))
self.reactor.callLater(_EPSILON, x)
self._cooperator = Cooperator(scheduler=schedule)
self._sleeper = AwakenableSleeper(self.reactor) self._sleeper = AwakenableSleeper(self.reactor)
@ -364,7 +371,7 @@ class MatrixFederationHttpClient:
self, self,
request: MatrixFederationRequest, request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
**send_request_args, **send_request_args: Any,
) -> IResponse: ) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request """Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a upon receiving a combination of a 400 HTTP response code and a
@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient:
self, self,
destination: str, destination: str,
path: str, path: str,
output_stream, output_stream: BinaryIO,
args: Optional[QueryParams] = None, args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True, retry_on_dns_fail: bool = True,
max_size: Optional[int] = None, max_size: Optional[int] = None,
@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient:
return length, headers return length, headers
def _flatten_response_never_received(e): def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"): if hasattr(e, "reasons"):
reasons = ", ".join( reasons = ", ".join(
_flatten_response_never_received(f.value) for f in e.reasons _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
) )
return "%s:[%s]" % (type(e).__name__, reasons) return "%s:[%s]" % (type(e).__name__, reasons)

View file

@ -162,7 +162,7 @@ class RequestMetrics:
with _in_flight_requests_lock: with _in_flight_requests_lock:
_in_flight_requests.add(self) _in_flight_requests.add(self)
def stop(self, time_sec, response_code, sent_bytes): def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
with _in_flight_requests_lock: with _in_flight_requests_lock:
_in_flight_requests.discard(self) _in_flight_requests.discard(self)
@ -186,13 +186,13 @@ class RequestMetrics:
) )
return return
response_code = str(response_code) response_code_str = str(response_code)
outgoing_responses_counter.labels(self.method, response_code).inc() outgoing_responses_counter.labels(self.method, response_code_str).inc()
response_count.labels(self.method, self.name, tag).inc() response_count.labels(self.method, self.name, tag).inc()
response_timer.labels(self.method, self.name, tag, response_code).observe( response_timer.labels(self.method, self.name, tag, response_code_str).observe(
time_sec - self.start_ts time_sec - self.start_ts
) )
@ -221,7 +221,7 @@ class RequestMetrics:
# flight. # flight.
self.update_metrics() self.update_metrics()
def update_metrics(self): def update_metrics(self) -> None:
"""Updates the in flight metrics with values from this request.""" """Updates the in flight metrics with values from this request."""
if not self.start_context: if not self.start_context:
logger.error( logger.error(

View file

@ -31,6 +31,7 @@ from typing import (
List, List,
Optional, Optional,
Tuple, Tuple,
Type,
TypeVar, TypeVar,
cast, cast,
overload, overload,
@ -41,6 +42,7 @@ from prometheus_client import Histogram
from typing_extensions import Concatenate, Literal, ParamSpec from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.internet.interfaces import IReactorCore
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool( def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine reactor: IReactorCore,
db_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
) -> adbapi.ConnectionPool: ) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.""" """Get the connection pool for the database."""
@ -101,7 +105,7 @@ def make_pool(
db_args = dict(db_config.config.get("args", {})) db_args = dict(db_config.config.get("args", {}))
db_args.setdefault("cp_reconnect", True) db_args.setdefault("cp_reconnect", True)
def _on_new_connection(conn): def _on_new_connection(conn: Connection) -> None:
# Ensure we have a logging context so we can correctly track queries, # Ensure we have a logging context so we can correctly track queries,
# etc. # etc.
with LoggingContext("db.on_new_connection"): with LoggingContext("db.on_new_connection"):
@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
default_txn_name: str default_txn_name: str
def cursor( def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None self,
*,
txn_name: Optional[str] = None,
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction": ) -> "LoggingTransaction":
if not txn_name: if not txn_name:
txn_name = self.default_txn_name txn_name = self.default_txn_name
@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
self.conn.__enter__() self.conn.__enter__()
return self return self
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback) return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class. # Proxy through any unknown lookups to the DB conn class.
def __getattr__(self, name): def __getattr__(self, name: str) -> Any:
return getattr(self.conn, name) return getattr(self.conn, name)
@ -391,17 +404,22 @@ class LoggingTransaction:
def __enter__(self) -> "LoggingTransaction": def __enter__(self) -> "LoggingTransaction":
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> None:
self.close() self.close()
class PerformanceCounters: class PerformanceCounters:
def __init__(self): def __init__(self) -> None:
self.current_counters = {} self.current_counters: Dict[str, Tuple[int, float]] = {}
self.previous_counters = {} self.previous_counters: Dict[str, Tuple[int, float]] = {}
def update(self, key: str, duration_secs: float) -> None: def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0)) count, cum_time = self.current_counters.get(key, (0, 0.0))
count += 1 count += 1
cum_time += duration_secs cum_time += duration_secs
self.current_counters[key] = (count, cum_time) self.current_counters[key] = (count, cum_time)
@ -527,7 +545,7 @@ class DatabasePool:
def start_profiling(self) -> None: def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time() self._previous_loop_ts = monotonic_time()
def loop(): def loop() -> None:
curr = self._current_txn_total_time curr = self._current_txn_total_time
prev = self._previous_txn_total_time prev = self._previous_txn_total_time
self._previous_txn_total_time = curr self._previous_txn_total_time = curr
@ -1186,7 +1204,7 @@ class DatabasePool:
if lock: if lock:
self.engine.lock_table(txn, table) self.engine.lock_table(txn, table)
def _getwhere(key): def _getwhere(key: str) -> str:
# If the value we're passing in is None (aka NULL), we need to use # If the value we're passing in is None (aka NULL), we need to use
# IS, not =, as NULL = NULL equals NULL (False). # IS, not =, as NULL = NULL equals NULL (False).
if keyvalues[key] is None: if keyvalues[key] is None:
@ -2258,7 +2276,7 @@ class DatabasePool:
term: Optional[str], term: Optional[str],
col: str, col: str,
retcols: Collection[str], retcols: Collection[str],
desc="simple_search_list", desc: str = "simple_search_list",
) -> Optional[List[Dict[str, Any]]]: ) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.

View file

@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import ( from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
) )
from synapse.storage.types import Cursor
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
self._last_user_visit_update = self._get_start_of_day() self._last_user_visit_update = self._get_start_of_day()
@wrap_as_background_process("read_forward_extremities") @wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self): async def _read_forward_extremities(self) -> None:
def fetch(txn): def fetch(txn):
txn.execute( txn.execute(
""" """
@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1] (x[0] - 1) * x[1] for x in res if x[1]
) )
async def count_daily_e2ee_messages(self): async def count_daily_e2ee_messages(self) -> int:
""" """
Returns an estimate of the number of messages sent in the last day. Returns an estimate of the number of messages sent in the last day.
@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
async def count_daily_sent_e2ee_messages(self): async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn): def _count_messages(txn):
# This is good enough as if you have silly characters in your own # This is good enough as if you have silly characters in your own
# hostname then that's your own fault. # hostname then that's your own fault.
@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_e2ee_messages", _count_messages "count_daily_sent_e2ee_messages", _count_messages
) )
async def count_daily_active_e2ee_rooms(self): async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn): def _count(txn):
sql = """ sql = """
SELECT COUNT(DISTINCT room_id) FROM events SELECT COUNT(DISTINCT room_id) FROM events
@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_active_e2ee_rooms", _count "count_daily_active_e2ee_rooms", _count
) )
async def count_daily_messages(self): async def count_daily_messages(self) -> int:
""" """
Returns an estimate of the number of messages sent in the last day. Returns an estimate of the number of messages sent in the last day.
@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_messages", _count_messages) return await self.db_pool.runInteraction("count_messages", _count_messages)
async def count_daily_sent_messages(self): async def count_daily_sent_messages(self) -> int:
def _count_messages(txn): def _count_messages(txn):
# This is good enough as if you have silly characters in your own # This is good enough as if you have silly characters in your own
# hostname then that's your own fault. # hostname then that's your own fault.
@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_messages", _count_messages "count_daily_sent_messages", _count_messages
) )
async def count_daily_active_rooms(self): async def count_daily_active_rooms(self) -> int:
def _count(txn): def _count(txn):
sql = """ sql = """
SELECT COUNT(DISTINCT room_id) FROM events SELECT COUNT(DISTINCT room_id) FROM events
@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago "count_monthly_users", self._count_users, thirty_days_ago
) )
def _count_users(self, txn, time_from): def _count_users(self, txn: Cursor, time_from: int) -> int:
""" """
Returns number of users seen in the past time_from period Returns number of users seen in the past time_from period
""" """
@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) u ) u
""" """
txn.execute(sql, (time_from,)) txn.execute(sql, (time_from,))
(count,) = txn.fetchone() # Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
(count,) = txn.fetchone() # type: ignore[misc]
return count return count
async def count_r30_users(self) -> Dict[str, int]: async def count_r30_users(self) -> Dict[str, int]:
@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_r30v2_users", _count_r30v2_users "count_r30v2_users", _count_r30v2_users
) )
def _get_start_of_day(self): def _get_start_of_day(self) -> int:
""" """
Returns millisecond unixtime for start of UTC day. Returns millisecond unixtime for start of UTC day.
""" """

View file

@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
event_id: str, event_id: str,
allow_none=False, allow_none: bool = False,
) -> int: ) -> Optional[int]:
return self.db_pool.simple_select_one_onecol_txn( # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
# Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
txn=txn, txn=txn,
table="events", table="events",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},

View file

@ -25,6 +25,7 @@ from typing import (
Collection, Collection,
Deque, Deque,
Dict, Dict,
Generator,
Generic, Generic,
Iterable, Iterable,
List, List,
@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
return res return res
def _handle_queue(self, room_id): def _handle_queue(self, room_id: str) -> None:
"""Attempts to handle the queue for a room if not already being handled. """Attempts to handle the queue for a room if not already being handled.
The queue's callback will be invoked with for each item in the queue, The queue's callback will be invoked with for each item in the queue,
@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
self._currently_persisting_rooms.add(room_id) self._currently_persisting_rooms.add(room_id)
async def handle_queue_loop(): async def handle_queue_loop() -> None:
try: try:
queue = self._get_drainining_queue(room_id) queue = self._get_drainining_queue(room_id)
for item in queue: for item in queue:
@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
with PreserveLoggingContext(): with PreserveLoggingContext():
item.deferred.callback(ret) item.deferred.callback(ret)
finally: finally:
queue = self._event_persist_queues.pop(room_id, None) remaining_queue = self._event_persist_queues.pop(room_id, None)
if queue: if remaining_queue:
self._event_persist_queues[room_id] = queue self._event_persist_queues[room_id] = remaining_queue
self._currently_persisting_rooms.discard(room_id) self._currently_persisting_rooms.discard(room_id)
# set handle_queue_loop off in the background # set handle_queue_loop off in the background
run_as_background_process("persist_events", handle_queue_loop) run_as_background_process("persist_events", handle_queue_loop)
def _get_drainining_queue(self, room_id): def _get_drainining_queue(
self, room_id: str
) -> Generator[_EventPersistQueueItem, None, None]:
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
try: try:
@ -317,7 +320,9 @@ class EventsPersistenceStorage:
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx)) partitioned.setdefault(event.room_id, []).append((event, ctx))
async def enqueue(item): async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
) -> Dict[str, str]:
room_id, evs_ctxs = item room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue( return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled room_id, evs_ctxs, backfilled=backfilled
@ -1102,7 +1107,7 @@ class EventsPersistenceStorage:
return False return False
async def _handle_potentially_left_users(self, user_ids: Set[str]): async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
"""Given a set of remote users check if the server still shares a room with """Given a set of remote users check if the server still shares a room with
them. If not then mark those users' device cache as stale. them. If not then mark those users' device cache as stale.
""" """

View file

@ -85,7 +85,7 @@ def prepare_database(
database_engine: BaseDatabaseEngine, database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig], config: Optional[HomeServerConfig],
databases: Collection[str] = ("main", "state"), databases: Collection[str] = ("main", "state"),
): ) -> None:
"""Prepares a physical database for usage. Will either create all necessary tables """Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version. or upgrade from an older schema version.

View file

@ -62,7 +62,7 @@ class StateFilter:
types: "frozendict[str, Optional[FrozenSet[str]]]" types: "frozendict[str, Optional[FrozenSet[str]]]"
include_others: bool = False include_others: bool = False
def __attrs_post_init__(self): def __attrs_post_init__(self) -> None:
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary # wildcards from the types dictionary
if self.include_others: if self.include_others:
@ -138,7 +138,9 @@ class StateFilter:
) )
@staticmethod @staticmethod
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): def freeze(
types: Mapping[str, Optional[Collection[str]]], include_others: bool
) -> "StateFilter":
""" """
Returns a (frozen) StateFilter with the same contents as the parameters Returns a (frozen) StateFilter with the same contents as the parameters
specified here, which can be made of mutable types. specified here, which can be made of mutable types.

View file

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union from types import TracebackType
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Protocol from typing_extensions import Protocol
@ -86,5 +87,10 @@ class Connection(Protocol):
def __enter__(self) -> "Connection": def __enter__(self) -> "Connection":
... ...
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
... ...