Merge pull request #2842 from matrix-org/rav/state_resolution_handler

Factor out resolve_state_groups to a separate handler
This commit is contained in:
Richard van der Hoff 2018-02-02 15:27:35 +01:00 committed by GitHub
commit 18eae413af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 119 additions and 66 deletions

View file

@ -808,13 +808,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn(
self.state_handler.resolve_state_groups_for_events
)
states = yield logcontext.make_deferred_yieldable(defer.gatherResults( states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [resolve(room_id, [e]) for e in event_ids],
logcontext.preserve_fn(self.state_handler.resolve_state_groups)( consumeErrors=True,
room_id, [e]
)
for e in event_ids
], consumeErrors=True,
)) ))
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))

View file

@ -66,7 +66,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository, MediaRepository,
MediaRepositoryResource, MediaRepositoryResource,
) )
from synapse.state import StateHandler from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.util import Clock from synapse.util import Clock
@ -102,6 +102,7 @@ class HomeServer(object):
'v1auth', 'v1auth',
'auth', 'auth',
'state_handler', 'state_handler',
'state_resolution_handler',
'presence_handler', 'presence_handler',
'sync_handler', 'sync_handler',
'typing_handler', 'typing_handler',
@ -224,6 +225,9 @@ class HomeServer(object):
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)
def build_state_resolution_handler(self):
return StateResolutionHandler(self)
def build_presence_handler(self): def build_presence_handler(self):
return PresenceHandler(self) return PresenceHandler(self)

View file

@ -34,6 +34,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler: def get_state_handler(self) -> synapse.state.StateHandler:
pass pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass pass

View file

@ -85,31 +85,19 @@ class _StateCacheEntry(object):
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """Fetches bits of state from the stores, and does state resolution
where necessary
""" """
def __init__(self, hs): def __init__(self, hs):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self): def start_caching(self):
logger.debug("start_caching") # TODO: remove this shim
self._state_resolution_handler.start_caching()
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="", def get_current_state(self, room_id, event_type=None, state_key="",
@ -131,7 +119,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state") logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state state = ret.state
if event_type: if event_type:
@ -168,7 +156,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state state = ret.state
defer.returnValue(state) defer.returnValue(state)
@ -178,7 +166,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room") logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry) joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users) defer.returnValue(joined_users)
@ -187,7 +175,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room") logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry) joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@ -245,7 +233,7 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups( entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
@ -287,8 +275,7 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function def resolve_state_groups_for_events(self, room_id, event_ids):
def resolve_state_groups(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -308,13 +295,7 @@ class StateHandler(object):
room_id, event_ids room_id, event_ids
) )
logger.debug( if len(state_groups_ids) == 1:
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@ -326,6 +307,92 @@ class StateHandler(object):
delta_ids=delta_ids, delta_ids=delta_ids,
)) ))
result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, self._state_map_factory,
)
defer.returnValue(result)
def _state_map_factory(self, ev_ids):
return self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
class StateResolutionHandler(object):
"""Responsible for doing state conflict resolution.
Note that the storage layer depends on this handler, so all functions must
be storage-independent.
"""
def __init__(self, hs):
self.clock = hs.get_clock()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self):
logger.debug("start_caching")
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks
@log_function
def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
not be called for a single state group
Args:
room_id (str): room we are resolving for (used for logging)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
Returns:
Deferred[_StateCacheEntry]: resolved state
"""
logger.debug(
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)): with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
@ -356,15 +423,17 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
state_groups_ids.values(), state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events( state_map_factory=state_map_factory,
ev_ids, get_prev_content=False, check_redacted=False,
),
) )
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.items()
} }
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but
# not get persisted.
state_group = None state_group = None
new_state_event_ids = frozenset(new_state.values()) new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items(): for sg, events in state_groups_ids.items():
@ -401,30 +470,6 @@ class StateHandler(object):
defer.returnValue(cache) defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):

View file

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.state import StateHandler from synapse.state import StateHandler, StateResolutionHandler
from .utils import MockClock from .utils import MockClock
@ -148,11 +148,13 @@ class StateTestCase(unittest.TestCase):
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock", "get_datastore", "get_auth", "get_state_handler", "get_clock",
"get_state_resolution_handler",
]) ])
hs.get_datastore.return_value = self.store hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.store.get_next_state_group.side_effect = Mock self.store.get_next_state_group.side_effect = Mock
self.store.get_state_group_delta.return_value = (None, None) self.store.get_state_group_delta.return_value = (None, None)