Add more missing type hints to tests. (#15028)

This commit is contained in:
Patrick Cloke 2023-02-08 16:29:49 -05:00 committed by GitHub
parent 4eed7b2ede
commit 30509a1010
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 124 additions and 111 deletions

1
changelog.d/15028.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -60,24 +60,6 @@ disallow_untyped_defs = False
[mypy-synapse.storage.database] [mypy-synapse.storage.database]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.scripts.test_new_matrix_user]
disallow_untyped_defs = False
[mypy-tests.server_notices.test_consent]
disallow_untyped_defs = False
[mypy-tests.server_notices.test_resource_limits_server_notices]
disallow_untyped_defs = False
[mypy-tests.test_federation]
disallow_untyped_defs = False
[mypy-tests.test_utils.*]
disallow_untyped_defs = False
[mypy-tests.test_visibility]
disallow_untyped_defs = False
[mypy-tests.unittest] [mypy-tests.unittest]
disallow_untyped_defs = False disallow_untyped_defs = False

View file

@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
self.hs_patcher.start() self.hs_patcher.start() # type: ignore[attr-defined]
self.handler = hs.get_oidc_handler() self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"] self.provider = self.handler._providers["oidc"]
@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs return hs
def tearDown(self) -> None: def tearDown(self) -> None:
self.hs_patcher.stop() self.hs_patcher.stop() # type: ignore[attr-defined]
return super().tearDown() return super().tearDown()
def reset_mocks(self) -> None: def reset_mocks(self) -> None:

View file

@ -12,29 +12,33 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List from typing import List, Optional
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from synapse._scripts.register_new_matrix_user import request_registration from synapse._scripts.register_new_matrix_user import request_registration
from synapse.types import JsonDict
from tests.unittest import TestCase from tests.unittest import TestCase
class RegisterTestCase(TestCase): class RegisterTestCase(TestCase):
def test_success(self): def test_success(self) -> None:
""" """
The script will fetch a nonce, and then generate a MAC with it, and then The script will fetch a nonce, and then generate a MAC with it, and then
post that MAC. post that MAC.
""" """
def get(url, verify=None): def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock() r = Mock()
r.status_code = 200 r.status_code = 200
r.json = lambda: {"nonce": "a"} r.json = lambda: {"nonce": "a"}
return r return r
def post(url, json=None, verify=None): def post(
url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
) -> Mock:
# Make sure we are sent the correct info # Make sure we are sent the correct info
assert json is not None
self.assertEqual(json["username"], "user") self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass") self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a") self.assertEqual(json["nonce"], "a")
@ -70,12 +74,12 @@ class RegisterTestCase(TestCase):
# sys.exit shouldn't have been called. # sys.exit shouldn't have been called.
self.assertEqual(err_code, []) self.assertEqual(err_code, [])
def test_failure_nonce(self): def test_failure_nonce(self) -> None:
""" """
If the script fails to fetch a nonce, it throws an error and quits. If the script fails to fetch a nonce, it throws an error and quits.
""" """
def get(url, verify=None): def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock() r = Mock()
r.status_code = 404 r.status_code = 404
r.reason = "Not Found" r.reason = "Not Found"
@ -107,20 +111,23 @@ class RegisterTestCase(TestCase):
self.assertIn("ERROR! Received 404 Not Found", out) self.assertIn("ERROR! Received 404 Not Found", out)
self.assertNotIn("Success!", out) self.assertNotIn("Success!", out)
def test_failure_post(self): def test_failure_post(self) -> None:
""" """
The script will fetch a nonce, and then if the final POST fails, will The script will fetch a nonce, and then if the final POST fails, will
report an error and quit. report an error and quit.
""" """
def get(url, verify=None): def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock() r = Mock()
r.status_code = 200 r.status_code = 200
r.json = lambda: {"nonce": "a"} r.json = lambda: {"nonce": "a"}
return r return r
def post(url, json=None, verify=None): def post(
url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
) -> Mock:
# Make sure we are sent the correct info # Make sure we are sent the correct info
assert json is not None
self.assertEqual(json["username"], "user") self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass") self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a") self.assertEqual(json["nonce"], "a")

