diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 63e633548d..a333acc4aa 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -832,7 +832,11 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def on_event_auth(self, event_id): - auth = yield self.store.get_auth_chain([event_id]) + event = yield self.store.get_event(event_id) + auth = yield self.store.get_auth_chain( + [auth_id for auth_id, _ in event.auth_events], + include_given=True + ) for event in auth: event.signatures.update( @@ -1047,9 +1051,7 @@ class FederationHandler(BaseHandler): yield user_joined_room(self.distributor, user, event.room_id) state_ids = context.prev_state_ids.values() - auth_chain = yield self.store.get_auth_chain(set( - [event.event_id] + state_ids - )) + auth_chain = yield self.store.get_auth_chain(state_ids) state = yield self.store.get_events(context.prev_state_ids.values()) @@ -1598,7 +1600,11 @@ class FederationHandler(BaseHandler): pass # Now get the current auth_chain for the event. - local_auth_chain = yield self.store.get_auth_chain([event_id]) + event = yield self.store.get_event(event_id) + local_auth_chain = yield self.store.get_auth_chain( + [auth_id for auth_id, _ in event.auth_events], + include_given=True + ) # TODO: Check if we would now reject event_id. If so we need to tell # everyone. @@ -1791,7 +1797,9 @@ class FederationHandler(BaseHandler): auth_ids = yield self.auth.compute_auth_events( event, context.prev_state_ids ) - local_auth_chain = yield self.store.get_auth_chain(auth_ids) + local_auth_chain = yield self.store.get_auth_chain( + auth_ids, include_given=True + ) try: # 2. Get remote difference. diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 519059c306..72126c682e 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -44,18 +44,41 @@ class EventFederationStore(SQLBaseStore): self._delete_old_forward_extrem_cache, 60 * 60 * 1000 ) - def get_auth_chain(self, event_ids): - return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) + def get_auth_chain(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. - def get_auth_chain_ids(self, event_ids): + Args: + event_ids (list): state events + include_given (bool): include the given events in result + + Returns: + list of events + """ + return self.get_auth_chain_ids( + event_ids, include_given=include_given, + ).addCallback(self._get_events) + + def get_auth_chain_ids(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. + + Args: + event_ids (list): state events + include_given (bool): include the given events in result + + Returns: + list of event_ids + """ return self.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, - event_ids + event_ids, include_given ) - def _get_auth_chain_ids_txn(self, txn, event_ids): - results = set() + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + if include_given: + results = set(event_ids) + else: + results = set() base_sql = ( "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c4aeb48800..3d4f53ea55 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1120,6 +1120,7 @@ class EventsStore(SQLBaseStore): } for event, _ in events_and_contexts for auth_id, _ in event.auth_events + if event.is_state() ], )