Adds misc missing type hints (#11953)

This commit is contained in:
Patrick Cloke 2022-02-11 07:20:16 -05:00 committed by GitHub
parent c3db7a0b59
commit a121507cfe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 48 additions and 41 deletions

1
changelog.d/11953.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints.

View file

@ -142,6 +142,9 @@ disallow_untyped_defs = True
[mypy-synapse.crypto.*] [mypy-synapse.crypto.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.event_auth]
disallow_untyped_defs = True
[mypy-synapse.events.*] [mypy-synapse.events.*]
disallow_untyped_defs = True disallow_untyped_defs = True
@ -166,6 +169,9 @@ disallow_untyped_defs = True
[mypy-synapse.module_api.*] [mypy-synapse.module_api.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.notifier]
disallow_untyped_defs = True
[mypy-synapse.push.*] [mypy-synapse.push.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -763,7 +763,9 @@ def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -
return default return default
def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]): def _verify_third_party_invite(
event: EventBase, auth_events: StateMap[EventBase]
) -> bool:
""" """
Validates that the invite event is authorized by a previous third-party invite. Validates that the invite event is authorized by a previous third-party invite.

View file

@ -544,9 +544,9 @@ class OidcProvider:
""" """
metadata = await self.load_metadata() metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint") token_endpoint = metadata.get("token_endpoint")
raw_headers = { raw_headers: Dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent, "User-Agent": self._http_client.user_agent.decode("ascii"),
"Accept": "application/json", "Accept": "application/json",
} }

View file

@ -322,21 +322,20 @@ class SimpleHttpClient:
self._ip_whitelist = ip_whitelist self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist self._ip_blacklist = ip_blacklist
self._extra_treq_args = treq_args or {} self._extra_treq_args = treq_args or {}
self.user_agent = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
user_agent = hs.version_string
if hs.config.server.user_agent_suffix: if hs.config.server.user_agent_suffix:
self.user_agent = "%s %s" % ( user_agent = "%s %s" % (
self.user_agent, user_agent,
hs.config.server.user_agent_suffix, hs.config.server.user_agent_suffix,
) )
self.user_agent = user_agent.encode("ascii")
# We use this for our body producers to ensure that they use the correct # We use this for our body producers to ensure that they use the correct
# reactor. # reactor.
self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist: if self._ip_blacklist:
# If we have an IP blacklist, we need to use a DNS resolver which # If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding. # filters out blacklisted IP addresses, to prevent DNS rebinding.

View file

@ -334,12 +334,11 @@ class MatrixFederationHttpClient:
user_agent = hs.version_string user_agent = hs.version_string
if hs.config.server.user_agent_suffix: if hs.config.server.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix) user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)
user_agent = user_agent.encode("ascii")
federation_agent = MatrixFederationAgent( federation_agent = MatrixFederationAgent(
self.reactor, self.reactor,
tls_client_options_factory, tls_client_options_factory,
user_agent, user_agent.encode("ascii"),
hs.config.server.federation_ip_range_whitelist, hs.config.server.federation_ip_range_whitelist,
hs.config.server.federation_ip_range_blacklist, hs.config.server.federation_ip_range_blacklist,
) )

View file

