Convert groups and visibility code to async / await. (#7951)

This commit is contained in:
Patrick Cloke 2020-07-27 12:32:08 -04:00 committed by GitHub
parent 8144bc26a7
commit 5f65e62681
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 37 deletions

1
changelog.d/7951.misc Normal file
View file

@ -0,0 +1 @@
Convert groups and visibility code to async / await.

View file

@ -41,8 +41,6 @@ from typing import Tuple
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -72,8 +70,9 @@ class GroupAttestationSigning(object):
self.server_name = hs.hostname self.server_name = hs.hostname
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
@defer.inlineCallbacks async def verify_attestation(
def verify_attestation(self, attestation, group_id, user_id, server_name=None): self, attestation, group_id, user_id, server_name=None
):
"""Verifies that the given attestation matches the given parameters. """Verifies that the given attestation matches the given parameters.
An optional server_name can be supplied to explicitly set which server's An optional server_name can be supplied to explicitly set which server's
@ -102,7 +101,7 @@ class GroupAttestationSigning(object):
if valid_until_ms < now: if valid_until_ms < now:
raise SynapseError(400, "Attestation expired") raise SynapseError(400, "Attestation expired")
yield self.keyring.verify_json_for_server( await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation" server_name, attestation, now, "Group attestation"
) )
@ -142,8 +141,7 @@ class GroupAttestionRenewer(object):
self._start_renew_attestations, 30 * 60 * 1000 self._start_renew_attestations, 30 * 60 * 1000
) )
@defer.inlineCallbacks async def on_renew_attestation(self, group_id, user_id, content):
def on_renew_attestation(self, group_id, user_id, content):
"""When a remote updates an attestation """When a remote updates an attestation
""" """
attestation = content["attestation"] attestation = content["attestation"]
@ -151,11 +149,11 @@ class GroupAttestionRenewer(object):
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
raise SynapseError(400, "Neither user not group are on this server") raise SynapseError(400, "Neither user not group are on this server")
yield self.attestations.verify_attestation( await self.attestations.verify_attestation(
attestation, user_id=user_id, group_id=group_id attestation, user_id=user_id, group_id=group_id
) )
yield self.store.update_remote_attestion(group_id, user_id, attestation) await self.store.update_remote_attestion(group_id, user_id, attestation)
return {} return {}
@ -172,8 +170,7 @@ class GroupAttestionRenewer(object):
now + UPDATE_ATTESTATION_TIME_MS now + UPDATE_ATTESTATION_TIME_MS
) )
@defer.inlineCallbacks async def _renew_attestation(group_user: Tuple[str, str]):
def _renew_attestation(group_user: Tuple[str, str]):
group_id, user_id = group_user group_id, user_id = group_user
try: try:
if not self.is_mine_id(group_id): if not self.is_mine_id(group_id):
@ -186,16 +183,16 @@ class GroupAttestionRenewer(object):
user_id, user_id,
group_id, group_id,
) )
yield self.store.remove_attestation_renewal(group_id, user_id) await self.store.remove_attestation_renewal(group_id, user_id)
return return
attestation = self.attestations.create_attestation(group_id, user_id) attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation( await self.transport_client.renew_group_attestation(
destination, group_id, user_id, content={"attestation": attestation} destination, group_id, user_id, content={"attestation": attestation}
) )
yield self.store.update_attestation_renewal( await self.store.update_attestation_renewal(
group_id, user_id, attestation group_id, user_id, attestation
) )
except (RequestSendFailed, HttpResponseException) as e: except (RequestSendFailed, HttpResponseException) as e:

View file

