Convert misc database code to async (#8087)

This commit is contained in:
Patrick Cloke 2020-08-14 07:24:26 -04:00 committed by GitHub
parent 7bdf9828d5
commit 894dae74fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 39 additions and 64 deletions

1
changelog.d/8087.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -18,8 +18,6 @@ from typing import Optional
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines from . import engines
@ -308,9 +306,8 @@ class BackgroundUpdater(object):
update_name (str): Name of update update_name (str): Name of update
""" """
@defer.inlineCallbacks async def noop_update(progress, batch_size):
def noop_update(progress, batch_size): await self._end_background_update(update_name)
yield self._end_background_update(update_name)
return 1 return 1
self.register_background_update_handler(update_name, noop_update) self.register_background_update_handler(update_name, noop_update)
@ -409,12 +406,11 @@ class BackgroundUpdater(object):
else: else:
runner = create_index_sqlite runner = create_index_sqlite
@defer.inlineCallbacks async def updater(progress, batch_size):
def updater(progress, batch_size):
if runner is not None: if runner is not None:
logger.info("Adding index %s to %s", index_name, table) logger.info("Adding index %s to %s", index_name, table)
yield self.db_pool.runWithConnection(runner) await self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name) await self._end_background_update(update_name)
return 1 return 1
self.register_background_update_handler(update_name, updater) self.register_background_update_handler(update_name, updater)

View file

@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList( @cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote", cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", list_name="user_ids",
inlineCallbacks=True,
) )
def get_device_list_last_stream_id_for_remotes(self, user_ids: str): async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,

View file

@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3 self._rotate_delay = 3
self._rotate_count = 10000 self._rotate_count = 10000
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) @cached(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user( async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id, user_id, last_read_event_id
): ):
ret = yield self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room", "get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn, self._get_unread_counts_by_receipt_txn,
room_id, room_id,
user_id, user_id,
last_read_event_id, last_read_event_id,
) )
return ret
def _get_unread_counts_by_receipt_txn( def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id self, txn, room_id, user_id, last_read_event_id

View file

@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
cached_method_name="_get_presence_for_user", cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
) )
def get_presence_for_users(self, user_ids): async def get_presence_for_users(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="presence_stream", table="presence_stream",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,

View file

@ -170,18 +170,15 @@ class PushRulesWorkerStore(
) )
@cachedList( @cachedList(
cached_method_name="get_push_rules_for_user", cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
) )
def bulk_get_push_rules(self, user_ids): async def bulk_get_push_rules(self, user_ids):
if not user_ids: if not user_ids:
return {} return {}
results = {user_id: [] for user_id in user_ids} results = {user_id: [] for user_id in user_ids}
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="push_rules", table="push_rules",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,
@ -194,7 +191,7 @@ class PushRulesWorkerStore(
for row in rows: for row in rows:
results.setdefault(row["user_name"], []).append(row) results.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items(): for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules use_new_defaults = user_id in self._users_new_default_push_rules
@ -260,15 +257,14 @@ class PushRulesWorkerStore(
cached_method_name="get_push_rules_enabled_for_user", cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", list_name="user_ids",
num_args=1, num_args=1,
inlineCallbacks=True,
) )
def bulk_get_push_rules_enabled(self, user_ids): async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids: if not user_ids:
return {} return {}
results = {user_id: {} for user_id in user_ids} results = {user_id: {} for user_id in user_ids}
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable", table="push_rules_enable",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,

View file

@ -170,13 +170,10 @@ class PusherWorkerStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
cached_method_name="get_if_user_has_pusher", cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
) )
def get_if_users_have_pushers(self, user_ids): async def get_if_users_have_pushers(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="pushers", table="pushers",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,

View file

@ -212,9 +212,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room", cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids", list_name="room_ids",
num_args=3, num_args=3,
inlineCallbacks=True,
) )
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids: if not room_ids:
return {} return {}
@ -243,7 +242,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
txn_results = yield self.db_pool.runInteraction( txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f "_get_linearized_receipts_for_rooms", f
) )

View file

@ -17,8 +17,6 @@
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count, lambda: self._known_servers_count,
) )
@defer.inlineCallbacks async def _count_known_servers(self):
def _count_known_servers(self):
""" """
Count the servers that this server knows about. Count the servers that this server knows about.
@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query) txn.execute(query)
return list(txn)[0][0] return list(txn)[0][0]
count = yield self.db_pool.runInteraction("get_known_servers", _transact) count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in # We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new). # room_memberships (for example, the server is new).
@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
cached_method_name="_get_joined_profile_from_event_id", cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
list_name="event_ids",
inlineCallbacks=True,
) )
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join """For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info. event and if so return the associated user and profile info.
@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup event_ids: The member event IDs to lookup
Returns: Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event). to `user_id` and ProfileInfo (or None if not join event).
""" """
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,

View file

@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event", cached_method_name="_get_state_group_for_event",
list_name="event_ids", list_name="event_ids",
num_args=1, num_args=1,
inlineCallbacks=True,
) )
def _get_state_group_for_events(self, event_ids): async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group """Returns mapping event_id -> state_group
""" """
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups", table="event_to_state_groups",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,

View file

@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
desc="is_user_erased", desc="is_user_erased",
).addCallback(operator.truth) ).addCallback(operator.truth)
@cachedList( @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True async def are_users_erased(self, user_ids):
)
def are_users_erased(self, user_ids):
""" """
Checks which users in a list have requested erasure Checks which users in a list have requested erasure
@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check user_ids (iterable[str]): full user id to check
Returns: Returns:
Deferred[dict[str, bool]]: dict[str, bool]:
for each user, whether the user has requested erasure. for each user, whether the user has requested erasure.
""" """
# this serves the dual purpose of (a) making sure we can do len and # this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates. # iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids)) user_ids = tuple(set(user_ids))
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="erased_users", table="erased_users",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
) )
erased_users = {row["user_id"] for row in rows} erased_users = {row["user_id"] for row in rows}
res = {u: u in erased_users for u in user_ids} return {u: u in erased_users for u in user_ids}
return res
class UserErasureStore(UserErasureWorkerStore): class UserErasureStore(UserErasureWorkerStore):