@ -14,6 +14,7 @@
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING,
Awaitable, Awaitable,
Callable, Callable,
Collection, Collection,
@ -32,7 +33,6 @@ from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
import synapse.server
from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import EventBase from synapse.events import EventBase
@ -53,6 +53,9 @@ from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
notified_events_counter = Counter("synapse_notifier_notified_events", "") notified_events_counter = Counter("synapse_notifier_notified_events", "")
@ -82,7 +85,7 @@ class _NotificationListener:
__slots__ = ["deferred"] __slots__ = ["deferred"]
def __init__(self, deferred): def __init__(self, deferred: "defer.Deferred"):
self.deferred = deferred self.deferred = deferred
@ -124,7 +127,7 @@ class _NotifierUserStream:
stream_key: str, stream_key: str,
stream_id: Union[int, RoomStreamToken], stream_id: Union[int, RoomStreamToken],
time_now_ms: int, time_now_ms: int,
): ) -> None:
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
Args: Args:
@ -152,7 +155,7 @@ class _NotifierUserStream:
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token) noify_deferred.callback(self.current_token)
def remove(self, notifier: "Notifier"): def remove(self, notifier: "Notifier") -> None:
"""Remove this listener from all the indexes in the Notifier """Remove this listener from all the indexes in the Notifier
it knows about. it knows about.
""" """
@ -188,7 +191,7 @@ class EventStreamResult:
start_token: StreamToken start_token: StreamToken
end_token: StreamToken end_token: StreamToken
def __bool__(self): def __bool__(self) -> bool:
return bool(self.events) return bool(self.events)
@ -212,7 +215,7 @@ class Notifier:
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "HomeServer"):
self.user_to_user_stream: Dict[str, _NotifierUserStream] = {} self.user_to_user_stream: Dict[str, _NotifierUserStream] = {}
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
@ -248,7 +251,7 @@ class Notifier:
# This is not a very cheap test to perform, but it's only executed # This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at # when rendering the metrics page, which is likely once per minute at
# most when scraping it. # most when scraping it.
def count_listeners(): def count_listeners() -> int:
all_user_streams: Set[_NotifierUserStream] = set() all_user_streams: Set[_NotifierUserStream] = set()
for streams in list(self.room_to_user_streams.values()): for streams in list(self.room_to_user_streams.values()):
@ -270,7 +273,7 @@ class Notifier:
"synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
) )
def add_replication_callback(self, cb: Callable[[], None]): def add_replication_callback(self, cb: Callable[[], None]) -> None:
"""Add a callback that will be called when some new data is available. """Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and it needs to do any asynchronous work, a background thread should be started and
@ -284,7 +287,7 @@ class Notifier:
event_pos: PersistedEventPosition, event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken, max_room_stream_token: RoomStreamToken,
extra_users: Optional[Collection[UserID]] = None, extra_users: Optional[Collection[UserID]] = None,
): ) -> None:
"""Unwraps event and calls `on_new_room_event_args`.""" """Unwraps event and calls `on_new_room_event_args`."""
await self.on_new_room_event_args( await self.on_new_room_event_args(
event_pos=event_pos, event_pos=event_pos,
@ -307,7 +310,7 @@ class Notifier:
event_pos: PersistedEventPosition, event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken, max_room_stream_token: RoomStreamToken,
extra_users: Optional[Collection[UserID]] = None, extra_users: Optional[Collection[UserID]] = None,
): ) -> None:
"""Used by handlers to inform the notifier something has happened """Used by handlers to inform the notifier something has happened
in the room, room event wise. in the room, room event wise.
@ -338,7 +341,9 @@ class Notifier:
self.notify_replication() self.notify_replication()
def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken): def _notify_pending_new_room_events(
self, max_room_stream_token: RoomStreamToken
) -> None:
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
Args: Args:
@ -374,7 +379,7 @@ class Notifier:
) )
self._on_updated_room_token(max_room_stream_token) self._on_updated_room_token(max_room_stream_token)
def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken): def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken) -> None:
"""Poke services that might care that the room position has been """Poke services that might care that the room position has been
updated. updated.
""" """
@ -386,13 +391,13 @@ class Notifier:
if self.federation_sender: if self.federation_sender:
self.federation_sender.notify_new_events(max_room_stream_token) self.federation_sender.notify_new_events(max_room_stream_token)
def _notify_app_services(self, max_room_stream_token: RoomStreamToken): def _notify_app_services(self, max_room_stream_token: RoomStreamToken) -> None:
try: try:
self.appservice_handler.notify_interested_services(max_room_stream_token) self.appservice_handler.notify_interested_services(max_room_stream_token)
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken) -> None:
try: try:
self._pusher_pool.on_new_notifications(max_room_stream_token) self._pusher_pool.on_new_notifications(max_room_stream_token)
except Exception: except Exception:
@ -475,8 +480,8 @@ class Notifier:
user_id: str, user_id: str,
timeout: int, timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]], callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids=None, room_ids: Optional[Collection[str]] = None,
from_token=StreamToken.START, from_token: StreamToken = StreamToken.START,
) -> T: ) -> T:
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
@ -700,14 +705,14 @@ class Notifier:
for expired_stream in expired_streams: for expired_stream in expired_streams:
expired_stream.remove(self) expired_stream.remove(self)
def _register_with_keys(self, user_stream: _NotifierUserStream): def _register_with_keys(self, user_stream: _NotifierUserStream) -> None:
self.user_to_user_stream[user_stream.user_id] = user_stream self.user_to_user_stream[user_stream.user_id] = user_stream
for room in user_stream.rooms: for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream) s.add(user_stream)
def _user_joined_room(self, user_id: str, room_id: str): def _user_joined_room(self, user_id: str, room_id: str) -> None:
new_user_stream = self.user_to_user_stream.get(user_id) new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None: if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams = self.room_to_user_streams.setdefault(room_id, set())
@ -719,7 +724,7 @@ class Notifier:
for cb in self.replication_callbacks: for cb in self.replication_callbacks:
cb() cb()
def notify_remote_server_up(self, server: str): def notify_remote_server_up(self, server: str) -> None:
"""Notify any replication that a remote server has come back up""" """Notify any replication that a remote server has come back up"""
# We call federation_sender directly rather than registering as a # We call federation_sender directly rather than registering as a
# callback as a) we already have a reference to it and b) it introduces # callback as a) we already have a reference to it and b) it introduces

