Allow for make_awaitable's return value to be re-used. (#8261)

This commit is contained in:
Patrick Cloke 2020-09-08 07:26:55 -04:00 committed by GitHub
parent 68cdb3708e
commit cef00211c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 56 additions and 70 deletions

1
changelog.d/8261.misc Normal file
View file

@ -0,0 +1 @@
Simplify tests that mock asynchronous functions.

View file

@ -77,11 +77,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock( fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(
@ -110,11 +108,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock( fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(
@ -150,11 +146,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock( fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
side_effect=lambda *args, **kwargs: make_awaitable(None)
)
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
# Artificially raise the complexity # Artificially raise the complexity
@ -208,11 +202,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock( fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(
@ -240,11 +232,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client() fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock( fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
)
handler.federation_handler.do_invite_join = Mock( handler.federation_handler.do_invite_join = Mock(
side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) return_value=make_awaitable(("", 1))
) )
d = handler._remote_join( d = handler._remote_join(

View file

@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call. # Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"] ["test", "host2"]
) )
return self.setup_test_homeserver( return self.setup_test_homeserver(

View file

@ -143,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded_large(self): def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.large_number_of_users) return_value=make_awaitable(self.large_number_of_users)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.large_number_of_users) return_value=make_awaitable(self.large_number_of_users)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -169,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
# If not in monthly active cohort # If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) return_value=make_awaitable(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -179,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) return_value=make_awaitable(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -189,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
) )
# If in monthly active cohort # If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock( self.hs.get_datastore().user_last_seen_monthly_active = Mock(
side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) return_value=make_awaitable(self.hs.get_clock().time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) return_value=make_awaitable(self.auth_blocking._max_mau_value)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
@ -200,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
) )
) )
self.hs.get_datastore().user_last_seen_monthly_active = Mock( self.hs.get_datastore().user_last_seen_monthly_active = Mock(
side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) return_value=make_awaitable(self.hs.get_clock().time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) return_value=make_awaitable(self.auth_blocking._max_mau_value)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
@ -216,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.small_number_of_users) return_value=make_awaitable(self.small_number_of_users)
) )
# Ensure does not raise exception # Ensure does not raise exception
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -226,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.small_number_of_users) return_value=make_awaitable(self.small_number_of_users)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(

View file

@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock( self.store.count_monthly_users = Mock(
side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1) return_value=make_awaitable(self.hs.config.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User")) self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@ -108,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_blocked(self): def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.lots_of_users) return_value=make_awaitable(self.lots_of_users)
) )
self.get_failure( self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"), self.get_or_create_user(self.requester, "b", "display_name"),
@ -116,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
) )
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) return_value=make_awaitable(self.hs.config.max_mau_value)
) )
self.get_failure( self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"), self.get_or_create_user(self.requester, "b", "display_name"),
@ -126,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_mau_blocked(self): def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.lots_of_users) return_value=make_awaitable(self.lots_of_users)
) )
self.get_failure( self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError self.handler.register_user(localpart="local_part"), ResourceLimitError
) )
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) return_value=make_awaitable(self.hs.config.max_mau_value)
) )
self.get_failure( self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError self.handler.register_user(localpart="local_part"), ResourceLimitError

View file

@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res retry_timings_res
) )
self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
(0, []) (0, [])
) )

View file

@ -45,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event. new event.
""" """
mock_client = Mock(spec=["put_json"]) mock_client = Mock(spec=["put_json"])
mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) mock_client.put_json.return_value = make_awaitable({})
self.make_worker_hs( self.make_worker_hs(
"synapse.app.federation_sender", "synapse.app.federation_sender",
@ -73,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events. new events.
""" """
mock_client1 = Mock(spec=["put_json"]) mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs( self.make_worker_hs(
"synapse.app.federation_sender", "synapse.app.federation_sender",
{ {
@ -85,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
) )
mock_client2 = Mock(spec=["put_json"]) mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs( self.make_worker_hs(
"synapse.app.federation_sender", "synapse.app.federation_sender",
{ {
@ -136,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs. new typing EDUs.
""" """
mock_client1 = Mock(spec=["put_json"]) mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs( self.make_worker_hs(
"synapse.app.federation_sender", "synapse.app.federation_sender",
{ {
@ -148,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
) )
mock_client2 = Mock(spec=["put_json"]) mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs( self.make_worker_hs(
"synapse.app.federation_sender", "synapse.app.federation_sender",
{ {

View file

@ -337,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit # Set monthly active users to the limit
store.get_monthly_active_count = Mock( store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) return_value=make_awaitable(self.hs.config.max_mau_value)
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit
@ -591,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit # Set monthly active users to the limit
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) return_value=make_awaitable(self.hs.config.max_mau_value)
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit
@ -631,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit # Set monthly active users to the limit
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) return_value=make_awaitable(self.hs.config.max_mau_value)
) )
# Check that the blocking of monthly active users is working as expected # Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit # The registration of a new user fails due to the limit

View file

@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
raise Exception("Failed to find reference to ResourceLimitsServerNotices") raise Exception("Failed to find reference to ResourceLimitsServerNotices")
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
side_effect=lambda user_id: make_awaitable(1000) return_value=make_awaitable(1000)
) )
self._rlsn._server_notices_manager.send_notice = Mock( self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock()) return_value=defer.succeed(Mock())
@ -80,9 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=defer.succeed("!something:localhost") return_value=defer.succeed("!something:localhost")
) )
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock( self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
side_effect=lambda user_id, room_id: make_awaitable({})
)
@override_config({"hs_disabled": True}) @override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self): def test_maybe_send_server_notice_disabled_hs(self):
@ -158,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
side_effect=lambda user_id: make_awaitable(None) return_value=make_awaitable(None)
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -261,12 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
side_effect=lambda: make_awaitable(1000)
)
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(
side_effect=lambda user_id: make_awaitable(1000) return_value=make_awaitable(1000)
) )
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once

View file

@ -154,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
user_id = "@user:server" user_id = "@user:server"
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
side_effect=lambda: make_awaitable(lots_of_users) return_value=make_awaitable(lots_of_users)
) )
self.get_success( self.get_success(
self.store.insert_client_ip( self.store.insert_client_ip(

View file

@ -231,9 +231,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
) )
self.get_success(d) self.get_success(d)
self.store.upsert_monthly_active_user = Mock( self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
side_effect=lambda user_id: make_awaitable(None)
)
d = self.store.populate_monthly_active_users(user_id) d = self.store.populate_monthly_active_users(user_id)
self.get_success(d) self.get_success(d)
@ -241,9 +239,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called() self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self): def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock( self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
side_effect=lambda user_id: make_awaitable(None)
)
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@ -256,9 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once() self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self): def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock( self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
side_effect=lambda user_id: make_awaitable(None)
)
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(
@ -344,9 +338,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self): def test_no_users_when_not_tracking(self):
self.store.upsert_monthly_active_user = Mock( self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
side_effect=lambda user_id: make_awaitable(None)
)
self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.get_success(self.store.populate_monthly_active_users("@user:sever"))

View file

@ -17,6 +17,7 @@
""" """
Utilities for running the unit tests Utilities for running the unit tests
""" """
from asyncio import Future
from typing import Any, Awaitable, TypeVar from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV") TV = TypeVar("TV")
@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed") raise Exception("awaitable has not yet completed")
async def make_awaitable(result: Any): def make_awaitable(result: Any) -> Awaitable[Any]:
"""Create an awaitable that just returns a result.""" """
return result Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned
to multiple callers.
"""
future = Future() # type: ignore
future.set_result(result)
return future