View file

@ -14,8 +14,12 @@
import os import os
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client import login, room, sync from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
tmpdir = self.mktemp() tmpdir = self.mktemp()
os.mkdir(tmpdir) os.mkdir(tmpdir)
@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
"room_name": "Server Notices", "room_name": "Server Notices",
} }
hs = self.setup_test_homeserver(config=config) return self.setup_test_homeserver(config=config)
return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
def prepare(self, reactor, clock, hs):
self.user_id = self.register_user("bob", "abc123") self.user_id = self.register_user("bob", "abc123")
self.access_token = self.login("bob", "abc123") self.access_token = self.login("bob", "abc123")
def test_get_sync_message(self): def test_get_sync_message(self) -> None:
""" """
When user consent server notices are enabled, a sync will cause a notice When user consent server notices are enabled, a sync will cause a notice
to fire (in a room which the user is invited to). The notice contains to fire (in a room which the user is invited to). The notice contains

View file

@ -24,6 +24,7 @@ from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import ( from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices, ResourceLimitsServerNotices,
) )
from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -33,7 +34,7 @@ from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def default_config(self): def default_config(self) -> JsonDict:
config = default_config("test") config = default_config("test")
config.update( config.update(
@ -86,18 +87,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment]
@override_config({"hs_disabled": True}) @override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self): def test_maybe_send_server_notice_disabled_hs(self) -> None:
"""If the HS is disabled, we should not send notices""" """If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@override_config({"limit_usage_by_mau": False}) @override_config({"limit_usage_by_mau": False})
def test_maybe_send_server_notice_to_user_flag_off(self): def test_maybe_send_server_notice_to_user_flag_off(self) -> None:
"""If mau limiting is disabled, we should not send notices""" """If mau limiting is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
"""Test when user has blocked notice, but should have it removed""" """Test when user has blocked notice, but should have it removed"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( self._rlsn._auth_blocking.check_auth_blocking = Mock(
@ -114,7 +115,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once() self._send_notice.assert_called_once()
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
""" """
Test when user has blocked notice, but notice ought to be there (NOOP) Test when user has blocked notice, but notice ought to be there (NOOP)
""" """
@ -134,7 +135,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
def test_maybe_send_server_notice_to_user_add_blocked_notice(self): def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None:
""" """
Test when user does not have blocked notice, but should have one Test when user does not have blocked notice, but should have one
""" """
@ -147,7 +148,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# Would be better to check contents, but 2 calls == set blocking event # Would be better to check contents, but 2 calls == set blocking event
self.assertEqual(self._send_notice.call_count, 2) self.assertEqual(self._send_notice.call_count, 2)
def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None:
""" """
Test when user does not have blocked notice, nor should they (NOOP) Test when user does not have blocked notice, nor should they (NOOP)
""" """
@ -159,7 +160,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None:
""" """
Test when user is not part of the MAU cohort - this should not ever Test when user is not part of the MAU cohort - this should not ever
happen - but ... happen - but ...
@ -175,7 +176,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@override_config({"mau_limit_alerting": False}) @override_config({"mau_limit_alerting": False})
def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(
self,
) -> None:
""" """
Test that when server is over MAU limit and alerting is suppressed, then Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room an alert message is not sent into the room
@ -191,7 +194,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 0) self.assertEqual(self._send_notice.call_count, 0)
@override_config({"mau_limit_alerting": False}) @override_config({"mau_limit_alerting": False})
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None:
""" """
Test that when a server is disabled, that MAU limit alerting is ignored. Test that when a server is disabled, that MAU limit alerting is ignored.
""" """
@ -207,7 +210,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 2) self.assertEqual(self._send_notice.call_count, 2)
@override_config({"mau_limit_alerting": False}) @override_config({"mau_limit_alerting": False})
def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(
self,
) -> None:
""" """
When the room is already in a blocked state, test that when alerting When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state. is suppressed that the room is returned to an unblocked state.
@ -242,7 +247,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def default_config(self): def default_config(self) -> JsonDict:
c = super().default_config() c = super().default_config()
c["server_notices"] = { c["server_notices"] = {
"system_mxid_localpart": "server", "system_mxid_localpart": "server",
@ -270,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self) -> None:
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(
@ -306,7 +311,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.assertEqual(count, 1) self.assertEqual(count, 1)
def test_no_invite_without_notice(self): def test_no_invite_without_notice(self) -> None:
"""Tests that a user doesn't get invited to a server notices room without a """Tests that a user doesn't get invited to a server notices room without a
server notice being sent. server notice being sent.
@ -328,7 +333,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
m.assert_called_once_with(user_id) m.assert_called_once_with(user_id)
def test_invite_with_notice(self): def test_invite_with_notice(self) -> None:
"""Tests that, if the MAU limit is hit, the server notices user invites each user """Tests that, if the MAU limit is hit, the server notices user invites each user
to a room in which it has sent a notice. to a room in which it has sent a notice.
""" """

