diff --git a/synapse/state.py b/synapse/state.py index 66e1a685e8..5028b0ac49 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,6 +16,7 @@ from twisted.internet import defer +from synapse import event_auth from synapse.util.logutils import log_function from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure @@ -339,9 +340,10 @@ class StateHandler(object): [state_map[e_id] for key, e_id in st.items() if e_id in state_map] for st in state_groups_ids.values() ] - new_state, _ = self._resolve_events( - state_sets, event_type, state_key - ) + with Measure(self.clock, "state._resolve_events"): + new_state, _ = resolve_events( + state_sets, event_type, state_key + ) new_state = { key: e.event_id for key, e in new_state.items() } @@ -392,152 +394,152 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - if event.is_state(): - return self._resolve_events( - state_sets, event.type, event.state_key - ) - else: - return self._resolve_events(state_sets) - - def _resolve_events(self, state_sets, event_type=None, state_key=""): - """ - Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. - """ with Measure(self.clock, "state._resolve_events"): - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e - - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } - - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } - - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] + if event.is_state(): + return resolve_events( + state_sets, event.type, event.state_key ) - prev_states = [s.event_id for s in prev_states_events] else: - prev_states = [] + return resolve_events(state_sets) - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } - try: - resolved_state = self._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - new_state = unconflicted_state - new_state.update(resolved_state) + return sorted(events, key=key_func) - return new_state, prev_states - @log_function - def _resolve_state_events(self, conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. +def resolve_events(state_sets, event_type=None, state_key=""): + """ + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. + """ + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = self._resolve_auth_events( - events, auth_events) + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - auth_events.update(resolved_state) + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - for key, events in conflicted_state.items(): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = self._resolve_auth_events( - events, - auth_events - ) + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] - auth_events.update(resolved_state) + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } - for key, events in conflicted_state.items(): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = self._resolve_auth_events( - events, - auth_events - ) + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise - auth_events.update(resolved_state) + new_state = unconflicted_state + new_state.update(resolved_state) - for key, events in conflicted_state.items(): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = self._resolve_normal_events( - events, auth_events - ) + return new_state, prev_states - return resolved_state - def _resolve_auth_events(self, events, auth_events): - reverse = [i for i in reversed(self._ordered_events(events))] +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. - auth_events = dict(auth_events) + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state: + events = conflicted_state[power_key] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[power_key] = _resolve_auth_events( + events, auth_events) - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. - # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) - prev_event = event - except AuthError: - return prev_event + auth_events.update(resolved_state) - return event + for key, events in conflicted_state.items(): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) - def _resolve_normal_events(self, events, auth_events): - for event in self._ordered_events(events): - try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. - # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) - return event - except AuthError: - pass + auth_events.update(resolved_state) - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) - def _ordered_events(self, events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + auth_events.update(resolved_state) - return sorted(events, key=key_func) + for key, events in conflicted_state.items(): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_events = dict(auth_events) + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + prev_event = event + except AuthError: + return prev_event + + return event + + +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event