Port PresenceHandler to async/await (#6991)

This commit is contained in:
Erik Johnston 2020-02-26 15:33:26 +00:00 committed by GitHub
parent 7728d87fd7
commit 1f773eec91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 113 additions and 115 deletions

1
changelog.d/6991.misc Normal file
View file

@ -0,0 +1 @@
Port `synapse.handlers.presence` to async/await.

View file

@ -1016,11 +1016,10 @@ class EventCreationHandler(object):
# matters as sometimes presence code can take a while. # matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user) run_in_background(self._bump_active_time, requester.user)
@defer.inlineCallbacks async def _bump_active_time(self, user):
def _bump_active_time(self, user):
try: try:
presence = self.hs.get_presence_handler() presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user) await presence.bump_presence_active_time(user)
except Exception: except Exception:
logger.exception("Error bumping presence active time") logger.exception("Error bumping presence active time")

View file

@ -24,11 +24,12 @@ The methods that define policy are:
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Set from typing import Dict, List, Set
from six import iteritems, itervalues from six import iteritems, itervalues
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer
@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
MYPY = False
if MYPY:
import synapse.server
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object): class PresenceHandler(object):
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs self.hs = hs
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname self.server_name = hs.hostname
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -150,7 +154,7 @@ class PresenceHandler(object):
# Set of users who have presence in the `user_to_current_state` that # Set of users who have presence in the `user_to_current_state` that
# have not yet been persisted # have not yet been persisted
self.unpersisted_users_changes = set() self.unpersisted_users_changes = set() # type: Set[str]
hs.get_reactor().addSystemEventTrigger( hs.get_reactor().addSystemEventTrigger(
"before", "before",
@ -160,12 +164,11 @@ class PresenceHandler(object):
self._on_shutdown, self._on_shutdown,
) )
self.serial_to_user = {}
self._next_serial = 1 self._next_serial = 1
# Keeps track of the number of *ongoing* syncs on this process. While # Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline. # this is non zero a user will never go offline.
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {} # type: Dict[str, int]
# Keeps track of the number of *ongoing* syncs on other processes. # Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never # While any sync is ongoing on another process the user will never
@ -213,8 +216,7 @@ class PresenceHandler(object):
self._event_pos = self.store.get_current_events_token() self._event_pos = self.store.get_current_events_token()
self._event_processing = False self._event_processing = False
@defer.inlineCallbacks async def _on_shutdown(self):
def _on_shutdown(self):
"""Gets called when shutting down. This lets us persist any updates that """Gets called when shutting down. This lets us persist any updates that
we haven't yet persisted, e.g. updates that only changes some internal we haven't yet persisted, e.g. updates that only changes some internal
timers. This allows changes to persist across startup without having to timers. This allows changes to persist across startup without having to
@ -235,7 +237,7 @@ class PresenceHandler(object):
if self.unpersisted_users_changes: if self.unpersisted_users_changes:
yield self.store.update_presence( await self.store.update_presence(
[ [
self.user_to_current_state[user_id] self.user_to_current_state[user_id]
for user_id in self.unpersisted_users_changes for user_id in self.unpersisted_users_changes
@ -243,8 +245,7 @@ class PresenceHandler(object):
) )
logger.info("Finished _on_shutdown") logger.info("Finished _on_shutdown")
@defer.inlineCallbacks async def _persist_unpersisted_changes(self):
def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they """We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times. may stack up and slow down shutdown times.
""" """
@ -253,12 +254,11 @@ class PresenceHandler(object):
if unpersisted: if unpersisted:
logger.info("Persisting %d unpersisted presence updates", len(unpersisted)) logger.info("Persisting %d unpersisted presence updates", len(unpersisted))
yield self.store.update_presence( await self.store.update_presence(
[self.user_to_current_state[user_id] for user_id in unpersisted] [self.user_to_current_state[user_id] for user_id in unpersisted]
) )
@defer.inlineCallbacks async def _update_states(self, new_states):
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes """Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state the notifier and federation if and only if the changed presence state
should be sent to clients/servers. should be sent to clients/servers.
@ -267,7 +267,7 @@ class PresenceHandler(object):
with Measure(self.clock, "presence_update_states"): with Measure(self.clock, "presence_update_states"):
# NOTE: We purposefully don't yield between now and when we've # NOTE: We purposefully don't await between now and when we've
# calculated what we want to do with the new states, to avoid races. # calculated what we want to do with the new states, to avoid races.
to_notify = {} # Changes we want to notify everyone about to_notify = {} # Changes we want to notify everyone about
@ -311,7 +311,7 @@ class PresenceHandler(object):
if to_notify: if to_notify:
notified_presence_counter.inc(len(to_notify)) notified_presence_counter.inc(len(to_notify))
yield self._persist_and_notify(list(to_notify.values())) await self._persist_and_notify(list(to_notify.values()))
self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes |= {s.user_id for s in new_states}
self.unpersisted_users_changes -= set(to_notify.keys()) self.unpersisted_users_changes -= set(to_notify.keys())
@ -326,7 +326,7 @@ class PresenceHandler(object):
self._push_to_remotes(to_federation_ping.values()) self._push_to_remotes(to_federation_ping.values())
def _handle_timeouts(self): async def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
""" """
@ -368,10 +368,9 @@ class PresenceHandler(object):
now=now, now=now,
) )
return self._update_states(changes) return await self._update_states(changes)
@defer.inlineCallbacks async def bump_presence_active_time(self, user):
def bump_presence_active_time(self, user):
"""We've seen the user do something that indicates they're interacting """We've seen the user do something that indicates they're interacting
with the app. with the app.
""" """
@ -383,16 +382,17 @@ class PresenceHandler(object):
bump_active_time_counter.inc() bump_active_time_counter.inc()
prev_state = yield self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)
new_fields = {"last_active_ts": self.clock.time_msec()} new_fields = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE: if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE new_fields["state"] = PresenceState.ONLINE
yield self._update_states([prev_state.copy_and_replace(**new_fields)]) await self._update_states([prev_state.copy_and_replace(**new_fields)])
@defer.inlineCallbacks async def user_syncing(
def user_syncing(self, user_id, affect_presence=True): self, user_id: str, affect_presence: bool = True
) -> ContextManager[None]:
"""Returns a context manager that should surround any stream requests """Returns a context manager that should surround any stream requests
from the user. from the user.
@ -415,11 +415,11 @@ class PresenceHandler(object):
curr_sync = self.user_to_num_current_syncs.get(user_id, 0) curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1 self.user_to_num_current_syncs[user_id] = curr_sync + 1
prev_state = yield self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)
if prev_state.state == PresenceState.OFFLINE: if prev_state.state == PresenceState.OFFLINE:
# If they're currently offline then bring them online, otherwise # If they're currently offline then bring them online, otherwise
# just update the last sync times. # just update the last sync times.
yield self._update_states( await self._update_states(
[ [
prev_state.copy_and_replace( prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE,
@ -429,7 +429,7 @@ class PresenceHandler(object):
] ]
) )
else: else:
yield self._update_states( await self._update_states(
[ [
prev_state.copy_and_replace( prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec() last_user_sync_ts=self.clock.time_msec()
@ -437,13 +437,12 @@ class PresenceHandler(object):
] ]
) )
@defer.inlineCallbacks async def _end():
def _end():
try: try:
self.user_to_num_current_syncs[user_id] -= 1 self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)
yield self._update_states( await self._update_states(
[ [
prev_state.copy_and_replace( prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec() last_user_sync_ts=self.clock.time_msec()
@ -480,8 +479,7 @@ class PresenceHandler(object):
else: else:
return set() return set()
@defer.inlineCallbacks async def update_external_syncs_row(
def update_external_syncs_row(
self, process_id, user_id, is_syncing, sync_time_msec self, process_id, user_id, is_syncing, sync_time_msec
): ):
"""Update the syncing users for an external process as a delta. """Update the syncing users for an external process as a delta.
@ -494,8 +492,8 @@ class PresenceHandler(object):
is_syncing (bool): Whether or not the user is now syncing is_syncing (bool): Whether or not the user is now syncing
sync_time_msec(int): Time in ms when the user was last syncing sync_time_msec(int): Time in ms when the user was last syncing
""" """
with (yield self.external_sync_linearizer.queue(process_id)): with (await self.external_sync_linearizer.queue(process_id)):
prev_state = yield self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)
process_presence = self.external_process_to_current_syncs.setdefault( process_presence = self.external_process_to_current_syncs.setdefault(
process_id, set() process_id, set()
@ -525,25 +523,24 @@ class PresenceHandler(object):
process_presence.discard(user_id) process_presence.discard(user_id)
if updates: if updates:
yield self._update_states(updates) await self._update_states(updates)
self.external_process_last_updated_ms[process_id] = self.clock.time_msec() self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
@defer.inlineCallbacks async def update_external_syncs_clear(self, process_id):
def update_external_syncs_clear(self, process_id):
"""Marks all users that had been marked as syncing by a given process """Marks all users that had been marked as syncing by a given process
as offline. as offline.
Used when the process has stopped/disappeared. Used when the process has stopped/disappeared.
""" """
with (yield self.external_sync_linearizer.queue(process_id)): with (await self.external_sync_linearizer.queue(process_id)):
process_presence = self.external_process_to_current_syncs.pop( process_presence = self.external_process_to_current_syncs.pop(
process_id, set() process_id, set()
) )
prev_states = yield self.current_state_for_users(process_presence) prev_states = await self.current_state_for_users(process_presence)
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
yield self._update_states( await self._update_states(
[ [
prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
for prev_state in itervalues(prev_states) for prev_state in itervalues(prev_states)
@ -551,15 +548,13 @@ class PresenceHandler(object):
) )
self.external_process_last_updated_ms.pop(process_id, None) self.external_process_last_updated_ms.pop(process_id, None)
@defer.inlineCallbacks async def current_state_for_user(self, user_id):
def current_state_for_user(self, user_id):
"""Get the current presence state for a user. """Get the current presence state for a user.
""" """
res = yield self.current_state_for_users([user_id]) res = await self.current_state_for_users([user_id])
return res[user_id] return res[user_id]
@defer.inlineCallbacks async def current_state_for_users(self, user_ids):
def current_state_for_users(self, user_ids):
"""Get the current presence state for multiple users. """Get the current presence state for multiple users.
Returns: Returns:
@ -574,7 +569,7 @@ class PresenceHandler(object):
if missing: if missing:
# There are things not in our in memory cache. Lets pull them out of # There are things not in our in memory cache. Lets pull them out of
# the database. # the database.
res = yield self.store.get_presence_for_users(missing) res = await self.store.get_presence_for_users(missing)
states.update(res) states.update(res)
missing = [user_id for user_id, state in iteritems(states) if not state] missing = [user_id for user_id, state in iteritems(states) if not state]
@ -587,14 +582,13 @@ class PresenceHandler(object):
return states return states
@defer.inlineCallbacks async def _persist_and_notify(self, states):
def _persist_and_notify(self, states):
"""Persist states in the database, poke the notifier and send to """Persist states in the database, poke the notifier and send to
interested remote servers interested remote servers
""" """
stream_id, max_token = yield self.store.update_presence(states) stream_id, max_token = await self.store.update_presence(states)
parties = yield get_interested_parties(self.store, states) parties = await get_interested_parties(self.store, states)
room_ids_to_states, users_to_states = parties room_ids_to_states, users_to_states = parties
self.notifier.on_new_event( self.notifier.on_new_event(
@ -606,9 +600,8 @@ class PresenceHandler(object):
self._push_to_remotes(states) self._push_to_remotes(states)
@defer.inlineCallbacks async def notify_for_states(self, state, stream_id):
def notify_for_states(self, state, stream_id): parties = await get_interested_parties(self.store, [state])
parties = yield get_interested_parties(self.store, [state])
room_ids_to_states, users_to_states = parties room_ids_to_states, users_to_states = parties
self.notifier.on_new_event( self.notifier.on_new_event(
@ -626,8 +619,7 @@ class PresenceHandler(object):
""" """
self.federation.send_presence(states) self.federation.send_presence(states)
@defer.inlineCallbacks async def incoming_presence(self, origin, content):
def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server. """Called when we receive a `m.presence` EDU from a remote server.
""" """
now = self.clock.time_msec() now = self.clock.time_msec()
@ -670,21 +662,19 @@ class PresenceHandler(object):
new_fields["status_msg"] = push.get("status_msg", None) new_fields["status_msg"] = push.get("status_msg", None)
new_fields["currently_active"] = push.get("currently_active", False) new_fields["currently_active"] = push.get("currently_active", False)
prev_state = yield self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)
updates.append(prev_state.copy_and_replace(**new_fields)) updates.append(prev_state.copy_and_replace(**new_fields))
if updates: if updates:
federation_presence_counter.inc(len(updates)) federation_presence_counter.inc(len(updates))
yield self._update_states(updates) await self._update_states(updates)
@defer.inlineCallbacks async def get_state(self, target_user, as_event=False):
def get_state(self, target_user, as_event=False): results = await self.get_states([target_user.to_string()], as_event=as_event)
results = yield self.get_states([target_user.to_string()], as_event=as_event)
return results[0] return results[0]
@defer.inlineCallbacks async def get_states(self, target_user_ids, as_event=False):
def get_states(self, target_user_ids, as_event=False):
"""Get the presence state for users. """Get the presence state for users.
Args: Args:
@ -695,7 +685,7 @@ class PresenceHandler(object):
list list
""" """
updates = yield self.current_state_for_users(target_user_ids) updates = await self.current_state_for_users(target_user_ids)
updates = list(updates.values()) updates = list(updates.values())
for user_id in set(target_user_ids) - {u.user_id for u in updates}: for user_id in set(target_user_ids) - {u.user_id for u in updates}:
@ -713,8 +703,7 @@ class PresenceHandler(object):
else: else:
return updates return updates
@defer.inlineCallbacks async def set_state(self, target_user, state, ignore_status_msg=False):
def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user. """Set the presence state of the user.
""" """
status_msg = state.get("status_msg", None) status_msg = state.get("status_msg", None)
@ -730,7 +719,7 @@ class PresenceHandler(object):
user_id = target_user.to_string() user_id = target_user.to_string()
prev_state = yield self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)
new_fields = {"state": presence} new_fields = {"state": presence}
@ -741,16 +730,15 @@ class PresenceHandler(object):
if presence == PresenceState.ONLINE: if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec() new_fields["last_active_ts"] = self.clock.time_msec()
yield self._update_states([prev_state.copy_and_replace(**new_fields)]) await self._update_states([prev_state.copy_and_replace(**new_fields)])
@defer.inlineCallbacks async def is_visible(self, observed_user, observer_user):
def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence. """Returns whether a user can see another user's presence.
""" """
observer_room_ids = yield self.store.get_rooms_for_user( observer_room_ids = await self.store.get_rooms_for_user(
observer_user.to_string() observer_user.to_string()
) )
observed_room_ids = yield self.store.get_rooms_for_user( observed_room_ids = await self.store.get_rooms_for_user(
observed_user.to_string() observed_user.to_string()
) )
@ -759,8 +747,7 @@ class PresenceHandler(object):
return False return False
@defer.inlineCallbacks async def get_all_presence_updates(self, last_id, current_id):
def get_all_presence_updates(self, last_id, current_id):
""" """
Gets a list of presence update rows from between the given stream ids. Gets a list of presence update rows from between the given stream ids.
Each row has: Each row has:
@ -775,7 +762,7 @@ class PresenceHandler(object):
""" """
# TODO(markjh): replicate the unpersisted changes. # TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes. # This could use the in-memory stores for recent changes.
rows = yield self.store.get_all_presence_updates(last_id, current_id) rows = await self.store.get_all_presence_updates(last_id, current_id)
return rows return rows
def notify_new_event(self): def notify_new_event(self):
@ -786,20 +773,18 @@ class PresenceHandler(object):
if self._event_processing: if self._event_processing:
return return
@defer.inlineCallbacks async def _process_presence():
def _process_presence():
assert not self._event_processing assert not self._event_processing
self._event_processing = True self._event_processing = True
try: try:
yield self._unsafe_process() await self._unsafe_process()
finally: finally:
self._event_processing = False self._event_processing = False
run_as_background_process("presence.notify_new_event", _process_presence) run_as_background_process("presence.notify_new_event", _process_presence)
@defer.inlineCallbacks async def _unsafe_process(self):
def _unsafe_process(self):
# Loop round handling deltas until we're up to date # Loop round handling deltas until we're up to date
while True: while True:
with Measure(self.clock, "presence_delta"): with Measure(self.clock, "presence_delta"):
@ -812,10 +797,10 @@ class PresenceHandler(object):
self._event_pos, self._event_pos,
room_max_stream_ordering, room_max_stream_ordering,
) )
max_pos, deltas = yield self.store.get_current_state_deltas( max_pos, deltas = await self.store.get_current_state_deltas(
self._event_pos, room_max_stream_ordering self._event_pos, room_max_stream_ordering
) )
yield self._handle_state_delta(deltas) await self._handle_state_delta(deltas)
self._event_pos = max_pos self._event_pos = max_pos
@ -824,8 +809,7 @@ class PresenceHandler(object):
max_pos max_pos
) )
@defer.inlineCallbacks async def _handle_state_delta(self, deltas):
def _handle_state_delta(self, deltas):
"""Process current state deltas to find new joins that need to be """Process current state deltas to find new joins that need to be
handled. handled.
""" """
@ -846,13 +830,13 @@ class PresenceHandler(object):
# joins. # joins.
continue continue
event = yield self.store.get_event(event_id, allow_none=True) event = await self.store.get_event(event_id, allow_none=True)
if not event or event.content.get("membership") != Membership.JOIN: if not event or event.content.get("membership") != Membership.JOIN:
# We only care about joins # We only care about joins
continue continue
if prev_event_id: if prev_event_id:
prev_event = yield self.store.get_event(prev_event_id, allow_none=True) prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if ( if (
prev_event prev_event
and prev_event.content.get("membership") == Membership.JOIN and prev_event.content.get("membership") == Membership.JOIN
@ -860,10 +844,9 @@ class PresenceHandler(object):
# Ignore changes to join events. # Ignore changes to join events.
continue continue
yield self._on_user_joined_room(room_id, state_key) await self._on_user_joined_room(room_id, state_key)
@defer.inlineCallbacks async def _on_user_joined_room(self, room_id, user_id):
def _on_user_joined_room(self, room_id, user_id):
"""Called when we detect a user joining the room via the current state """Called when we detect a user joining the room via the current state
delta stream. delta stream.
@ -882,8 +865,8 @@ class PresenceHandler(object):
# TODO: We should be able to filter the hosts down to those that # TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user # haven't previously seen the user
state = yield self.current_state_for_user(user_id) state = await self.current_state_for_user(user_id)
hosts = yield self.state.get_current_hosts_in_room(room_id) hosts = await self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves. # Filter out ourselves.
hosts = {host for host in hosts if host != self.server_name} hosts = {host for host in hosts if host != self.server_name}
@ -903,10 +886,10 @@ class PresenceHandler(object):
# TODO: Check that this is actually a new server joining the # TODO: Check that this is actually a new server joining the
# room. # room.
user_ids = yield self.state.get_current_users_in_room(room_id) user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids)) user_ids = list(filter(self.is_mine_id, user_ids))
states = yield self.current_state_for_users(user_ids) states = await self.current_state_for_users(user_ids)
# Filter out old presence, i.e. offline presence states where # Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this # the user hasn't been active for a week. We can change this
@ -996,9 +979,8 @@ class PresenceEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@defer.inlineCallbacks
@log_function @log_function
def get_new_events( async def get_new_events(
self, self,
user, user,
from_key, from_key,
@ -1045,7 +1027,7 @@ class PresenceEventSource(object):
presence = self.get_presence_handler() presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache stream_change_cache = self.store.presence_stream_cache
users_interested_in = yield self._get_interested_in(user, explicit_room_id) users_interested_in = await self._get_interested_in(user, explicit_room_id)
user_ids_changed = set() user_ids_changed = set()
changed = None changed = None
@ -1071,7 +1053,7 @@ class PresenceEventSource(object):
else: else:
user_ids_changed = users_interested_in user_ids_changed = users_interested_in
updates = yield presence.current_state_for_users(user_ids_changed) updates = await presence.current_state_for_users(user_ids_changed)
if include_offline: if include_offline:
return (list(updates.values()), max_token) return (list(updates.values()), max_token)
@ -1084,11 +1066,11 @@ class PresenceEventSource(object):
def get_current_key(self): def get_current_key(self):
return self.store.get_current_presence_token() return self.store.get_current_presence_token()
def get_pagination_rows(self, user, pagination_config, key): async def get_pagination_rows(self, user, pagination_config, key):
return self.get_new_events(user, from_key=None, include_offline=False) return await self.get_new_events(user, from_key=None, include_offline=False)
@cachedInlineCallbacks(num_args=2, cache_context=True) @cached(num_args=2, cache_context=True)
def _get_interested_in(self, user, explicit_room_id, cache_context): async def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence """Returns the set of users that the given user should see presence
updates for updates for
""" """
@ -1096,13 +1078,13 @@ class PresenceEventSource(object):
users_interested_in = set() users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user( users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate user_id, on_invalidate=cache_context.invalidate
) )
users_interested_in.update(users_who_share_room) users_interested_in.update(users_who_share_room)
if explicit_room_id: if explicit_room_id:
user_ids = yield self.store.get_users_in_room( user_ids = await self.store.get_users_in_room(
explicit_room_id, on_invalidate=cache_context.invalidate explicit_room_id, on_invalidate=cache_context.invalidate
) )
users_interested_in.update(user_ids) users_interested_in.update(user_ids)
@ -1277,8 +1259,8 @@ def get_interested_parties(store, states):
2-tuple: `(room_ids_to_states, users_to_states)`, 2-tuple: `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]` with each item being a dict of `entity_name` -> `[UserPresenceState]`
""" """
room_ids_to_states = {} room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
users_to_states = {} users_to_states = {} # type: Dict[str, List[UserPresenceState]]
for state in states: for state in states:
room_ids = yield store.get_rooms_for_user(state.user_id) room_ids = yield store.get_rooms_for_user(state.user_id)
for room_id in room_ids: for room_id in room_ids:

View file

@ -323,7 +323,11 @@ class ReplicationStreamer(object):
# We need to tell the presence handler that the connection has been # We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection. # lost so that it can handle any ongoing syncs on that connection.
self.presence_handler.update_external_syncs_clear(connection.conn_id) run_as_background_process(
"update_external_syncs_clear",
self.presence_handler.update_external_syncs_clear,
connection.conn_id,
)
def _batch_updates(updates): def _batch_updates(updates):

View file

@ -3,6 +3,7 @@ import twisted.internet
import synapse.api.auth import synapse.api.auth
import synapse.config.homeserver import synapse.config.homeserver
import synapse.crypto.keyring import synapse.crypto.keyring
import synapse.federation.federation_server
import synapse.federation.sender import synapse.federation.sender
import synapse.federation.transport.client import synapse.federation.transport.client
import synapse.handlers import synapse.handlers
@ -107,5 +108,9 @@ class HomeServer(object):
self, self,
) -> synapse.replication.tcp.client.ReplicationClientHandler: ) -> synapse.replication.tcp.client.ReplicationClientHandler:
pass pass
def get_federation_registry(
self,
) -> synapse.federation.federation_server.FederationHandlerRegistry:
pass
def is_mine_id(self, domain_id: str) -> bool: def is_mine_id(self, domain_id: str) -> bool:
pass pass

View file

@ -494,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, "@test2:server") self.helper.join(room_id, "@test2:server")
# Mark test2 as online, test will be offline with a last_active of 0 # Mark test2 as online, test will be offline with a last_active of 0
self.presence_handler.set_state( self.get_success(
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} self.presence_handler.set_state(
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
)
) )
self.reactor.pump([0]) # Wait for presence updates to be handled self.reactor.pump([0]) # Wait for presence updates to be handled
@ -543,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
# Mark test as online # Mark test as online
self.presence_handler.set_state( self.get_success(
UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} self.presence_handler.set_state(
UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
)
) )
# Mark test2 as online, test will be offline with a last_active of 0. # Mark test2 as online, test will be offline with a last_active of 0.
# Note we don't join them to the room yet # Note we don't join them to the room yet
self.presence_handler.set_state( self.get_success(
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} self.presence_handler.set_state(
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
)
) )
# Add servers to the room # Add servers to the room

View file

@ -183,6 +183,7 @@ commands = mypy \
synapse/events/spamcheck.py \ synapse/events/spamcheck.py \
synapse/federation/sender \ synapse/federation/sender \
synapse/federation/transport \ synapse/federation/transport \
synapse/handlers/presence.py \
synapse/handlers/sync.py \ synapse/handlers/sync.py \
synapse/handlers/ui_auth \ synapse/handlers/ui_auth \
synapse/logging/ \ synapse/logging/ \