View file

@ -12,53 +12,48 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Union
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import FederationError from synapse.api.errors import FederationError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json from synapse.federation.federation_base import event_from_pdu_json
from synapse.http.types import QueryParams
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.types import UserID, create_requester from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from tests import unittest from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase): class MessageAcceptTests(unittest.HomeserverTestCase):
def setUp(self): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock() self.http_client = Mock()
self.reactor = ThreadedMemoryReactorClock() return self.setup_test_homeserver(federation_http_client=self.http_client)
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
self.addCleanup,
federation_http_client=self.http_client,
clock=self.hs_clock,
reactor=self.reactor,
)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
user_id = UserID("us", "test") user_id = UserID("us", "test")
our_user = create_requester(user_id) our_user = create_requester(user_id)
room_creator = self.homeserver.get_room_creation_handler() room_creator = self.hs.get_room_creation_handler()
self.room_id = self.get_success( self.room_id = self.get_success(
room_creator.create_room( room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False our_user, room_creator._presets_dict["public_chat"], ratelimit=False
) )
)[0]["room_id"] )[0]["room_id"]
self.store = self.homeserver.get_datastores().main self.store = self.hs.get_datastores().main
# Figure out what the most recent event is # Figure out what the most recent event is
most_recent = self.get_success( most_recent = self.get_success(
self.homeserver.get_datastores().main.get_latest_event_ids_in_room( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
self.room_id
)
)[0] )[0]
join_event = make_event_from_dict( join_event = make_event_from_dict(
@ -78,14 +73,16 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
} }
) )
self.handler = self.homeserver.get_federation_handler() self.handler = self.hs.get_federation_handler()
federation_event_handler = self.homeserver.get_federation_event_handler() federation_event_handler = self.hs.get_federation_event_handler()
async def _check_event_auth(origin, event, context): async def _check_event_auth(
origin: Optional[str], event: EventBase, context: EventContext
) -> None:
pass pass
federation_event_handler._check_event_auth = _check_event_auth federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client() self.client = self.hs.get_federation_client()
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
lambda dest, pdus, **k: succeed(pdus) lambda dest, pdus, **k: succeed(pdus)
) )
@ -104,16 +101,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"$join:test.serv", "$join:test.serv",
) )
def test_cant_hide_direct_ancestors(self): def test_cant_hide_direct_ancestors(self) -> None:
""" """
If you send a message, you must be able to provide the direct If you send a message, you must be able to provide the direct
prev_events that said event references. prev_events that said event references.
""" """
async def post_json(destination, path, data, headers=None, timeout=0): async def post_json(
destination: str,
path: str,
data: Optional[JsonDict] = None,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
# If it asks us for new missing events, give them NOTHING # If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"): if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []} return {"events": []}
return {}
self.http_client.post_json = post_json self.http_client.post_json = post_json
@ -138,7 +144,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
} }
) )
federation_event_handler = self.homeserver.get_federation_event_handler() federation_event_handler = self.hs.get_federation_event_handler()
with LoggingContext("test-context"): with LoggingContext("test-context"):
failure = self.get_failure( failure = self.get_failure(
federation_event_handler.on_receive_pdu("test.serv", lying_event), federation_event_handler.on_receive_pdu("test.serv", lying_event),
@ -158,7 +164,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(extrem[0], "$join:test.serv") self.assertEqual(extrem[0], "$join:test.serv")
def test_retry_device_list_resync(self): def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and """Tests that device lists are marked as stale if they couldn't be synced, and
that stale device lists are retried periodically. that stale device lists are retried periodically.
""" """
@ -171,24 +177,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# When this function is called, increment the number of resync attempts (only if # When this function is called, increment the number of resync attempts (only if
# we're querying devices for the right user ID), then raise a # we're querying devices for the right user ID), then raise a
# NotRetryingDestination error to fail the resync gracefully. # NotRetryingDestination error to fail the resync gracefully.
def query_user_devices(destination, user_id): def query_user_devices(
destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
if user_id == remote_user_id: if user_id == remote_user_id:
self.resync_attempts += 1 self.resync_attempts += 1
raise NotRetryingDestination(0, 0, destination) raise NotRetryingDestination(0, 0, destination)
# Register the mock on the federation client. # Register the mock on the federation client.
federation_client = self.homeserver.get_federation_client() federation_client = self.hs.get_federation_client()
federation_client.query_user_devices = Mock(side_effect=query_user_devices) federation_client.query_user_devices = Mock(side_effect=query_user_devices)
# Register a mock on the store so that the incoming update doesn't fail because # Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user. # we don't share a room with the user.
store = self.homeserver.get_datastores().main store = self.hs.get_datastores().main
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at # Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried. # least one prev_id so that the user's device list will need to be retried.
device_list_updater = self.homeserver.get_device_handler().device_list_updater device_list_updater = self.hs.get_device_handler().device_list_updater
self.get_success( self.get_success(
device_list_updater.incoming_device_list_update( device_list_updater.incoming_device_list_update(
origin=remote_origin, origin=remote_origin,
@ -218,7 +226,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.reactor.advance(30) self.reactor.advance(30)
self.assertEqual(self.resync_attempts, 2) self.assertEqual(self.resync_attempts, 2)
def test_cross_signing_keys_retry(self): def test_cross_signing_keys_retry(self) -> None:
"""Tests that resyncing a device list correctly processes cross-signing keys from """Tests that resyncing a device list correctly processes cross-signing keys from
the remote server. the remote server.
""" """
@ -227,7 +235,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
# Register mock device list retrieval on the federation client. # Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client() federation_client = self.hs.get_federation_client()
federation_client.query_user_devices = Mock( federation_client.query_user_devices = Mock(
return_value=make_awaitable( return_value=make_awaitable(
{ {
@ -252,7 +260,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) )
# Resync the device list. # Resync the device list.
device_handler = self.homeserver.get_device_handler() device_handler = self.hs.get_device_handler()
self.get_success( self.get_success(
device_handler.device_list_updater.user_device_resync(remote_user_id), device_handler.device_list_updater.user_device_resync(remote_user_id),
) )
@ -279,7 +287,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
class StripUnsignedFromEventsTestCase(unittest.TestCase): class StripUnsignedFromEventsTestCase(unittest.TestCase):
def test_strip_unauthorized_unsigned_values(self): def test_strip_unauthorized_unsigned_values(self) -> None:
event1 = { event1 = {
"sender": "@baduser:test.serv", "sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv", "state_key": "@baduser:test.serv",
@ -296,7 +304,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
# Make sure unauthorized fields are stripped from unsigned # Make sure unauthorized fields are stripped from unsigned
self.assertNotIn("more warez", filtered_event.unsigned) self.assertNotIn("more warez", filtered_event.unsigned)
def test_strip_event_maintains_allowed_fields(self): def test_strip_event_maintains_allowed_fields(self) -> None:
event2 = { event2 = {
"sender": "@baduser:test.serv", "sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv", "state_key": "@baduser:test.serv",
@ -323,7 +331,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
self.assertIn("invite_room_state", filtered_event2.unsigned) self.assertIn("invite_room_state", filtered_event2.unsigned)
self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
def test_strip_event_removes_fields_based_on_event_type(self): def test_strip_event_removes_fields_based_on_event_type(self) -> None:
event3 = { event3 = {
"sender": "@baduser:test.serv", "sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv", "state_key": "@baduser:test.serv",

View file

@ -20,12 +20,13 @@ import sys
import warnings import warnings
from asyncio import Future from asyncio import Future
from binascii import unhexlify from binascii import unhexlify
from typing import Awaitable, Callable, Tuple, TypeVar from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
import zope.interface import zope.interface
from twisted.internet.interfaces import IProtocol
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.client import ResponseDone from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES from twisted.web.http import RESPONSES
@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING:
from sys import UnraisableHookArgs
TV = TypeVar("TV") TV = TypeVar("TV")
@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
unraisable_exceptions = [] unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook orig_unraisablehook = sys.unraisablehook
def unraisablehook(unraisable): def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
unraisable_exceptions.append(unraisable.exc_value) unraisable_exceptions.append(unraisable.exc_value)
def cleanup(): def cleanup() -> None:
""" """
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
""" """
sys.unraisablehook = orig_unraisablehook sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions: if unraisable_exceptions:
raise unraisable_exceptions.pop() exc = unraisable_exceptions.pop()
assert exc is not None
raise exc
sys.unraisablehook = unraisablehook sys.unraisablehook = unraisablehook
return cleanup return cleanup
def simple_async_mock(return_value=None, raises=None) -> Mock: def simple_async_mock(
return_value: Optional[TV] = None, raises: Optional[Exception] = None
) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour # AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs): async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises: if raises:
raise raises raise raises
return return_value return return_value
@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc]
headers: Headers = attr.Factory(Headers) headers: Headers = attr.Factory(Headers)
@property @property
def phrase(self): def phrase(self) -> bytes:
return RESPONSES.get(self.code, b"Unknown Status") return RESPONSES.get(self.code, b"Unknown Status")
@property @property
def length(self): def length(self) -> int:
return len(self.body) return len(self.body)
def deliverBody(self, protocol): def deliverBody(self, protocol: IProtocol) -> None:
protocol.dataReceived(self.body) protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone())) protocol.connectionLost(Failure(ResponseDone()))

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple from typing import Any, List, Optional, Tuple
import synapse.server import synapse.server
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -32,7 +32,7 @@ async def inject_member_event(
membership: str, membership: str,
target: Optional[str] = None, target: Optional[str] = None,
extra_content: Optional[dict] = None, extra_content: Optional[dict] = None,
**kwargs, **kwargs: Any,
) -> EventBase: ) -> EventBase:
"""Inject a membership event into a room.""" """Inject a membership event into a room."""
if target is None: if target is None:
@ -57,7 +57,7 @@ async def inject_event(
hs: synapse.server.HomeServer, hs: synapse.server.HomeServer,
room_version: Optional[str] = None, room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
**kwargs, **kwargs: Any,
) -> EventBase: ) -> EventBase:
"""Inject a generic event into a room """Inject a generic event into a room
@ -82,7 +82,7 @@ async def create_event(
hs: synapse.server.HomeServer, hs: synapse.server.HomeServer,
room_version: Optional[str] = None, room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
**kwargs, **kwargs: Any,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
if room_version is None: if room_version is None:
room_version = await hs.get_datastores().main.get_room_version_id( room_version = await hs.get_datastores().main.get_room_version_id(

View file

@ -13,13 +13,13 @@
# limitations under the License. # limitations under the License.
from html.parser import HTMLParser from html.parser import HTMLParser
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
class TestHtmlParser(HTMLParser): class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML""" """A generic HTML page parser which extracts useful things from the HTML"""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
# a list of links found in the doc # a list of links found in the doc
@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
assert input_name assert input_name
self.hiddens[input_name] = attr_dict["value"] self.hiddens[input_name] = attr_dict["value"]
def error(_, message): def error(self, message: str) -> NoReturn:
raise AssertionError(message) raise AssertionError(message)

View file

@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
tx_log = twisted.logger.Logger() tx_log = twisted.logger.Logger()
def emit(self, record): def emit(self, record: logging.LogRecord) -> None:
log_entry = self.format(record) log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn") log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit( self.tx_log.emit(
@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
) )
def setup_logging(): def setup_logging() -> None:
"""Configure the python logging appropriately for the tests. """Configure the python logging appropriately for the tests.
(Logs will end up in _trial_temp.) (Logs will end up in _trial_temp.)

