Merge pull request #2248 from matrix-org/erikj/state_fixup

Faster cache for get_joined_hosts
This commit is contained in:
Erik Johnston 2017-06-07 14:01:06 +01:00 committed by GitHub
commit a053ff3979
7 changed files with 156 additions and 56 deletions

View file

@ -187,6 +187,7 @@ class TransactionQueue(object):
prev_id for prev_id, _ in event.prev_events prev_id for prev_id, _ in event.prev_events
], ],
) )
destinations = set(destinations)
if send_on_behalf_of is not None: if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server # If we are sending the event on behalf of another server

View file

@ -89,7 +89,7 @@ class TypingHandler(object):
until = self._member_typing_until.get(member, None) until = self._member_typing_until.get(member, None)
if not until or until <= now: if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id) logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member) self._stopped_typing(member)
continue continue
# Check if we need to resend a keep alive over federation for this # Check if we need to resend a keep alive over federation for this
@ -147,7 +147,7 @@ class TypingHandler(object):
# No point sending another notification # No point sending another notification
defer.returnValue(None) defer.returnValue(None)
yield self._push_update( self._push_update(
member=member, member=member,
typing=True, typing=True,
) )
@ -171,7 +171,7 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id) member = RoomMember(room_id=room_id, user_id=target_user_id)
yield self._stopped_typing(member) self._stopped_typing(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_left_room(self, user, room_id): def user_left_room(self, user, room_id):
@ -180,7 +180,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=user_id) member = RoomMember(room_id=room_id, user_id=user_id)
yield self._stopped_typing(member) yield self._stopped_typing(member)
@defer.inlineCallbacks
def _stopped_typing(self, member): def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()): if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point # No point
@ -189,16 +188,15 @@ class TypingHandler(object):
self._member_typing_until.pop(member, None) self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None) self._member_last_federation_poke.pop(member, None)
yield self._push_update( self._push_update(
member=member, member=member,
typing=False, typing=False,
) )
@defer.inlineCallbacks
def _push_update(self, member, typing): def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id): if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users. # Only send updates for changes to our own users.
yield self._push_remote(member, typing) preserve_fn(self._push_remote)(member, typing)
self._push_update_local( self._push_update_local(
member=member, member=member,

View file

@ -108,6 +108,8 @@ class SlavedEventStore(BaseSlavedStore):
get_current_state_ids = ( get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"] StateStore.__dict__["get_current_state_ids"]
) )
get_state_group_delta = DataStore.get_state_group_delta.__func__
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__ has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = ( get_unread_push_actions_for_user_in_range_for_http = (

View file

@ -170,9 +170,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room") logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state( joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_users) defer.returnValue(joined_users)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -181,9 +179,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room") logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts( joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -206,12 +202,12 @@ class StateHandler(object):
Returns: Returns:
synapse.events.snapshot.EventContext: synapse.events.snapshot.EventContext:
""" """
context = EventContext()
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
# If this is an outlier, then we know it shouldn't have any current # If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
context = EventContext()
if old_state: if old_state:
context.prev_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
@ -230,6 +226,7 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
context = EventContext()
context.prev_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
@ -250,19 +247,13 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
if event.is_state(): entry = yield self.resolve_state_groups(
entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events],
event.room_id, [e for e, _ in event.prev_events], )
event_type=event.type,
state_key=event.state_key,
)
else:
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
)
curr_state = entry.state curr_state = entry.state
context = EventContext()
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
context.state_group = self.store.get_next_state_group() context.state_group = self.store.get_next_state_group()
@ -275,11 +266,14 @@ class StateHandler(object):
context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id context.current_state_ids[key] = event.event_id
context.prev_group = entry.prev_group if entry.state_group:
context.delta_ids = entry.delta_ids context.prev_group = entry.state_group
if context.delta_ids is not None: context.delta_ids = {
context.delta_ids = dict(context.delta_ids) key: event.event_id
context.delta_ids[key] = event.event_id }
elif entry.prev_group:
context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids
else: else:
if entry.state_group is None: if entry.state_group is None:
entry.state_group = self.store.get_next_state_group() entry.state_group = self.store.get_next_state_group()
@ -295,7 +289,7 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): def resolve_state_groups(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -320,11 +314,13 @@ class StateHandler(object):
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
defer.returnValue(_StateCacheEntry( defer.returnValue(_StateCacheEntry(
state=state_list, state=state_list,
state_group=name, state_group=name,
prev_group=name, prev_group=prev_group,
delta_ids={}, delta_ids=delta_ids,
)) ))
with (yield self.resolve_linearizer.queue(group_names)): with (yield self.resolve_linearizer.queue(group_names)):
@ -377,11 +373,11 @@ class StateHandler(object):
prev_group = None prev_group = None
delta_ids = None delta_ids = None
for old_group, old_ids in state_groups_ids.items(): for old_group, old_ids in state_groups_ids.iteritems():
if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): if not set(new_state) - set(old_ids):
n_delta_ids = { n_delta_ids = {
k: v k: v
for k, v in new_state.items() for k, v in new_state.iteritems()
if old_ids.get(k) != v if old_ids.get(k) != v
} }
if not delta_ids or len(n_delta_ids) < len(delta_ids): if not delta_ids or len(n_delta_ids) < len(delta_ids):

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.async import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii from synapse.util.stringutils import to_ascii
@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore):
context=context, context=context,
) )
def get_joined_users_from_state(self, room_id, state_group, state_ids): def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore):
state_group = object() state_group = object()
return self._get_joined_users_from_context( return self._get_joined_users_from_context(
room_id, state_group, state_ids, room_id, state_group, state_entry.state, context=state_entry,
) )
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
@ -534,7 +536,8 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(False) defer.returnValue(False)
def get_joined_hosts(self, room_id, state_group, state_ids): def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -543,33 +546,20 @@ class RoomMemberStore(SQLBaseStore):
state_group = object() state_group = object()
return self._get_joined_hosts( return self._get_joined_hosts(
room_id, state_group, state_ids room_id, state_group, state_entry.state, state_entry=state_entry,
) )
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
def _get_joined_hosts(self, room_id, state_group, current_state_ids): # @defer.inlineCallbacks
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
# We don't use `state_group`, its there so that we can cache based # We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's # on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this. # See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None assert state_group is not None
joined_hosts = set() cache = self._get_joined_hosts_cache(room_id)
for etype, state_key in current_state_ids: joined_hosts = yield cache.get_destinations(state_entry)
if etype == EventTypes.Member:
try:
host = get_domain_from_id(state_key)
except:
logger.warn("state_key not user_id: %s", state_key)
continue
if host in joined_hosts:
continue
event_id = current_state_ids[(etype, state_key)]
event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN:
joined_hosts.add(intern_string(host))
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@ -647,3 +637,75 @@ class RoomMemberStore(SQLBaseStore):
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
defer.returnValue(result) defer.returnValue(result)
@cached(max_entries=10000, iterable=True)
def _get_joined_hosts_cache(self, room_id):
return _JoinedHostsCache(self, room_id)
class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates
via state deltas.
"""
def __init__(self, store, room_id):
self.store = store
self.room_id = room_id
self.hosts_to_joined_users = {}
self.state_group = object()
self.linearizer = Linearizer("_JoinedHostsCache")
self._len = 0
@defer.inlineCallbacks
def get_destinations(self, state_entry):
"""Get set of destinations for a state entry
Args:
state_entry(synapse.state._StateCacheEntry)
"""
if state_entry.state_group == self.state_group:
defer.returnValue(frozenset(self.hosts_to_joined_users))
with (yield self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
if typ != EventTypes.Member:
continue
host = intern_string(get_domain_from_id(state_key))
user_id = state_key
known_joins = self.hosts_to_joined_users.setdefault(host, set())
event = yield self.store.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
known_joins.discard(user_id)
if not known_joins:
self.hosts_to_joined_users.pop(host, None)
else:
joined_users = yield self.store.get_joined_users_from_state(
self.room_id, state_entry,
)
self.hosts_to_joined_users = {}
for user_id in joined_users:
host = intern_string(get_domain_from_id(user_id))
self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
if state_entry.state_group:
self.state_group = state_entry.state_group
else:
self.state_group = object()
self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
defer.returnValue(frozenset(self.hosts_to_joined_users))
def __len__(self):
return self._len

View file

@ -98,6 +98,45 @@ class StateStore(SQLBaseStore):
_get_current_state_ids_txn, _get_current_state_ids_txn,
) )
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
(prev_group, delta_ids), where both may be None.
"""
def _get_state_group_delta_txn(txn):
prev_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": state_group,
},
retcol="prev_state_group",
allow_none=True,
)
if not prev_group:
return None, None
delta_ids = self._simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": state_group,
},
retcols=("type", "state_key", "event_id",)
)
return prev_group, {
(row["type"], row["state_key"]): row["event_id"]
for row in delta_ids
}
return self.runInteraction(
"get_state_group_delta",
_get_state_group_delta_txn,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
if not event_ids: if not event_ids:

View file

@ -143,6 +143,7 @@ class StateTestCase(unittest.TestCase):
"add_event_hashes", "add_event_hashes",
"get_events", "get_events",
"get_next_state_group", "get_next_state_group",
"get_state_group_delta",
] ]
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
@ -154,6 +155,7 @@ class StateTestCase(unittest.TestCase):
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
self.store.get_next_state_group.side_effect = Mock self.store.get_next_state_group.side_effect = Mock
self.store.get_state_group_delta.return_value = (None, None)
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0