Type hints for tests.appservice (#14990)

* Accept a Sequence of events in synapse.appservice

This avoids some casts/ignores in the tests I'm about to fixup. It seems
that `List[Mock]` is not a subtype of `List[EventBase]`, but
`Sequence[Mock]` is a subtype of `Sequence[EventBase]`. So presumably
`Mock` is considered a subtype of anything, much like `Any`.

* make tests.appservice.test_scheduler pass mypy

* Extra hints in tests.appservice.test_scheduler

* Extra hints in tests.appservice.test_api

* Extra hints in tests.appservice.test_appservice

* Disallow untyped defs

* Changelog
This commit is contained in:
David Robertson 2023-02-06 12:49:06 +00:00 committed by GitHub
parent 3e37ff1a7e
commit e8269ed391
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 132 additions and 59 deletions

1
changelog.d/14990.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -32,7 +32,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/appservice/test_scheduler.py
|tests/federation/test_federation_catch_up.py |tests/federation/test_federation_catch_up.py
|tests/federation/test_federation_sender.py |tests/federation/test_federation_sender.py
|tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_matrix_federation_agent.py
@ -78,6 +77,9 @@ disallow_untyped_defs = True
[mypy-tests.app.*] [mypy-tests.app.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.appservice.*]
disallow_untyped_defs = True
[mypy-tests.config.*] [mypy-tests.config.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -16,7 +16,7 @@
import logging import logging
import re import re
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern, Sequence
import attr import attr
from netaddr import IPSet from netaddr import IPSet
@ -377,7 +377,7 @@ class AppServiceTransaction:
self, self,
service: ApplicationService, service: ApplicationService,
id: int, id: int,
events: List[EventBase], events: Sequence[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_keys_count: TransactionOneTimeKeysCount, one_time_keys_count: TransactionOneTimeKeysCount,

View file

@ -14,7 +14,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib.parse import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
)
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import TypeGuard from typing_extensions import TypeGuard
@ -259,7 +269,7 @@ class ApplicationServiceApi(SimpleHttpClient):
async def push_bulk( async def push_bulk(
self, self,
service: "ApplicationService", service: "ApplicationService",
events: List[EventBase], events: Sequence[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_keys_count: TransactionOneTimeKeysCount, one_time_keys_count: TransactionOneTimeKeysCount,

View file

@ -57,6 +57,7 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
) )
@ -364,7 +365,7 @@ class _TransactionController:
async def send( async def send(
self, self,
service: ApplicationService, service: ApplicationService,
events: List[EventBase], events: Sequence[EventBase],
ephemeral: Optional[List[JsonDict]] = None, ephemeral: Optional[List[JsonDict]] = None,
to_device_messages: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None,
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None,

View file

@ -14,7 +14,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Pattern,
Sequence,
Tuple,
cast,
)
from synapse.appservice import ( from synapse.appservice import (
ApplicationService, ApplicationService,
@ -257,7 +267,7 @@ class ApplicationServiceTransactionWorkerStore(
async def create_appservice_txn( async def create_appservice_txn(
self, self,
service: ApplicationService, service: ApplicationService,
events: List[EventBase], events: Sequence[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_keys_count: TransactionOneTimeKeysCount, one_time_keys_count: TransactionOneTimeKeysCount,

View file

@ -29,7 +29,7 @@ URL = "http://mytestservice"
class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.api = hs.get_application_service_api() self.api = hs.get_application_service_api()
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier", id="unique_identifier",
@ -39,7 +39,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
hs_token=TOKEN, hs_token=TOKEN,
) )
def test_query_3pe_authenticates_token(self): def test_query_3pe_authenticates_token(self) -> None:
""" """
Tests that 3pe queries to the appservice are authenticated Tests that 3pe queries to the appservice are authenticated
with the appservice's token. with the appservice's token.

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import Generator
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
@ -27,7 +28,7 @@ def _regex(regex: str, exclusive: bool = True) -> Namespace:
class ApplicationServiceTestCase(unittest.TestCase): class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier", id="unique_identifier",
sender="@as:test", sender="@as:test",
@ -46,7 +47,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_local_users_in_room = simple_async_mock([]) self.store.get_local_users_in_room = simple_async_mock([])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue( self.assertTrue(
@ -60,7 +63,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse( self.assertFalse(
@ -74,7 +79,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
@ -90,7 +97,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_id_match(self): def test_regex_room_id_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
@ -106,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_id_no_match(self): def test_regex_room_id_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
@ -122,7 +133,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_alias_match(self): def test_regex_alias_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
@ -140,44 +153,46 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
) )
def test_non_exclusive_alias(self): def test_non_exclusive_alias(self) -> None:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=False) _regex("#irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
def test_non_exclusive_room(self): def test_non_exclusive_room(self) -> None:
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=False) _regex("!irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org")) self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
def test_non_exclusive_user(self): def test_non_exclusive_user(self) -> None:
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=False) _regex("@irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org")) self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
def test_exclusive_alias(self): def test_exclusive_alias(self) -> None:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=True) _regex("#irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
def test_exclusive_user(self): def test_exclusive_user(self) -> None:
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=True) _regex("@irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org")) self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
def test_exclusive_room(self): def test_exclusive_room(self) -> None:
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=True) _regex("!irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org")) self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_alias_no_match(self): def test_regex_alias_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
@ -196,7 +211,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_multiple_matches(self): def test_regex_multiple_matches(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
@ -215,7 +232,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_interested_in_self(self): def test_interested_in_self(
self,
) -> Generator["defer.Deferred[object]", object, None]:
# make sure invites get through # make sure invites get through
self.service.sender = "@appservice:name" self.service.sender = "@appservice:name"
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
@ -233,7 +252,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_member_list_match(self): def test_member_list_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user. # Note that @irc_fo:here is the AS user.
self.store.get_local_users_in_room = simple_async_mock( self.store.get_local_users_in_room = simple_async_mock(

View file

@ -11,20 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock from unittest.mock import Mock
from typing_extensions import TypeAlias
from twisted.internet import defer from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.appservice.scheduler import ( from synapse.appservice.scheduler import (
ApplicationServiceScheduler, ApplicationServiceScheduler,
_Recoverer, _Recoverer,
_TransactionController, _TransactionController,
) )
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import DeviceListUpdates from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -37,18 +45,18 @@ if TYPE_CHECKING:
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.clock = MockClock() self.clock = MockClock()
self.store = Mock() self.store = Mock()
self.as_api = Mock() self.as_api = Mock()
self.recoverer = Mock() self.recoverer = Mock()
self.recoverer_fn = Mock(return_value=self.recoverer) self.recoverer_fn = Mock(return_value=self.recoverer)
self.txnctrl = _TransactionController( self.txnctrl = _TransactionController(
clock=self.clock, store=self.store, as_api=self.as_api clock=cast(Clock, self.clock), store=self.store, as_api=self.as_api
) )
self.txnctrl.RECOVERER_CLASS = self.recoverer_fn self.txnctrl.RECOVERER_CLASS = self.recoverer_fn
def test_single_service_up_txn_sent(self): def test_single_service_up_txn_sent(self) -> None:
# Test: The AS is up and the txn is successfully sent. # Test: The AS is up and the txn is successfully sent.
service = Mock() service = Mock()
events = [Mock(), Mock()] events = [Mock(), Mock()]
@ -76,7 +84,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed txn.complete.assert_called_once_with(self.store) # txn completed
def test_single_service_down(self): def test_single_service_down(self) -> None:
# Test: The AS is down so it shouldn't push; Recoverers will do it. # Test: The AS is down so it shouldn't push; Recoverers will do it.
# It should still make a transaction though. # It should still make a transaction though.
service = Mock() service = Mock()
@ -103,7 +111,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.assertEqual(0, txn.send.call_count) # txn not sent though self.assertEqual(0, txn.send.call_count) # txn not sent though
self.assertEqual(0, txn.complete.call_count) # or completed self.assertEqual(0, txn.complete.call_count) # or completed
def test_single_service_up_txn_not_sent(self): def test_single_service_up_txn_not_sent(self) -> None:
# Test: The AS is up and the txn is not sent. A Recoverer is made and # Test: The AS is up and the txn is not sent. A Recoverer is made and
# started. # started.
service = Mock() service = Mock()
@ -139,26 +147,28 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.clock = MockClock() self.clock = MockClock()
self.as_api = Mock() self.as_api = Mock()
self.store = Mock() self.store = Mock()
self.service = Mock() self.service = Mock()
self.callback = simple_async_mock() self.callback = simple_async_mock()
self.recoverer = _Recoverer( self.recoverer = _Recoverer(
clock=self.clock, clock=cast(Clock, self.clock),
as_api=self.as_api, as_api=self.as_api,
store=self.store, store=self.store,
service=self.service, service=self.service,
callback=self.callback, callback=self.callback,
) )
def test_recover_single_txn(self): def test_recover_single_txn(self) -> None:
txn = Mock() txn = Mock()
# return one txn to send, then no more old txns # return one txn to send, then no more old txns
txns = [txn, None] txns = [txn, None]
def take_txn(*args, **kwargs): def take_txn(
*args: object, **kwargs: object
) -> "defer.Deferred[Optional[Mock]]":
return defer.succeed(txns.pop(0)) return defer.succeed(txns.pop(0))
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
@ -177,12 +187,14 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.callback.assert_called_once_with(self.recoverer) self.callback.assert_called_once_with(self.recoverer)
self.assertEqual(self.recoverer.service, self.service) self.assertEqual(self.recoverer.service, self.service)
def test_recover_retry_txn(self): def test_recover_retry_txn(self) -> None:
txn = Mock() txn = Mock()
txns = [txn, None] txns = [txn, None]
pop_txn = False pop_txn = False
def take_txn(*args, **kwargs): def take_txn(
*args: object, **kwargs: object
) -> "defer.Deferred[Optional[Mock]]":
if pop_txn: if pop_txn:
return defer.succeed(txns.pop(0)) return defer.succeed(txns.pop(0))
else: else:
@ -214,8 +226,24 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.callback.assert_called_once_with(self.recoverer) self.callback.assert_called_once_with(self.recoverer)
# Corresponds to synapse.appservice.scheduler._TransactionController.send
TxnCtrlArgs: TypeAlias = """
defer.Deferred[
Tuple[
ApplicationService,
Sequence[EventBase],
Optional[List[JsonDict]],
Optional[List[JsonDict]],
Optional[TransactionOneTimeKeysCount],
Optional[TransactionUnusedFallbackKeys],
Optional[DeviceListUpdates],
]
]
"""
class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer): def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None:
self.scheduler = ApplicationServiceScheduler(hs) self.scheduler = ApplicationServiceScheduler(hs)
self.txn_ctrl = Mock() self.txn_ctrl = Mock()
self.txn_ctrl.send = simple_async_mock() self.txn_ctrl.send = simple_async_mock()
@ -224,7 +252,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
self.scheduler.txn_ctrl = self.txn_ctrl self.scheduler.txn_ctrl = self.txn_ctrl
self.scheduler.queuer.txn_ctrl = self.txn_ctrl self.scheduler.queuer.txn_ctrl = self.txn_ctrl
def test_send_single_event_no_queue(self): def test_send_single_event_no_queue(self) -> None:
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4) service = Mock(id=4)
event = Mock() event = Mock()
@ -233,8 +261,8 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service, [event], [], [], None, None, DeviceListUpdates() service, [event], [], [], None, None, DeviceListUpdates()
) )
def test_send_single_event_with_queue(self): def test_send_single_event_with_queue(self) -> None:
d = defer.Deferred() d: TxnCtrlArgs = defer.Deferred()
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
service = Mock(id=4) service = Mock(id=4)
event = Mock(event_id="first") event = Mock(event_id="first")
@ -257,22 +285,22 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(2, self.txn_ctrl.send.call_count) self.assertEqual(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self): def test_multiple_service_queues(self) -> None:
# Tests that each service has its own queue, and that they don't block # Tests that each service has its own queue, and that they don't block
# on each other. # on each other.
srv1 = Mock(id=4) srv1 = Mock(id=4)
srv_1_defer = defer.Deferred() srv_1_defer: "defer.Deferred[EventBase]" = defer.Deferred()
srv_1_event = Mock(event_id="srv1a") srv_1_event = Mock(event_id="srv1a")
srv_1_event2 = Mock(event_id="srv1b") srv_1_event2 = Mock(event_id="srv1b")
srv2 = Mock(id=6) srv2 = Mock(id=6)
srv_2_defer = defer.Deferred() srv_2_defer: "defer.Deferred[EventBase]" = defer.Deferred()
srv_2_event = Mock(event_id="srv2a") srv_2_event = Mock(event_id="srv2a")
srv_2_event2 = Mock(event_id="srv2b") srv_2_event2 = Mock(event_id="srv2b")
send_return_list = [srv_1_defer, srv_2_defer] send_return_list = [srv_1_defer, srv_2_defer]
def do_send(*args, **kwargs): def do_send(*args: object, **kwargs: object) -> "defer.Deferred[EventBase]":
return make_deferred_yieldable(send_return_list.pop(0)) return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send) self.txn_ctrl.send = Mock(side_effect=do_send)
@ -297,12 +325,12 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(3, self.txn_ctrl.send.call_count) self.assertEqual(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self): def test_send_large_txns(self) -> None:
srv_1_defer = defer.Deferred() srv_1_defer: "defer.Deferred[EventBase]" = defer.Deferred()
srv_2_defer = defer.Deferred() srv_2_defer: "defer.Deferred[EventBase]" = defer.Deferred()
send_return_list = [srv_1_defer, srv_2_defer] send_return_list = [srv_1_defer, srv_2_defer]
def do_send(*args, **kwargs): def do_send(*args: object, **kwargs: object) -> "defer.Deferred[EventBase]":
return make_deferred_yieldable(send_return_list.pop(0)) return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send) self.txn_ctrl.send = Mock(side_effect=do_send)
@ -328,7 +356,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(3, self.txn_ctrl.send.call_count) self.assertEqual(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self): def test_send_single_ephemeral_no_queue(self) -> None:
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
event_list = [Mock(name="event")] event_list = [Mock(name="event")]
@ -337,7 +365,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service, [], event_list, [], None, None, DeviceListUpdates() service, [], event_list, [], None, None, DeviceListUpdates()
) )
def test_send_multiple_ephemeral_no_queue(self): def test_send_multiple_ephemeral_no_queue(self) -> None:
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
@ -346,8 +374,8 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service, [], event_list, [], None, None, DeviceListUpdates() service, [], event_list, [], None, None, DeviceListUpdates()
) )
def test_send_single_ephemeral_with_queue(self): def test_send_single_ephemeral_with_queue(self) -> None:
d = defer.Deferred() d: TxnCtrlArgs = defer.Deferred()
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
service = Mock(id=4) service = Mock(id=4)
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
@ -377,8 +405,8 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(2, self.txn_ctrl.send.call_count) self.assertEqual(2, self.txn_ctrl.send.call_count)
def test_send_large_txns_ephemeral(self): def test_send_large_txns_ephemeral(self) -> None:
d = defer.Deferred() d: TxnCtrlArgs = defer.Deferred()
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4, name="service") service = Mock(id=4, name="service")