From ec0a523ac338bab1eb23a6b21227b8f7402cc2d4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Jan 2017 18:37:18 +0000 Subject: [PATCH 1/2] Split out static state methods from StateHandler --- synapse/state.py | 142 ++++++++++++++++++++++++----------------------- 1 file changed, 73 insertions(+), 69 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index b9d5627a82..c75499c3e0 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,6 +16,7 @@ from twisted.internet import defer +from synapse import event_auth from synapse.util.logutils import log_function from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure @@ -335,9 +336,10 @@ class StateHandler(object): [state_map[e_id] for key, e_id in st.items() if e_id in state_map] for st in state_groups_ids.values() ] - new_state, _ = self._resolve_events( - state_sets, event_type, state_key - ) + with Measure(self.clock, "state._resolve_events"): + new_state, _ = Resolver.resolve_events( + state_sets, event_type, state_key + ) new_state = { key: e.event_id for key, e in new_state.items() } @@ -388,68 +390,78 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - if event.is_state(): - return self._resolve_events( - state_sets, event.type, event.state_key - ) - else: - return self._resolve_events(state_sets) + with Measure(self.clock, "state._resolve_events"): + if event.is_state(): + return Resolver.resolve_events( + state_sets, event.type, event.state_key + ) + else: + return Resolver.resolve_events(state_sets) - def _resolve_events(self, state_sets, event_type=None, state_key=""): + +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) + + +class Resolver(object): + @staticmethod + def resolve_events(state_sets, event_type=None, state_key=""): """ Returns (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple (new_state, prev_states). new_state is a map from (type, state_key) to event. prev_states is a list of event_ids. """ - with Measure(self.clock, "state._resolve_events"): - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } - try: - resolved_state = self._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise + try: + resolved_state = Resolver._resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise - new_state = unconflicted_state - new_state.update(resolved_state) + new_state = unconflicted_state + new_state.update(resolved_state) return new_state, prev_states - @log_function - def _resolve_state_events(self, conflicted_state, auth_events): + @staticmethod + def _resolve_state_events(conflicted_state, auth_events): """ This is where we actually decide which of the conflicted state to use. @@ -464,7 +476,7 @@ class StateHandler(object): if power_key in conflicted_state: events = conflicted_state[power_key] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = self._resolve_auth_events( + resolved_state[power_key] = Resolver._resolve_auth_events( events, auth_events) auth_events.update(resolved_state) @@ -472,7 +484,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = self._resolve_auth_events( + resolved_state[key] = Resolver._resolve_auth_events( events, auth_events ) @@ -482,7 +494,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = self._resolve_auth_events( + resolved_state[key] = Resolver._resolve_auth_events( events, auth_events ) @@ -492,14 +504,15 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = self._resolve_normal_events( + resolved_state[key] = Resolver._resolve_normal_events( events, auth_events ) return resolved_state - def _resolve_auth_events(self, events, auth_events): - reverse = [i for i in reversed(self._ordered_events(events))] + @staticmethod + def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] auth_events = dict(auth_events) @@ -507,23 +520,20 @@ class StateHandler(object): for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False) prev_event = event except AuthError: return prev_event return event - def _resolve_normal_events(self, events, auth_events): - for event in self._ordered_events(events): + @staticmethod + def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): try: - # FIXME: hs.get_auth() is bad style, but we need to do it to - # get around circular deps. # The signatures have already been checked at this point - self.hs.get_auth().check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False) return event except AuthError: pass @@ -531,9 +541,3 @@ class StateHandler(object): # Use the last event (the one with the least depth) if they all fail # the auth check. return event - - def _ordered_events(self, events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - - return sorted(events, key=key_func) From beda469bc6e96a0b776c3d6742cf97950819b2f0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:05:24 +0000 Subject: [PATCH 2/2] Put staticmethods at module level --- synapse/state.py | 252 +++++++++++++++++++++++------------------------ 1 file changed, 125 insertions(+), 127 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index c75499c3e0..90b14e758c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -337,7 +337,7 @@ class StateHandler(object): for st in state_groups_ids.values() ] with Measure(self.clock, "state._resolve_events"): - new_state, _ = Resolver.resolve_events( + new_state, _ = resolve_events( state_sets, event_type, state_key ) new_state = { @@ -392,11 +392,11 @@ class StateHandler(object): ) with Measure(self.clock, "state._resolve_events"): if event.is_state(): - return Resolver.resolve_events( + return resolve_events( state_sets, event.type, event.state_key ) else: - return Resolver.resolve_events(state_sets) + return resolve_events(state_sets) def _ordered_events(events): @@ -406,138 +406,136 @@ def _ordered_events(events): return sorted(events, key=key_func) -class Resolver(object): - @staticmethod - def resolve_events(state_sets, event_type=None, state_key=""): - """ - Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. - """ - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e +def resolve_events(state_sets, event_type=None, state_key=""): + """ + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. + """ + state = {} + for st in state_sets: + for e in st: + state.setdefault( + (e.type, e.state_key), + {} + )[e.event_id] = e - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + unconflicted_state = { + k: v.values()[0] for k, v in state.items() + if len(v.values()) == 1 + } - conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 - } + conflicted_state = { + k: v.values() + for k, v in state.items() + if len(v.values()) > 1 + } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] + if event_type: + prev_states_events = conflicted_state.get( + (event_type, state_key), [] + ) + prev_states = [s.event_id for s in prev_states_events] + else: + prev_states = [] + + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except: + logger.exception("Failed to resolve state") + raise + + new_state = unconflicted_state + new_state.update(resolved_state) + + return new_state, prev_states + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state: + events = conflicted_state[power_key] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[power_key] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes - } + auth_events.update(resolved_state) + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_events = dict(auth_events) + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: - resolved_state = Resolver._resolve_state_events( - conflicted_state, auth_events - ) - except: - logger.exception("Failed to resolve state") - raise + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + prev_event = event + except AuthError: + return prev_event - new_state = unconflicted_state - new_state.update(resolved_state) + return event - return new_state, prev_states - @staticmethod - def _resolve_state_events(conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False) + return event + except AuthError: + pass - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = Resolver._resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = Resolver._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = Resolver._resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in conflicted_state.items(): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = Resolver._resolve_normal_events( - events, auth_events - ) - - return resolved_state - - @staticmethod - def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] - - auth_events = dict(auth_events) - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - @staticmethod - def _resolve_normal_events(events, auth_events): - for event in _ordered_events(events): - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event