@ -16,8 +16,6 @@
import logging import logging
import operator import operator
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage import Storage from synapse.storage import Storage
@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = (
) )
@defer.inlineCallbacks async def filter_events_for_client(
def filter_events_for_client(
storage: Storage, storage: Storage,
user_id, user_id,
events, events,
@ -67,19 +64,19 @@ def filter_events_for_client(
also be called to check whether a user can see the state at a given point. also be called to check whether a user can see the state at a given point.
Returns: Returns:
Deferred[list[synapse.events.EventBase]] list[synapse.events.EventBase]
""" """
# Filter out events that have been soft failed so that we don't relay them # Filter out events that have been soft failed so that we don't relay them
# to clients. # to clients.
events = [e for e in events if not e.internal_metadata.is_soft_failed()] events = [e for e in events if not e.internal_metadata.is_soft_failed()]
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
event_id_to_state = yield storage.state.get_state_for_events( event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types), state_filter=StateFilter.from_types(types),
) )
ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id "m.ignored_user_list", user_id
) )
@ -90,7 +87,7 @@ def filter_events_for_client(
else [] else []
) )
erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) erased_senders = await storage.main.are_users_erased((e.sender for e in events))
if filter_send_to_client: if filter_send_to_client:
room_ids = {e.room_id for e in events} room_ids = {e.room_id for e in events}
@ -99,7 +96,7 @@ def filter_events_for_client(
for room_id in room_ids: for room_id in room_ids:
retention_policies[ retention_policies[
room_id room_id
] = yield storage.main.get_retention_policy_for_room(room_id) ] = await storage.main.get_retention_policy_for_room(room_id)
def allowed(event): def allowed(event):
""" """
@ -254,8 +251,7 @@ def filter_events_for_client(
return list(filtered_events) return list(filtered_events)
@defer.inlineCallbacks async def filter_events_for_server(
def filter_events_for_server(
storage: Storage, storage: Storage,
server_name, server_name,
events, events,
@ -277,7 +273,7 @@ def filter_events_for_server(
backfill or not. backfill or not.
Returns Returns
Deferred[list[FrozenEvent]] list[FrozenEvent]
""" """
def is_sender_erased(event, erased_senders): def is_sender_erased(event, erased_senders):
@ -321,7 +317,7 @@ def filter_events_for_server(
# Lets check to see if all the events have a history visibility # Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If that's the case then we don't # of "shared" or "world_readable". If that's the case then we don't
# need to check membership (as we know the server is in the room). # need to check membership (as we know the server is in the room).
event_to_state_ids = yield storage.state.get_state_ids_for_events( event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""),) types=((EventTypes.RoomHistoryVisibility, ""),)
@ -339,14 +335,14 @@ def filter_events_for_server(
if not visibility_ids: if not visibility_ids:
all_open = True all_open = True
else: else:
event_map = yield storage.main.get_events(visibility_ids) event_map = await storage.main.get_events(visibility_ids)
all_open = all( all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable") e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.values() for e in event_map.values()
) )
if not check_history_visibility_only: if not check_history_visibility_only:
erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) erased_senders = await storage.main.are_users_erased((e.sender for e in events))
else: else:
# We don't want to check whether users are erased, which is equivalent # We don't want to check whether users are erased, which is equivalent
# to no users having been erased. # to no users having been erased.
@ -375,7 +371,7 @@ def filter_events_for_server(
# first, for each event we're wanting to return, get the event_ids # first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events. # of the history vis and membership state at those events.
event_to_state_ids = yield storage.state.get_state_ids_for_events( event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@ -405,7 +401,7 @@ def filter_events_for_server(
return False return False
return state_key[idx + 1 :] == server_name return state_key[idx + 1 :] == server_name
event_map = yield storage.main.get_events( event_map = await storage.main.get_events(
[e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
) )

View file

@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
evt = yield self.inject_room_member(user, extra_content={"a": "b"}) evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt) events_to_filter.append(evt)
filtered = yield filter_events_for_server( filtered = yield defer.ensureDeferred(
self.storage, "test_server", events_to_filter filter_events_for_server(self.storage, "test_server", events_to_filter)
) )
# the result should be 5 redacted events, and 5 unredacted events. # the result should be 5 redacted events, and 5 unredacted events.
@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens. # ... and the filtering happens.
filtered = yield filter_events_for_server( filtered = yield defer.ensureDeferred(
self.storage, "test_server", events_to_filter filter_events_for_server(self.storage, "test_server", events_to_filter)
) )
for i in range(0, len(events_to_filter)): for i in range(0, len(events_to_filter)):
@ -265,8 +265,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
storage.main = test_store storage.main = test_store
storage.state = test_store storage.state = test_store
filtered = yield filter_events_for_server( filtered = yield defer.ensureDeferred(
test_store, "test_server", events_to_filter filter_events_for_server(test_store, "test_server", events_to_filter)
) )
logger.info("Filtering took %f seconds", time.time() - start) logger.info("Filtering took %f seconds", time.time() - start)