Remove the is_new_state argument to persist event.

Move the checks for whether an event is new state inside persist
event itself.

This was harder than expected because there wasn't enough information
passed to persist event to correctly handle invites from remote servers
for new rooms.
This commit is contained in:
Mark Haines 2016-03-31 15:00:42 +01:00
parent 62e395f0e3
commit 76503f95ed
3 changed files with 56 additions and 55 deletions

View file

@ -33,6 +33,9 @@ class _EventInternalMetadata(object):
def is_outlier(self): def is_outlier(self):
return getattr(self, "outlier", False) return getattr(self, "outlier", False)
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def _event_dict_property(key): def _event_dict_property(key):
def getter(self): def getter(self):

View file

@ -102,8 +102,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -174,11 +173,7 @@ class FederationHandler(BaseHandler):
}) })
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
yield self._handle_new_events( yield self._handle_new_events(origin, event_infos)
origin,
event_infos,
outliers=True
)
try: try:
context, event_stream_id, max_stream_id = yield self._handle_new_event( context, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -761,6 +756,7 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
@ -1069,9 +1065,6 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None): def _handle_new_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self._prep_event( context = yield self._prep_event(
origin, event, origin, event,
state=state, state=state,
@ -1087,14 +1080,12 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
is_new_state=not outlier,
) )
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False, def _handle_new_events(self, origin, event_infos, backfilled=False):
outliers=False):
contexts = yield defer.gatherResults( contexts = yield defer.gatherResults(
[ [
self._prep_event( self._prep_event(
@ -1113,7 +1104,6 @@ class FederationHandler(BaseHandler):
for ev_info, context in itertools.izip(event_infos, contexts) for ev_info, context in itertools.izip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1176,7 +1166,6 @@ class FederationHandler(BaseHandler):
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
], ],
is_new_state=False,
) )
new_event_context = yield self.state_handler.compute_event_context( new_event_context = yield self.state_handler.compute_event_context(
@ -1185,7 +1174,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
is_new_state=True,
current_state=state, current_state=state,
) )

View file

@ -61,8 +61,7 @@ class EventsStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False, def persist_events(self, events_and_contexts, backfilled=False):
is_new_state=True):
if not events_and_contexts: if not events_and_contexts:
return return
@ -110,13 +109,11 @@ class EventsStore(SQLBaseStore):
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
is_new_state=is_new_state,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, def persist_event(self, event, context, current_state=None):
is_new_state=True, current_state=None):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
@ -128,7 +125,6 @@ class EventsStore(SQLBaseStore):
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context, context=context,
is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
except _RollbackButIsFineException: except _RollbackButIsFineException:
@ -194,8 +190,7 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@log_function @log_function
def _persist_event_txn(self, txn, event, context, def _persist_event_txn(self, txn, event, context, current_state):
is_new_state, current_state):
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
@ -236,12 +231,10 @@ class EventsStore(SQLBaseStore):
txn, txn,
[(event, context)], [(event, context)],
backfilled=False, backfilled=False,
is_new_state=is_new_state,
) )
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled):
is_new_state):
depth_updates = {} depth_updates = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
@ -452,10 +445,9 @@ class EventsStore(SQLBaseStore):
txn, [event for event, _ in events_and_contexts] txn, [event for event, _ in events_and_contexts]
) )
state_events_and_contexts = filter( state_events_and_contexts = [
lambda i: i[0].is_state(), ec for ec in events_and_contexts if ec[0].is_state()
events_and_contexts, ]
)
state_values = [] state_values = []
for event, context in state_events_and_contexts: for event, context in state_events_and_contexts:
@ -493,32 +485,50 @@ class EventsStore(SQLBaseStore):
], ],
) )
if is_new_state: for event, _ in state_events_and_contexts:
for event, _ in state_events_and_contexts: if backfilled:
if not context.rejected: # Backfilled events come before the current state so shouldn't
txn.call_after( # clobber it.
self._get_current_state_for_key.invalidate, continue
(event.room_id, event.type, event.state_key,)
)
if event.type in [EventTypes.Name, EventTypes.Aliases]: if (not event.internal_metadata.is_invite_from_remote()
txn.call_after( and event.internal_metadata.is_outlier()):
self.get_room_name_and_aliases.invalidate, # Outlier events generally shouldn't clobber the current state.
(event.room_id,) # However invites from remote severs for rooms we aren't in
) # are a bit special: they don't come with any associated
# state so are technically an outlier, however all the
# client-facing code assumes that they are in the current
# state table so we insert the event anyway.
continue
self._simple_upsert_txn( if context.rejected:
txn, # If the event failed it's auth checks then it shouldn't
"current_state_events", # clobbler the current state.
keyvalues={ continue
"room_id": event.room_id,
"type": event.type, txn.call_after(
"state_key": event.state_key, self._get_current_state_for_key.invalidate,
}, (event.room_id, event.type, event.state_key,)
values={ )
"event_id": event.event_id,
} if event.type in [EventTypes.Name, EventTypes.Aliases]:
) txn.call_after(
self.get_room_name_and_aliases.invalidate,
(event.room_id,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return return