View file

@ -233,8 +233,8 @@ class HomeServer(metaclass=abc.ABCMeta):
self, self,
hostname: str, hostname: str,
config: HomeServerConfig, config: HomeServerConfig,
reactor=None, reactor: Optional[ISynapseReactor] = None,
version_string="Synapse", version_string: str = "Synapse",
): ):
""" """
Args: Args:
@ -244,7 +244,7 @@ class HomeServer(metaclass=abc.ABCMeta):
if not reactor: if not reactor:
from twisted.internet import reactor as _reactor from twisted.internet import reactor as _reactor
reactor = _reactor reactor = cast(ISynapseReactor, _reactor)
self._reactor = reactor self._reactor = reactor
self.hostname = hostname self.hostname = hostname
@ -264,7 +264,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self._module_web_resources: Dict[str, Resource] = {} self._module_web_resources: Dict[str, Resource] = {}
self._module_web_resources_consumed = False self._module_web_resources_consumed = False
def register_module_web_resource(self, path: str, resource: Resource): def register_module_web_resource(self, path: str, resource: Resource) -> None:
"""Allows a module to register a web resource to be served at the given path. """Allows a module to register a web resource to be served at the given path.
If multiple modules register a resource for the same path, the module that If multiple modules register a resource for the same path, the module that

View file

@ -155,7 +155,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
self.http_client = Mock(spec=["get_json"]) self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = "Synapse Test" self.http_client.user_agent = b"Synapse Test"
hs = self.setup_test_homeserver(proxied_http_client=self.http_client) hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
@ -438,12 +438,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
state = "state" state = "state"
nonce = "nonce" nonce = "nonce"
client_redirect_url = "http://client/redirect" client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1" ip_address = "10.0.0.1"
session = self._generate_oidc_session_token(state, nonce, client_redirect_url) session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
request = _build_callback_request( request = _build_callback_request(code, state, session, ip_address=ip_address)
code, state, session, user_agent=user_agent, ip_address=ip_address
)
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
@ -1274,7 +1271,6 @@ def _build_callback_request(
code: str, code: str,
state: str, state: str,
session: str, session: str,
user_agent: str = "Browser",
ip_address: str = "10.0.0.1", ip_address: str = "10.0.0.1",
): ):
"""Builds a fake SynapseRequest to mock the browser callback """Builds a fake SynapseRequest to mock the browser callback
@ -1289,7 +1285,6 @@ def _build_callback_request(
query param. Should be the same as was embedded in the session in query param. Should be the same as was embedded in the session in
_build_oidc_session. _build_oidc_session.
session: the "session" which would have been passed around in the cookie. session: the "session" which would have been passed around in the cookie.
user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from ip_address: the IP address to pretend the request came from
""" """
request = Mock( request = Mock(