Add type hints for tests/unittest.py. (#12347)

In particular, add type hints for get_success and friends, which are then helpful in a bunch of places.
This commit is contained in:
Richard van der Hoff 2022-04-01 17:04:16 +01:00 committed by GitHub
parent 33ebee47e4
commit f0b03186d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 97 additions and 48 deletions

1
changelog.d/12347.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations for `tests/unittest.py`.

View file

@ -83,7 +83,6 @@ exclude = (?x)
|tests/test_server.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/unittest.py
|tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py
|tests/util/caches/test_descriptors.py

View file

@ -463,8 +463,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 400)
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
query_res = self.get_success(
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(query_res, {local_user: {}})
def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""

View file

@ -375,7 +375,8 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
member_event.signatures = member_event_dict["signatures"]
# Add the new member_event to the StateMap
prev_state_map[
updated_state_map = dict(prev_state_map)
updated_state_map[
(member_event.type, member_event.state_key)
] = member_event.event_id
auth_events.append(member_event)
@ -399,7 +400,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
prev_event_ids=message_event_dict["prev_events"],
auth_event_ids=self._event_auth_handler.compute_auth_events(
builder,
prev_state_map,
updated_state_map,
for_verification=False,
),
depth=message_event_dict["depth"],

View file

@ -354,10 +354,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
req = Mock(spec=["cookies"])
req.cookies = []
url = self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
url = urlparse(
self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
self.assertEqual(url.scheme, auth_endpoint.scheme)

View file

@ -351,6 +351,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)
def test_handle_local_profile_change_with_deactivated_user(self) -> None:
@ -369,6 +370,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name)
# deactivate user

View file

@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
"""
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
# quarantining
@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["quarantined_by"])
# remove from quarantine
@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
def test_quarantine_protected_media(self) -> None:
@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])
# quarantining
@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"])
@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
"""
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])
# protect
@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"])
# unprotect
@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"])

View file

