Only store event_auth for state events

This commit is contained in:
Erik Johnston 2017-05-24 14:22:41 +01:00
parent 58c4720293
commit c049472b8a
3 changed files with 44 additions and 12 deletions

View file

@ -832,7 +832,11 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, event_id): def on_event_auth(self, event_id):
auth = yield self.store.get_auth_chain([event_id]) event = yield self.store.get_event(event_id)
auth = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events],
include_given=True
)
for event in auth: for event in auth:
event.signatures.update( event.signatures.update(
@ -1047,9 +1051,7 @@ class FederationHandler(BaseHandler):
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
state_ids = context.prev_state_ids.values() state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(state_ids)
[event.event_id] + state_ids
))
state = yield self.store.get_events(context.prev_state_ids.values()) state = yield self.store.get_events(context.prev_state_ids.values())
@ -1598,7 +1600,11 @@ class FederationHandler(BaseHandler):
pass pass
# Now get the current auth_chain for the event. # Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain([event_id]) event = yield self.store.get_event(event_id)
local_auth_chain = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events],
include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell # TODO: Check if we would now reject event_id. If so we need to tell
# everyone. # everyone.
@ -1791,7 +1797,9 @@ class FederationHandler(BaseHandler):
auth_ids = yield self.auth.compute_auth_events( auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids event, context.prev_state_ids
) )
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
)
try: try:
# 2. Get remote difference. # 2. Get remote difference.

View file

@ -44,18 +44,41 @@ class EventFederationStore(SQLBaseStore):
self._delete_old_forward_extrem_cache, 60 * 60 * 1000 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
) )
def get_auth_chain(self, event_ids): def get_auth_chain(self, event_ids, include_given=False):
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) """Get auth events for given event_ids. The events *must* be state events.
def get_auth_chain_ids(self, event_ids): Args:
event_ids (list): state events
include_given (bool): include the given events in result
Returns:
list of events
"""
return self.get_auth_chain_ids(
event_ids, include_given=include_given,
).addCallback(self._get_events)
def get_auth_chain_ids(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids (list): state events
include_given (bool): include the given events in result
Returns:
list of event_ids
"""
return self.runInteraction( return self.runInteraction(
"get_auth_chain_ids", "get_auth_chain_ids",
self._get_auth_chain_ids_txn, self._get_auth_chain_ids_txn,
event_ids event_ids, include_given
) )
def _get_auth_chain_ids_txn(self, txn, event_ids): def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
results = set() if include_given:
results = set(event_ids)
else:
results = set()
base_sql = ( base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)" "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"

View file

@ -1120,6 +1120,7 @@ class EventsStore(SQLBaseStore):
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
for auth_id, _ in event.auth_events for auth_id, _ in event.auth_events
if event.is_state()
], ],
) )