View file

@ -14,7 +14,7 @@
import json import json
from typing import Any, Dict, List, Optional, Tuple from typing import Any, ContextManager, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from urllib.parse import parse_qs from urllib.parse import parse_qs
@ -77,14 +77,14 @@ class FakeOidcServer:
self._id_token_overrides: Dict[str, Any] = {} self._id_token_overrides: Dict[str, Any] = {}
def reset_mocks(self): def reset_mocks(self) -> None:
self.request.reset_mock() self.request.reset_mock()
self.get_jwks_handler.reset_mock() self.get_jwks_handler.reset_mock()
self.get_metadata_handler.reset_mock() self.get_metadata_handler.reset_mock()
self.get_userinfo_handler.reset_mock() self.get_userinfo_handler.reset_mock()
self.post_token_handler.reset_mock() self.post_token_handler.reset_mock()
def patch_homeserver(self, hs: HomeServer): def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
This patch should be used whenever the HS is expected to perform request to the This patch should be used whenever the HS is expected to perform request to the
@ -188,7 +188,7 @@ class FakeOidcServer:
return self._sign(logout_token) return self._sign(logout_token)
def id_token_override(self, overrides: dict): def id_token_override(self, overrides: dict) -> ContextManager[dict]:
"""Temporarily patch the ID token generated by the token endpoint.""" """Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides) return patch.object(self, "_id_token_overrides", overrides)
@ -247,7 +247,7 @@ class FakeOidcServer:
metadata: bool = False, metadata: bool = False,
token: bool = False, token: bool = False,
userinfo: bool = False, userinfo: bool = False,
): ) -> ContextManager[Dict[str, Mock]]:
"""A context which makes a set of endpoints return a 500 error. """A context which makes a set of endpoints return a 500 error.
Args: Args:

View file

@ -258,7 +258,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
def test_out_of_band_invite_rejection(self): def test_out_of_band_invite_rejection(self) -> None:
# this is where we have received an invite event over federation, and then # this is where we have received an invite event over federation, and then
# rejected it. # rejected it.
invite_pdu = { invite_pdu = {

View file

@ -315,7 +315,7 @@ class HomeserverTestCase(TestCase):
# This has to be a function and not just a Mock, because # This has to be a function and not just a Mock, because
# `self.helper.auth_user_id` is temporarily reassigned in some tests # `self.helper.auth_user_id` is temporarily reassigned in some tests
async def get_requester(*args, **kwargs) -> Requester: async def get_requester(*args: Any, **kwargs: Any) -> Requester:
assert self.helper.auth_user_id is not None assert self.helper.auth_user_id is not None
return create_requester( return create_requester(
user_id=UserID.from_string(self.helper.auth_user_id), user_id=UserID.from_string(self.helper.auth_user_id),