@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
pushers = self.get_success(
self.store.get_pushers_by({"user_name": "@bob:test"})
pushers = list(
self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual("@bob:test", pushers[0].user_name)
@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
pushers = self.get_success(
self.store.get_pushers_by({"user_name": "@bob:test"})
pushers = list(
self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
)
pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def test_set_password(self) -> None:
@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# is in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
assert profile is not None
self.assertTrue(profile["display_name"] == "User")
# Deactivate user
@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
user_tuple = self.get_success(
self.store.get_user_by_access_token(other_user_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# The user starts off as not shadow-banned.
other_user_token = self.login("user", "pass")
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# Ensure the user is shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertTrue(result.shadow_banned)
# Un-shadow-ban the user.
@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# Ensure the user is no longer shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertFalse(result.shadow_banned)

View file

@ -22,7 +22,6 @@ import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
AnyStr,
Callable,
Dict,
Iterable,
@ -86,6 +85,9 @@ from tests.utils import (
logger = logging.getLogger(__name__)
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
class TimedOutException(Exception):
"""
@ -260,7 +262,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""

View file

@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""
# First to acquire this lock, so it should complete
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None
# Enter the context manager
self.get_success(lock.__aenter__())
@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# We can now acquire the lock again.
lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock3)
assert lock3 is not None
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))
@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""Test that we don't time out locks while they're still active"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None
self.get_success(lock.__aenter__())
@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""Test that we time out locks if they're not updated for ages"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock)
assert lock is not None
self.get_success(lock.__aenter__())

View file

@ -358,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 12, other_events))
txn = self.get_success(self.store.get_oldest_unsent_txn(service))
assert txn is not None
self.assertEqual(service, txn.service)
self.assertEqual(10, txn.id)
self.assertEqual(events, txn.events)

View file

@ -22,10 +22,11 @@ import secrets
import time
from typing import (
Any,
AnyStr,
Awaitable,
Callable,
ClassVar,
Dict,
Generic,
Iterable,
List,
Optional,
@ -39,6 +40,7 @@ from unittest.mock import Mock, patch
import canonicaljson
import signedjson.key
import unpaddedbase64
from typing_extensions import Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
@ -49,7 +51,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
@ -70,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
from tests.server import (
CustomHeaderType,
FakeChannel,
get_clock,
make_request,
setup_test_homeserver,
)
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@ -78,6 +86,17 @@ from tests.utils import default_config, setupdb
setupdb()
setup_logging()
TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""
@property
def value(self) -> _ExcType:
...
def around(target):
"""A CLOS-style 'around' modifier, which wraps the original method of the
@ -276,6 +295,7 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
assert self.helper.auth_user_id is not None
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
@ -288,6 +308,7 @@ class HomeserverTestCase(TestCase):
)
async def get_user_by_access_token(token=None, allow_guest=False):
assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id,
@ -295,6 +316,7 @@ class HomeserverTestCase(TestCase):
}
async def get_user_by_req(request, allow_guest=False, rights="access"):
assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
token_id,
@ -311,7 +333,7 @@ class HomeserverTestCase(TestCase):
)
if self.needs_threadpool:
self.reactor.threadpool = ThreadPool()
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
self.addCleanup(self.reactor.threadpool.stop)
self.reactor.threadpool.start()
@ -426,7 +448,7 @@ class HomeserverTestCase(TestCase):
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
@ -511,30 +533,36 @@ class HomeserverTestCase(TestCase):
return hs
def pump(self, by=0.0):
def pump(self, by: float = 0.0) -> None:
"""
Pump the reactor enough that Deferreds will fire.
"""
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
deferred: Deferred[TV] = ensureDeferred(d)
def get_success(
self,
d: Awaitable[TV],
by: float = 0.0,
) -> TV:
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by)
return self.successResultOf(deferred)
def get_failure(self, d, exc):
def get_failure(
self, d: Awaitable[Any], exc: Type[_ExcType]
) -> _TypedFailure[_ExcType]:
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
deferred: Deferred[Any] = ensureDeferred(d)
deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
self.pump()
return self.failureResultOf(deferred, exc)
def get_success_or_raise(self, d, by=0.0):
def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Drive deferred to completion and return result or raise exception
on failure.
"""
deferred: Deferred[TV] = ensureDeferred(d)
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
results: list = []
deferred.addBoth(results.append)
@ -642,11 +670,11 @@ class HomeserverTestCase(TestCase):
def login(
self,
username,
password,
device_id=None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
username: str,
password: str,
device_id: Optional[str] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
) -> str:
"""
Log in a user, and get an access token. Requires the Login API be
registered.
@ -668,18 +696,22 @@ class HomeserverTestCase(TestCase):
return access_token
def create_and_send_event(
self, room_id, user, soft_failed=False, prev_event_ids=None
):
self,
room_id: str,
user: UserID,
soft_failed: bool = False,
prev_event_ids: Optional[List[str]] = None,
) -> str:
"""
Create and send an event.
Args:
soft_failed (bool): Whether to create a soft failed event or not
prev_event_ids (list[str]|None): Explicitly set the prev events,
soft_failed: Whether to create a soft failed event or not
prev_event_ids: Explicitly set the prev events,
or if None just use the default
Returns:
str: The new event's ID.
The new event's ID.
"""
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
@ -706,7 +738,7 @@ class HomeserverTestCase(TestCase):
return event.event_id
def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
def inject_room_member(self, room: str, user: str, membership: str) -> None:
"""
Inject a membership event into a room.
@ -766,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
path: str,
content: Optional[JsonDict] = None,
await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""Make an inbound signed federation request to this server
@ -799,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
self.site,
method=method,
path=path,
content=content,
content=content or "",
shorthand=False,
await_result=await_result,
custom_headers=custom_headers,
@ -878,9 +910,6 @@ def override_config(extra_config):
return decorator
TV = TypeVar("TV")
def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
"""A test decorator which will skip the decorated test unless a condition is set