diff --git a/synapse/state.py b/synapse/state.py index 32125c95df..033f55d967 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -471,69 +471,39 @@ class StateResolutionHandler(object): "Resolving state for %s with %d groups", room_id, len(state_groups_ids) ) - # build a map from state key to the event_ids which set that state. - # dict[(str, str), set[str]) - state = {} + # start by assuming we won't have any conflicted state, and build up the new + # state map by iterating through the state groups. If we discover a conflict, + # we give up and instead use `resolve_events_with_factory`. + # + # XXX: is this actually worthwhile, or should we just let + # resolve_events_with_factory do it? + new_state = {} + conflicted_state = False for st in itervalues(state_groups_ids): for key, e_id in iteritems(st): - state.setdefault(key, set()).add(e_id) - - # build a map from state key to the event_ids which set that state, - # including only those where there are state keys in conflict. - conflicted_state = { - k: list(v) - for k, v in iteritems(state) - if len(v) > 1 - } + if key in new_state: + conflicted_state = True + break + new_state[key] = e_id + if conflicted_state: + break if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_factory( - list(state_groups_ids.values()), + list(itervalues(state_groups_ids)), event_map=event_map, state_map_factory=state_map_factory, ) - else: - new_state = { - key: e_ids.pop() for key, e_ids in iteritems(state) - } + + # if the new state matches any of the input state groups, we can + # use that state group again. Otherwise we will generate a state_id + # which will be used as a cache key for future resolutions, but + # not get persisted. with Measure(self.clock, "state.create_group_ids"): - # if the new state matches any of the input state groups, we can - # use that state group again. Otherwise we will generate a state_id - # which will be used as a cache key for future resolutions, but - # not get persisted. - state_group = None - new_state_event_ids = frozenset(itervalues(new_state)) - for sg, events in iteritems(state_groups_ids): - if new_state_event_ids == frozenset(e_id for e_id in events): - state_group = sg - break - - # TODO: We want to create a state group for this set of events, to - # increase cache hits, but we need to make sure that it doesn't - # end up as a prev_group without being added to the database - - prev_group = None - delta_ids = None - for old_group, old_ids in iteritems(state_groups_ids): - if not set(new_state) - set(old_ids): - n_delta_ids = { - k: v - for k, v in iteritems(new_state) - if old_ids.get(k) != v - } - if not delta_ids or len(n_delta_ids) < len(delta_ids): - prev_group = old_group - delta_ids = n_delta_ids - - cache = _StateCacheEntry( - state=new_state, - state_group=state_group, - prev_group=prev_group, - delta_ids=delta_ids, - ) + cache = _make_state_cache_entry(new_state, state_groups_ids) if self._state_cache is not None: self._state_cache[group_names] = cache @@ -541,6 +511,70 @@ class StateResolutionHandler(object): defer.returnValue(cache) +def _make_state_cache_entry( + new_state, + state_groups_ids, +): + """Given a resolved state, and a set of input state groups, pick one to base + a new state group on (if any), and return an appropriately-constructed + _StateCacheEntry. + + Args: + new_state (dict[(str, str), str]): resolved state map (mapping from + (type, state_key) to event_id) + + state_groups_ids (dict[int, dict[(str, str), str]]): + map from state group id to the state in that state group + (where 'state' is a map from state key to event id) + + Returns: + _StateCacheEntry + """ + # if the new state matches any of the input state groups, we can + # use that state group again. Otherwise we will generate a state_id + # which will be used as a cache key for future resolutions, but + # not get persisted. + + # first look for exact matches + new_state_event_ids = set(itervalues(new_state)) + for sg, state in iteritems(state_groups_ids): + if len(new_state_event_ids) != len(state): + continue + + old_state_event_ids = set(itervalues(state)) + if new_state_event_ids == old_state_event_ids: + # got an exact match. + return _StateCacheEntry( + state=new_state, + state_group=sg, + ) + + # TODO: We want to create a state group for this set of events, to + # increase cache hits, but we need to make sure that it doesn't + # end up as a prev_group without being added to the database + + # failing that, look for the closest match. + prev_group = None + delta_ids = None + + for old_group, old_state in iteritems(state_groups_ids): + n_delta_ids = { + k: v + for k, v in iteritems(new_state) + if old_state.get(k) != v + } + if not delta_ids or len(n_delta_ids) < len(delta_ids): + prev_group = old_group + delta_ids = n_delta_ids + + return _StateCacheEntry( + state=new_state, + state_group=None, + prev_group=prev_group, + delta_ids=delta_ids, + ) + + def _ordered_events(events): def key_func(e): return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest() @@ -582,7 +616,7 @@ def _seperate(state_sets): with them in different state sets. Args: - state_sets(list[dict[(str, str), str]]): + state_sets(iterable[dict[(str, str), str]]): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. @@ -596,10 +630,11 @@ def _seperate(state_sets): conflicted_state is a dict mapping (type, state_key) to a set of event ids for conflicted state keys. """ - unconflicted_state = dict(state_sets[0]) + state_set_iterator = iter(state_sets) + unconflicted_state = dict(next(state_set_iterator)) conflicted_state = {} - for state_set in state_sets[1:]: + for state_set in state_set_iterator: for key, value in iteritems(state_set): # Check if there is an unconflicted entry for the state key. unconflicted_value = unconflicted_state.get(key)