Move where we ack federation

This commit is contained in:
Erik Johnston 2017-04-04 10:29:29 +01:00
parent 36c28bc467
commit 6ce6bbedcb

View file

@ -153,15 +153,13 @@ class FederationSenderServer(HomeServer):
class FederationSenderReplicationHandler(ReplicationClientHandler):
def __init__(self, hs):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs)
self.send_handler = FederationSenderHandler(hs, self)
def on_rdata(self, stream_name, token, rows):
super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows
)
self.send_handler.process_replication_rows(stream_name, token, rows)
if stream_name == "federation":
self.send_federation_ack(token)
def get_streams_to_replicate(self):
args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
@ -248,13 +246,16 @@ class FederationSenderHandler(object):
"""Processes the replication stream and forwards the appropriate entries
to the federation sender.
"""
def __init__(self, hs):
def __init__(self, hs, replication_client):
self.store = hs.get_datastore()
self.federation_sender = hs.get_federation_sender()
self.replication_client = replication_client
self.federation_position = self.store.federation_out_pos_startup
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
self._last_ack = self.federation_position
self._room_serials = {}
self._room_typing = {}
@ -345,11 +346,19 @@ class FederationSenderHandler(object):
@defer.inlineCallbacks
def update_token(self, token):
self.federation_position = token
# We linearize here to ensure we don't have races updating the token
with (yield self._fed_position_linearizer.queue(None)):
if self._last_ack < self.federation_position:
yield self.store.update_federation_out_pos(
"federation", self.federation_position
)
# We ACK this token over replication so that the master can drop
# its in memory queues
self.replication_client.send_federation_ack(self.federation_position)
self._last_ack = self.federation_position
if __name__ == '__main__':
with LoggingContext("main"):