Port SyncHandler to async/await

This commit is contained in:
Erik Johnston 2019-12-05 17:58:25 +00:00
parent d085a8a0a5
commit 8437e2383e
6 changed files with 182 additions and 191 deletions

View file

@ -16,8 +16,6 @@
import logging
import random
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@log_function
def get_stream(
async def get_stream(
self,
auth_user_id,
pagin_config,
@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler):
"""
if room_id:
blocked = yield self.store.is_room_blocked(room_id)
blocked = await self.store.is_room_blocked(room_id)
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
# send any outstanding server notices to the user.
yield self._server_notices_sender.on_user_syncing(auth_user_id)
await self._server_notices_sender.on_user_syncing(auth_user_id)
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
context = yield presence_handler.user_syncing(
context = await presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence
)
with context:
@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for(
events, tokens = await self.notifier.get_events_for(
auth_user,
pagin_config,
timeout,
@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.state.get_current_users_in_room(
users = await self.state.get_current_users_in_room(
event.room_id
)
states = yield presence_handler.get_states(users, as_event=True)
states = await presence_handler.get_states(users, as_event=True)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
ev = await presence_handler.get_state(
UserID.from_string(event.state_key), as_event=True
)
to_add.append(ev)
@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
chunks = yield self._event_serializer.serialize_events(
chunks = await self._event_serializer.serialize_events(
events,
time_now,
as_client_event=as_client_event,
@ -151,8 +148,7 @@ class EventHandler(BaseHandler):
super(EventHandler, self).__init__(hs)
self.storage = hs.get_storage()
@defer.inlineCallbacks
def get_event(self, user, room_id, event_id):
async def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
Args:
@ -167,15 +163,15 @@ class EventHandler(BaseHandler):
AuthError if the user does not have the rights to inspect this
event.
"""
event = yield self.store.get_event(event_id, check_room_id=room_id)
event = await self.store.get_event(event_id, check_room_id=room_id)
if not event:
return None
users = yield self.store.get_users_in_room(event.room_id)
users = await self.store.get_users_in_room(event.room_id)
is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
filtered = await filter_events_for_client(
self.storage, user.to_string(), [event], is_peeking=is_peeking
)

View file

@ -22,8 +22,6 @@ from six import iteritems, itervalues
from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.logging.context import LoggingContext
from synapse.push.clientformat import format_push_rules_for_user
@ -241,8 +239,7 @@ class SyncHandler(object):
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
@defer.inlineCallbacks
def wait_for_sync_for_user(
async def wait_for_sync_for_user(
self, sync_config, since_token=None, timeout=0, full_state=False
):
"""Get the sync for a client if we have new data for it now. Otherwise
@ -255,9 +252,9 @@ class SyncHandler(object):
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
yield self.auth.check_auth_blocking(user_id)
await self.auth.check_auth_blocking(user_id)
res = yield self.response_cache.wrap(
res = await self.response_cache.wrap(
sync_config.request_key,
self._wait_for_sync_for_user,
sync_config,
@ -267,8 +264,9 @@ class SyncHandler(object):
)
return res
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
async def _wait_for_sync_for_user(
self, sync_config, since_token, timeout, full_state
):
if since_token is None:
sync_type = "initial_sync"
elif full_state:
@ -283,7 +281,7 @@ class SyncHandler(object):
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result = yield self.current_sync_for_user(
result = await self.current_sync_for_user(
sync_config, since_token, full_state=full_state
)
else:
@ -291,7 +289,7 @@ class SyncHandler(object):
def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token)
result = yield self.notifier.wait_for_events(
result = await self.notifier.wait_for_events(
sync_config.user.to_string(),
timeout,
current_sync_callback,
@ -314,15 +312,13 @@ class SyncHandler(object):
"""
return self.generate_sync_result(sync_config, since_token, full_state)
@defer.inlineCallbacks
def push_rules_for_user(self, user):
async def push_rules_for_user(self, user):
user_id = user.to_string()
rules = yield self.store.get_push_rules_for_user(user_id)
rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
return rules
@defer.inlineCallbacks
def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in
Args:
sync_result_builder(SyncResultBuilder)
@ -343,7 +339,7 @@ class SyncHandler(object):
room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events(
typing, typing_key = typing_source.get_new_events(
user=sync_config.user,
from_key=typing_key,
limit=sync_config.filter_collection.ephemeral_limit(),
@ -365,7 +361,7 @@ class SyncHandler(object):
receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events(
receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user,
from_key=receipt_key,
limit=sync_config.filter_collection.ephemeral_limit(),
@ -382,8 +378,7 @@ class SyncHandler(object):
return now_token, ephemeral_by_room
@defer.inlineCallbacks
def _load_filtered_recents(
async def _load_filtered_recents(
self,
room_id,
sync_config,
@ -415,10 +410,10 @@ class SyncHandler(object):
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = await self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client(
recents = await filter_events_for_client(
self.storage,
sync_config.user.to_string(),
recents,
@ -449,14 +444,14 @@ class SyncHandler(object):
# Otherwise, we want to return the last N events in the room
# in toplogical ordering.
if since_key:
events, end_key = yield self.store.get_room_events_stream_for_room(
events, end_key = await self.store.get_room_events_stream_for_room(
room_id,
limit=load_limit + 1,
from_key=since_key,
to_key=end_key,
)
else:
events, end_key = yield self.store.get_recent_events_for_room(
events, end_key = await self.store.get_recent_events_for_room(
room_id, limit=load_limit + 1, end_token=end_key
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(
@ -468,10 +463,10 @@ class SyncHandler(object):
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = await self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client(
loaded_recents = await filter_events_for_client(
self.storage,
sync_config.user.to_string(),
loaded_recents,
@ -498,8 +493,7 @@ class SyncHandler(object):
limited=limited or newly_joined_room,
)
@defer.inlineCallbacks
def get_state_after_event(self, event, state_filter=StateFilter.all()):
async def get_state_after_event(self, event, state_filter=StateFilter.all()):
"""
Get the room state after the given event
@ -511,7 +505,7 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
state_ids = yield self.state_store.get_state_ids_for_event(
state_ids = await self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter
)
if event.is_state():
@ -519,8 +513,9 @@ class SyncHandler(object):
state_ids[(event.type, event.state_key)] = event.event_id
return state_ids
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
async def get_state_at(
self, room_id, stream_position, state_filter=StateFilter.all()
):
""" Get the room state at a particular stream position
Args:
@ -536,13 +531,13 @@ class SyncHandler(object):
# get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
last_events, _ = await self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1
)
if last_events:
last_event = last_events[-1]
state = yield self.get_state_after_event(
state = await self.get_state_after_event(
last_event, state_filter=state_filter
)
@ -551,8 +546,7 @@ class SyncHandler(object):
state = {}
return state
@defer.inlineCallbacks
def compute_summary(self, room_id, sync_config, batch, state, now_token):
async def compute_summary(self, room_id, sync_config, batch, state, now_token):
""" Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds
@ -574,7 +568,7 @@ class SyncHandler(object):
# FIXME: we could/should get this from room_stats when matthew/stats lands
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
last_events, _ = yield self.store.get_recent_event_ids_for_room(
last_events, _ = await self.store.get_recent_event_ids_for_room(
room_id, end_token=now_token.room_key, limit=1
)
@ -582,7 +576,7 @@ class SyncHandler(object):
return None
last_event = last_events[-1]
state_ids = yield self.state_store.get_state_ids_for_event(
state_ids = await self.state_store.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@ -590,7 +584,7 @@ class SyncHandler(object):
)
# this is heavily cached, thus: fast.
details = yield self.store.get_room_summary(room_id)
details = await self.store.get_room_summary(room_id)
name_id = state_ids.get((EventTypes.Name, ""))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
@ -608,12 +602,12 @@ class SyncHandler(object):
# calculating heroes. Empty strings are falsey, so we check
# for the "name" value and default to an empty string.
if name_id:
name = yield self.store.get_event(name_id, allow_none=True)
name = await self.store.get_event(name_id, allow_none=True)
if name and name.content.get("name"):
return summary
if canonical_alias_id:
canonical_alias = yield self.store.get_event(
canonical_alias = await self.store.get_event(
canonical_alias_id, allow_none=True
)
if canonical_alias and canonical_alias.content.get("alias"):
@ -678,7 +672,7 @@ class SyncHandler(object):
)
]
missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
missing_hero_state = await self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values()
for s in missing_hero_state:
@ -697,8 +691,7 @@ class SyncHandler(object):
logger.debug("found LruCache for %r", cache_key)
return cache
@defer.inlineCallbacks
def compute_state_delta(
async def compute_state_delta(
self, room_id, batch, sync_config, since_token, now_token, full_state
):
""" Works out the difference in state between the start of the timeline
@ -759,16 +752,16 @@ class SyncHandler(object):
if full_state:
if batch:
current_state_ids = yield self.state_store.get_state_ids_for_event(
current_state_ids = await self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
state_ids = yield self.state_store.get_state_ids_for_event(
state_ids = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
else:
current_state_ids = yield self.get_state_at(
current_state_ids = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter
)
@ -783,13 +776,13 @@ class SyncHandler(object):
)
elif batch.limited:
if batch:
state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
state_at_timeline_start = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
else:
# We can get here if the user has ignored the senders of all
# the recent events.
state_at_timeline_start = yield self.get_state_at(
state_at_timeline_start = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter
)
@ -807,19 +800,19 @@ class SyncHandler(object):
# about them).
state_filter = StateFilter.all()
state_at_previous_sync = yield self.get_state_at(
state_at_previous_sync = await self.get_state_at(
room_id, stream_position=since_token, state_filter=state_filter
)
if batch:
current_state_ids = yield self.state_store.get_state_ids_for_event(
current_state_ids = await self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
else:
# Its not clear how we get here, but empirically we do
# (#5407). Logging has been added elsewhere to try and
# figure out where this state comes from.
current_state_ids = yield self.get_state_at(
current_state_ids = await self.get_state_at(
room_id, stream_position=now_token, state_filter=state_filter
)
@ -843,7 +836,7 @@ class SyncHandler(object):
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
state_ids = yield self.state_store.get_state_ids_for_event(
state_ids = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
@ -883,7 +876,7 @@ class SyncHandler(object):
state = {}
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
state = await self.store.get_events(list(state_ids.values()))
return {
(e.type, e.state_key): e
@ -892,10 +885,9 @@ class SyncHandler(object):
)
}
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
async def unread_notifs_for_room_id(self, room_id, sync_config):
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
receipt_type="m.read",
@ -903,7 +895,7 @@ class SyncHandler(object):
notifs = []
if last_unread_event_id:
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
return notifs
@ -912,8 +904,9 @@ class SyncHandler(object):
# count is whatever it was last time.
return None
@defer.inlineCallbacks
def generate_sync_result(self, sync_config, since_token=None, full_state=False):
async def generate_sync_result(
self, sync_config, since_token=None, full_state=False
):
"""Generates a sync result.
Args:
@ -928,7 +921,7 @@ class SyncHandler(object):
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token()
now_token = await self.event_sources.get_current_token()
logger.info(
"Calculating sync response for %r between %s and %s",
@ -944,10 +937,9 @@ class SyncHandler(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
joined_room_ids = yield self.get_rooms_for_user_at(
joined_room_ids = await self.get_rooms_for_user_at(
user_id, now_token.room_stream_id
)
sync_result_builder = SyncResultBuilder(
sync_config,
full_state,
@ -956,11 +948,11 @@ class SyncHandler(object):
joined_room_ids=joined_room_ids,
)
account_data_by_room = yield self._generate_sync_entry_for_account_data(
account_data_by_room = await self._generate_sync_entry_for_account_data(
sync_result_builder
)
res = yield self._generate_sync_entry_for_rooms(
res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
@ -970,13 +962,13 @@ class SyncHandler(object):
since_token is None and sync_config.filter_collection.blocks_all_presence()
)
if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
await self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
)
yield self._generate_sync_entry_for_to_device(sync_result_builder)
await self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_users=newly_joined_or_invited_users,
@ -987,11 +979,11 @@ class SyncHandler(object):
device_id = sync_config.device_id
one_time_key_counts = {}
if device_id:
one_time_key_counts = yield self.store.count_e2e_one_time_keys(
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
yield self._generate_sync_entry_for_groups(sync_result_builder)
await self._generate_sync_entry_for_groups(sync_result_builder)
# debug for https://github.com/matrix-org/synapse/issues/4422
for joined_room in sync_result_builder.joined:
@ -1015,18 +1007,17 @@ class SyncHandler(object):
)
@measure_func("_generate_sync_entry_for_groups")
@defer.inlineCallbacks
def _generate_sync_entry_for_groups(self, sync_result_builder):
async def _generate_sync_entry_for_groups(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
if since_token and since_token.groups_key:
results = yield self.store.get_groups_changes_for_user(
results = self.store.get_groups_changes_for_user(
user_id, since_token.groups_key, now_token.groups_key
)
else:
results = yield self.store.get_all_groups_for_user(
results = await self.store.get_all_groups_for_user(
user_id, now_token.groups_key
)
@ -1059,8 +1050,7 @@ class SyncHandler(object):
)
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(
async def _generate_sync_entry_for_device_list(
self,
sync_result_builder,
newly_joined_rooms,
@ -1108,32 +1098,32 @@ class SyncHandler(object):
# room with by looking at all users that have left a room plus users
# that were in a room we've left.
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
# Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = yield self.store.get_users_whose_devices_changed(
users_that_have_changed = await self.store.get_users_whose_devices_changed(
since_token.device_list_key, users_who_share_room
)
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
joined_users = yield self.state.get_current_users_in_room(room_id)
joined_users = await self.state.get_current_users_in_room(room_id)
newly_joined_or_invited_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
users_that_have_changed.update(newly_joined_or_invited_users)
user_signatures_changed = yield self.store.get_users_whose_signatures_changed(
user_signatures_changed = await self.store.get_users_whose_signatures_changed(
user_id, since_token.device_list_key
)
users_that_have_changed.update(user_signatures_changed)
# Now find users that we no longer track
for room_id in newly_left_rooms:
left_users = yield self.state.get_current_users_in_room(room_id)
left_users = await self.state.get_current_users_in_room(room_id)
newly_left_users.update(left_users)
# Remove any users that we still share a room with.
@ -1143,8 +1133,7 @@ class SyncHandler(object):
else:
return DeviceLists(changed=[], left=[])
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
async def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates
`sync_result_builder` with the result.
@ -1165,14 +1154,14 @@ class SyncHandler(object):
# We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point.
deleted = yield self.store.delete_messages_for_device(
deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug(
"Deleted %d to-device messages up to %d", deleted, since_stream_id
)
messages, stream_id = yield self.store.get_new_messages_for_device(
messages, stream_id = await self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
@ -1190,8 +1179,7 @@ class SyncHandler(object):
else:
sync_result_builder.to_device = []
@defer.inlineCallbacks
def _generate_sync_entry_for_account_data(self, sync_result_builder):
async def _generate_sync_entry_for_account_data(self, sync_result_builder):
"""Generates the account data portion of the sync response. Populates
`sync_result_builder` with the result.
@ -1209,25 +1197,25 @@ class SyncHandler(object):
(
account_data,
account_data_by_room,
) = yield self.store.get_updated_account_data_for_user(
) = self.store.get_updated_account_data_for_user(
user_id, since_token.account_data_key
)
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
push_rules_changed = await self.store.have_push_rules_changed_for_user(
user_id, int(since_token.push_rules_key)
)
if push_rules_changed:
account_data["m.push_rules"] = yield self.push_rules_for_user(
account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
else:
(
account_data,
account_data_by_room,
) = yield self.store.get_account_data_for_user(sync_config.user.to_string())
) = await self.store.get_account_data_for_user(sync_config.user.to_string())
account_data["m.push_rules"] = yield self.push_rules_for_user(
account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
@ -1242,8 +1230,7 @@ class SyncHandler(object):
return account_data_by_room
@defer.inlineCallbacks
def _generate_sync_entry_for_presence(
async def _generate_sync_entry_for_presence(
self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
):
"""Generates the presence portion of the sync response. Populates the
@ -1271,7 +1258,7 @@ class SyncHandler(object):
presence_key = None
include_offline = False
presence, presence_key = yield presence_source.get_new_events(
presence, presence_key = await presence_source.get_new_events(
user=user,
from_key=presence_key,
is_guest=sync_config.is_guest,
@ -1283,12 +1270,12 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms:
users = yield self.state.get_current_users_in_room(room_id)
users = await self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())
if extra_users_ids:
states = yield self.presence_handler.get_states(extra_users_ids)
states = await self.presence_handler.get_states(extra_users_ids)
presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user
@ -1298,8 +1285,9 @@ class SyncHandler(object):
sync_result_builder.presence = presence
@defer.inlineCallbacks
def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room):
async def _generate_sync_entry_for_rooms(
self, sync_result_builder, account_data_by_room
):
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
@ -1321,7 +1309,7 @@ class SyncHandler(object):
if block_all_room_ephemeral:
ephemeral_by_room = {}
else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
now_token, ephemeral_by_room = await self.ephemeral_by_room(
sync_result_builder,
now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token,
@ -1333,16 +1321,16 @@ class SyncHandler(object):
since_token = sync_result_builder.since_token
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
have_changed = yield self._have_rooms_changed(sync_result_builder)
have_changed = await self._have_rooms_changed(sync_result_builder)
if not have_changed:
tags_by_room = yield self.store.get_updated_tags(
tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key
)
if not tags_by_room:
logger.debug("no-oping sync")
return [], [], [], []
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id
)
@ -1352,18 +1340,18 @@ class SyncHandler(object):
ignored_users = frozenset()
if since_token:
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
res = await self._get_rooms_changed(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms, newly_left_rooms = res
tags_by_room = yield self.store.get_updated_tags(
tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key
)
else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
res = await self._get_all_rooms(sync_result_builder, ignored_users)
room_entries, invited, newly_joined_rooms = res
newly_left_rooms = []
tags_by_room = yield self.store.get_tags_for_user(user_id)
tags_by_room = await self.store.get_tags_for_user(user_id)
def handle_room_entries(room_entry):
return self._generate_room_entry(
@ -1376,7 +1364,7 @@ class SyncHandler(object):
always_include=sync_result_builder.full_state,
)
yield concurrently_execute(handle_room_entries, room_entries, 10)
await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
@ -1410,8 +1398,7 @@ class SyncHandler(object):
newly_left_users,
)
@defer.inlineCallbacks
def _have_rooms_changed(self, sync_result_builder):
async def _have_rooms_changed(self, sync_result_builder):
"""Returns whether there may be any new events that should be sent down
the sync. Returns True if there are.
"""
@ -1422,7 +1409,7 @@ class SyncHandler(object):
assert since_token
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
rooms_changed = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
@ -1435,8 +1422,7 @@ class SyncHandler(object):
return True
return False
@defer.inlineCallbacks
def _get_rooms_changed(self, sync_result_builder, ignored_users):
async def _get_rooms_changed(self, sync_result_builder, ignored_users):
"""Gets the the changes that have happened since the last sync.
Args:
@ -1461,7 +1447,7 @@ class SyncHandler(object):
assert since_token
# Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user(
rooms_changed = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
@ -1499,11 +1485,11 @@ class SyncHandler(object):
continue
if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_state_ids = await self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
old_mem_ev = yield self.store.get_event(
old_mem_ev = await self.store.get_event(
old_mem_ev_id, allow_none=True
)
@ -1536,13 +1522,13 @@ class SyncHandler(object):
newly_left_rooms.append(room_id)
else:
if not old_state_ids:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_state_ids = await self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get(
(EventTypes.Member, user_id), None
)
old_mem_ev = None
if old_mem_ev_id:
old_mem_ev = yield self.store.get_event(
old_mem_ev = await self.store.get_event(
old_mem_ev_id, allow_none=True
)
if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
@ -1566,7 +1552,7 @@ class SyncHandler(object):
if leave_events:
leave_event = leave_events[-1]
leave_stream_token = yield self.store.get_stream_token_for_event(
leave_stream_token = await self.store.get_stream_token_for_event(
leave_event.event_id
)
leave_token = since_token.copy_and_replace(
@ -1603,7 +1589,7 @@ class SyncHandler(object):
timeline_limit = sync_config.filter_collection.timeline_limit()
# Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key,
to_key=now_token.room_key,
@ -1652,8 +1638,7 @@ class SyncHandler(object):
return room_entries, invited, newly_joined_rooms, newly_left_rooms
@defer.inlineCallbacks
def _get_all_rooms(self, sync_result_builder, ignored_users):
async def _get_all_rooms(self, sync_result_builder, ignored_users):
"""Returns entries for all rooms for the user.
Args:
@ -1677,7 +1662,7 @@ class SyncHandler(object):
Membership.BAN,
)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
room_list = await self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=membership_list
)
@ -1700,7 +1685,7 @@ class SyncHandler(object):
elif event.membership == Membership.INVITE:
if event.sender in ignored_users:
continue
invite = yield self.store.get_event(event.event_id)
invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
@ -1726,8 +1711,7 @@ class SyncHandler(object):
return room_entries, invited, []
@defer.inlineCallbacks
def _generate_room_entry(
async def _generate_room_entry(
self,
sync_result_builder,
ignored_users,
@ -1769,7 +1753,7 @@ class SyncHandler(object):
since_token = room_builder.since_token
upto_token = room_builder.upto_token
batch = yield self._load_filtered_recents(
batch = await self._load_filtered_recents(
room_id,
sync_config,
now_token=upto_token,
@ -1796,7 +1780,7 @@ class SyncHandler(object):
# tag was added by synapse e.g. for server notice rooms.
if full_state:
user_id = sync_result_builder.sync_config.user.to_string()
tags = yield self.store.get_tags_for_room(user_id, room_id)
tags = await self.store.get_tags_for_room(user_id, room_id)
# If there aren't any tags, don't send the empty tags list down
# sync
@ -1821,7 +1805,7 @@ class SyncHandler(object):
):
return
state = yield self.compute_state_delta(
state = await self.compute_state_delta(
room_id, batch, sync_config, since_token, now_token, full_state=full_state
)
@ -1844,7 +1828,7 @@ class SyncHandler(object):
)
or since_token is None
):
summary = yield self.compute_summary(
summary = await self.compute_summary(
room_id, sync_config, batch, state, now_token
)
@ -1861,7 +1845,7 @@ class SyncHandler(object):
)
if room_sync or always_include:
notifs = yield self.unread_notifs_for_room_id(room_id, sync_config)
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"]
@ -1887,8 +1871,7 @@ class SyncHandler(object):
else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
@defer.inlineCallbacks
def get_rooms_for_user_at(self, user_id, stream_ordering):
async def get_rooms_for_user_at(self, user_id, stream_ordering):
"""Get set of joined rooms for a user at the given stream ordering.
The stream ordering *must* be recent, otherwise this may throw an
@ -1903,7 +1886,7 @@ class SyncHandler(object):
Deferred[frozenset[str]]: Set of room_ids the user is in at given
stream_ordering.
"""
joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id)
joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
joined_room_ids = set()
@ -1921,10 +1904,10 @@ class SyncHandler(object):
logger.info("User joined room after current token: %s", room_id)
extrems = yield self.store.get_forward_extremeties_for_room(
extrems = await self.store.get_forward_extremeties_for_room(
room_id, stream_ordering
)
users_in_room = yield self.state.get_current_users_in_room(room_id, extrems)
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
joined_room_ids.add(room_id)

View file

@ -304,8 +304,7 @@ class Notifier(object):
without waking up any of the normal user event streams"""
self.notify_replication()
@defer.inlineCallbacks
def wait_for_events(
async def wait_for_events(
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
):
"""Wait until the callback returns a non empty response or the
@ -313,9 +312,9 @@ class Notifier(object):
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
current_token = yield self.event_sources.get_current_token()
current_token = await self.event_sources.get_current_token()
if room_ids is None:
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
@ -344,11 +343,11 @@ class Notifier(object):
self.hs.get_reactor(),
)
with PreserveLoggingContext():
yield listener.deferred
await listener.deferred
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
result = await callback(prev_token, current_token)
if result:
break
@ -364,12 +363,11 @@ class Notifier(object):
# This happened if there was no timeout or if the timeout had
# already expired.
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
result = await callback(prev_token, current_token)
return result
@defer.inlineCallbacks
def get_events_for(
async def get_events_for(
self,
user,
pagination_config,
@ -391,15 +389,14 @@ class Notifier(object):
"""
from_token = pagination_config.from_token
if not from_token:
from_token = yield self.event_sources.get_current_token()
from_token = await self.event_sources.get_current_token()
limit = pagination_config.limit
room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
async def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token))
@ -415,7 +412,7 @@ class Notifier(object):
if only_keys and name not in only_keys:
continue
new_events, new_key = yield source.get_new_events(
new_events, new_key = await source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
limit=limit,
@ -425,7 +422,7 @@ class Notifier(object):
)
if name == "room":
new_events = yield filter_events_for_client(
new_events = await filter_events_for_client(
self.storage,
user.to_string(),
new_events,
@ -461,7 +458,7 @@ class Notifier(object):
user_id_for_stream,
)
result = yield self.wait_for_events(
result = await self.wait_for_events(
user_id_for_stream,
timeout,
check_for_updates,

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from functools import wraps
@ -64,12 +65,22 @@ def measure_func(name=None):
def wrapper(func):
block_name = func.__name__ if name is None else name
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
if inspect.iscoroutinefunction(func):
@wraps(func)
async def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r
else:
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
return measured_func

View file

@ -12,54 +12,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler
from synapse.handlers.sync import SyncConfig
from synapse.types import UserID
import tests.unittest
import tests.utils
from tests.utils import setup_test_homeserver
class SyncTestCase(tests.unittest.TestCase):
class SyncTestCase(tests.unittest.HomeserverTestCase):
""" Tests Sync Handler. """
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
self.sync_handler = SyncHandler(self.hs)
def prepare(self, reactor, clock, hs):
self.hs = hs
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors
yield self.store.upsert_monthly_active_user(user_id1)
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
self.hs.config.hs_disabled = True
with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(

View file

@ -18,6 +18,7 @@
import gc
import hashlib
import hmac
import inspect
import logging
import time
@ -25,7 +26,7 @@ from mock import Mock
from canonicaljson import json
from twisted.internet.defer import Deferred, succeed
from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@ -415,6 +416,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
if inspect.isawaitable(d):
d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by)
@ -424,6 +427,8 @@ class HomeserverTestCase(TestCase):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
if inspect.isawaitable(d):
d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()