Handle the fact that workers can't generate state groups

This commit is contained in:
Erik Johnston 2016-08-31 15:53:19 +01:00
parent f51888530d
commit ed7a703d4c
2 changed files with 60 additions and 27 deletions

View file

@ -281,11 +281,13 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, curr_state_ids = yield self.state.resolve_state_groups( entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids room_id, latest_event_ids
) )
ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids) ret = yield self.store.is_host_joined(
room_id, host, entry.state_group, entry.state
)
defer.returnValue(ret) defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):

View file

@ -43,11 +43,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
def _gen_state_id():
global _NEXT_STATE_ID
s = "X%d" % (_NEXT_STATE_ID,)
_NEXT_STATE_ID += 1
return s
class _StateCacheEntry(object): class _StateCacheEntry(object):
def __init__(self, state, state_group, ts): __slots__ = ["state", "state_group", "state_id"]
def __init__(self, state, state_group):
self.state = state self.state = state
self.state_group = state_group self.state_group = state_group
# The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state
# group each time we resolve state, so we generate a separate one that
# isn't persisted and is used solely for caches.
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
self.state_id = state_group
else:
self.state_id = _gen_state_id()
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
@ -93,7 +117,8 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type: if event_type:
event_id = state.get((event_type, state_key)) event_id = state.get((event_type, state_key))
@ -116,7 +141,8 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type: if event_type:
defer.returnValue(state.get((event_type, state_key))) defer.returnValue(state.get((event_type, state_key)))
@ -127,9 +153,9 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_user_in_room(self, room_id): def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state( joined_users = yield self.store.get_joined_users_from_state(
room_id, group, state_ids room_id, entry.state_id, entry.state
) )
defer.returnValue(joined_users) defer.returnValue(joined_users)
@ -191,23 +217,26 @@ class StateHandler(object):
defer.returnValue(context) defer.returnValue(context)
if event.is_state(): if event.is_state():
ret = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
event_type=event.type, event_type=event.type,
state_key=event.state_key, state_key=event.state_key,
) )
else: else:
ret = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
group, curr_state = ret curr_state = entry.state
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state() or group is None: if event.is_state():
context.state_group = self.store.get_next_state_group() context.state_group = self.store.get_next_state_group()
else: else:
context.state_group = group if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
@ -249,16 +278,15 @@ class StateHandler(object):
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
defer.returnValue((name, state_list,)) defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
))
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
if cache: if cache:
cache.ts = self.clock.time_msec() defer.returnValue(cache)
defer.returnValue(
(cache.state_group, cache.state,)
)
logger.info( logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids) "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
@ -302,19 +330,22 @@ class StateHandler(object):
if new_state_event_ids == frozenset(e_id for e_id in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
if not state_group: if state_group is None:
state_group = self.store.get_next_state_group() # Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache
# hits.
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
cache = _StateCacheEntry(
state=new_state,
state_group=state_group,
)
if self._state_cache is not None: if self._state_cache is not None:
cache = _StateCacheEntry(
state=new_state,
state_group=state_group,
ts=self.clock.time_msec()
)
self._state_cache[group_names] = cache self._state_cache[group_names] = cache
defer.returnValue((state_group, new_state,)) defer.returnValue(cache)
def resolve_events(self, state_sets, event): def resolve_events(self, state_sets, event):
logger.info( logger.info(