Port synapse.replication.tcp to async/await (#6666)

* Port synapse.replication.tcp to async/await

* Newsfile

* Correctly document type of on_<FOO> functions as async

* Don't be overenthusiastic with the asyncing....
This commit is contained in:
Erik Johnston 2020-01-16 09:16:12 +00:00 committed by GitHub
parent 19a1aac48c
commit 48c3a96886
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 80 additions and 105 deletions

1
changelog.d/6666.misc Normal file
View file

@ -0,0 +1 @@
Port `synapse.replication.tcp` to async/await.

View file

@ -84,8 +84,7 @@ class AdminCmdServer(HomeServer):
class AdminCmdReplicationHandler(ReplicationClientHandler): class AdminCmdReplicationHandler(ReplicationClientHandler):
@defer.inlineCallbacks async def on_rdata(self, stream_name, token, rows):
def on_rdata(self, stream_name, token, rows):
pass pass
def get_streams_to_replicate(self): def get_streams_to_replicate(self):

View file

@ -115,9 +115,8 @@ class ASReplicationHandler(ReplicationClientHandler):
super(ASReplicationHandler, self).__init__(hs.get_datastore()) super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_rdata(self, stream_name, token, rows):
def on_rdata(self, stream_name, token, rows): await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
if stream_name == "events": if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering() max_stream_id = self.store.get_room_max_stream_ordering()

View file

@ -145,9 +145,8 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self) self.send_handler = FederationSenderHandler(hs, self)
@defer.inlineCallbacks async def on_rdata(self, stream_name, token, rows):
def on_rdata(self, stream_name, token, rows): await super(FederationSenderReplicationHandler, self).on_rdata(
yield super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows stream_name, token, rows
) )
self.send_handler.process_replication_rows(stream_name, token, rows) self.send_handler.process_replication_rows(stream_name, token, rows)

View file

@ -141,9 +141,8 @@ class PusherReplicationHandler(ReplicationClientHandler):
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
@defer.inlineCallbacks async def on_rdata(self, stream_name, token, rows):
def on_rdata(self, stream_name, token, rows): await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -358,9 +358,8 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks async def on_rdata(self, stream_name, token, rows):
def on_rdata(self, stream_name, token, rows): await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self): def get_streams_to_replicate(self):

View file

@ -172,9 +172,8 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler() self.user_directory = hs.get_user_directory_handler()
@defer.inlineCallbacks async def on_rdata(self, stream_name, token, rows):
def on_rdata(self, stream_name, token, rows): await super(UserDirectoryReplicationHandler, self).on_rdata(
yield super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows stream_name, token, rows
) )
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:

View file

@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object):
def federation_ack(self, token): def federation_ack(self, token):
self._clear_queue_before_pos(token) self._clear_queue_before_pos(token)
def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
"""Get rows to be sent over federation between the two tokens """Get rows to be sent over federation between the two tokens
Args: Args:

View file

@ -257,7 +257,7 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id] "typing_key", self._latest_room_serial, rooms=[member.room_id]
) )
def get_all_typing_updates(self, last_id, current_id): async def get_all_typing_updates(self, last_id, current_id):
if last_id == current_id: if last_id == current_id:
return [] return []

View file

@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
port = hs.config.worker_replication_port port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory) hs.get_reactor().connectTCP(host, port, self.factory)
def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to By default this just pokes the slave store. Can be overridden in subclasses to
@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
token (int): stream token for this batch of rows token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row. Stream.parse_row.
Returns:
Deferred|None
""" """
logger.debug("Received rdata %s -> %s", stream_name, token) logger.debug("Received rdata %s -> %s", stream_name, token)
return self.store.process_replication_rows(stream_name, token, rows) self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token): async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes """Called when we get new position data. By default this just pokes
the slave store. the slave store.
Can be overriden in subclasses to handle more. Can be overriden in subclasses to handle more.
""" """
return self.store.process_replication_rows(stream_name, token, []) self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data): def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting """When we received a SYNC we wake up any deferreds that were waiting

View file

@ -81,12 +81,11 @@ from synapse.replication.tcp.commands import (
SyncCommand, SyncCommand,
UserSyncCommand, UserSyncCommand,
) )
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection from synapse.types import Collection
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from .streams import STREAMS_MAP
connection_close_counter = Counter( connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
) )
@ -241,19 +240,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
) )
def handle_command(self, cmd): async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream. """Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND> By default delegates to on_<COMMAND>, which should return an awaitable.
Args: Args:
cmd (synapse.replication.tcp.commands.Command): received command cmd: received command
Returns:
Deferred
""" """
handler = getattr(self, "on_%s" % (cmd.NAME,)) handler = getattr(self, "on_%s" % (cmd.NAME,))
return handler(cmd) await handler(cmd)
def close(self): def close(self):
logger.warning("[%s] Closing connection", self.id()) logger.warning("[%s] Closing connection", self.id())
@ -326,10 +322,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending: for cmd in pending:
self.send_command(cmd) self.send_command(cmd)
def on_PING(self, line): async def on_PING(self, line):
self.received_ping = True self.received_ping = True
def on_ERROR(self, cmd): async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self): def pauseProducing(self):
@ -429,16 +425,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self) BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self) self.streamer.new_connection(self)
def on_NAME(self, cmd): async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data) logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data self.name = cmd.data
def on_USER_SYNC(self, cmd): async def on_USER_SYNC(self, cmd):
return self.streamer.on_user_sync( await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
) )
def on_REPLICATE(self, cmd): async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name stream_name = cmd.stream_name
token = cmd.token token = cmd.token
@ -449,23 +445,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
for stream in iterkeys(self.streamer.streams_by_name) for stream in iterkeys(self.streamer.streams_by_name)
] ]
return make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
else: else:
return self.subscribe_to_stream(stream_name, token) await self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd): async def on_FEDERATION_ACK(self, cmd):
return self.streamer.federation_ack(cmd.token) self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd): async def on_REMOVE_PUSHER(self, cmd):
return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
def on_INVALIDATE_CACHE(self, cmd): async def on_INVALIDATE_CACHE(self, cmd):
return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd): async def on_USER_IP(self, cmd):
return self.streamer.on_user_ip( self.streamer.on_user_ip(
cmd.user_id, cmd.user_id,
cmd.access_token, cmd.access_token,
cmd.ip, cmd.ip,
@ -474,8 +470,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen, cmd.last_seen,
) )
@defer.inlineCallbacks async def subscribe_to_stream(self, stream_name, token):
def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream. """Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those This invloves checking if they've missed anything and sending those
@ -487,7 +482,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
try: try:
# Get missing updates # Get missing updates
updates, current_token = yield self.streamer.get_stream_updates( updates, current_token = await self.streamer.get_stream_updates(
stream_name, token stream_name, token
) )
@ -572,7 +567,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
Args: Args:
@ -580,14 +575,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
token (int): stream token for this batch of rows token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row. Stream.parse_row.
Returns:
Deferred|None
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def on_position(self, stream_name, token): async def on_position(self, stream_name, token):
"""Called when we get new position data.""" """Called when we get new position data."""
raise NotImplementedError() raise NotImplementedError()
@ -676,12 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
if not self.streams_connecting: if not self.streams_connecting:
self.handler.finished_connecting() self.handler.finished_connecting()
def on_SERVER(self, cmd): async def on_SERVER(self, cmd):
if cmd.data != self.server_name: if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote") self.send_error("Wrong remote")
def on_RDATA(self, cmd): async def on_RDATA(self, cmd):
stream_name = cmd.stream_name stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc() inbound_rdata_count.labels(stream_name).inc()
@ -701,19 +693,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, []) rows = self.pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
return self.handler.on_rdata(stream_name, cmd.token, rows) await self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd): async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting # When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date. # missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name) self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting: if not self.streams_connecting:
self.handler.finished_connecting() self.handler.finished_connecting()
return self.handler.on_position(cmd.stream_name, cmd.token) await self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd): async def on_SYNC(self, cmd):
return self.handler.on_sync(cmd.data) self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token): def replicate(self, stream_name, token):
"""Send the subscription request to the server """Send the subscription request to the server

View file

@ -23,7 +23,6 @@ from six import itervalues
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
@ -155,8 +154,7 @@ class ReplicationStreamer(object):
run_as_background_process("replication_notifier", self._run_notifier_loop) run_as_background_process("replication_notifier", self._run_notifier_loop)
@defer.inlineCallbacks async def _run_notifier_loop(self):
def _run_notifier_loop(self):
self.is_looping = True self.is_looping = True
try: try:
@ -185,7 +183,7 @@ class ReplicationStreamer(object):
continue continue
if self._replication_torture_level: if self._replication_torture_level:
yield self.clock.sleep( await self.clock.sleep(
self._replication_torture_level / 1000.0 self._replication_torture_level / 1000.0
) )
@ -196,7 +194,7 @@ class ReplicationStreamer(object):
stream.upto_token, stream.upto_token,
) )
try: try:
updates, current_token = yield stream.get_updates() updates, current_token = await stream.get_updates()
except Exception: except Exception:
logger.info("Failed to handle stream %s", stream.NAME) logger.info("Failed to handle stream %s", stream.NAME)
raise raise
@ -233,7 +231,7 @@ class ReplicationStreamer(object):
self.is_looping = False self.is_looping = False
@measure_func("repl.get_stream_updates") @measure_func("repl.get_stream_updates")
def get_stream_updates(self, stream_name, token): async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when """For a given stream get all updates since token. This is called when
a client first subscribes to a stream. a client first subscribes to a stream.
""" """
@ -241,7 +239,7 @@ class ReplicationStreamer(object):
if not stream: if not stream:
raise Exception("unknown stream %s", stream_name) raise Exception("unknown stream %s", stream_name)
return stream.get_updates_since(token) return await stream.get_updates_since(token)
@measure_func("repl.federation_ack") @measure_func("repl.federation_ack")
def federation_ack(self, token): def federation_ack(self, token):
@ -252,22 +250,20 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token) self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync") @measure_func("repl.on_user_sync")
@defer.inlineCallbacks async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker. """A client has started/stopped syncing on a worker.
""" """
user_sync_counter.inc() user_sync_counter.inc()
yield self.presence_handler.update_external_syncs_row( await self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms conn_id, user_id, is_syncing, last_sync_ms
) )
@measure_func("repl.on_remove_pusher") @measure_func("repl.on_remove_pusher")
@defer.inlineCallbacks async def on_remove_pusher(self, app_id, push_key, user_id):
def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher """A client has asked us to remove a pusher
""" """
remove_pusher_counter.inc() remove_pusher_counter.inc()
yield self.store.delete_pusher_by_app_id_pushkey_user_id( await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id app_id=app_id, pushkey=push_key, user_id=user_id
) )
@ -281,15 +277,16 @@ class ReplicationStreamer(object):
getattr(self.store, cache_func).invalidate(tuple(keys)) getattr(self.store, cache_func).invalidate(tuple(keys))
@measure_func("repl.on_user_ip") @measure_func("repl.on_user_ip")
@defer.inlineCallbacks async def on_user_ip(
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): self, user_id, access_token, ip, user_agent, device_id, last_seen
):
"""The client saw a user request """The client saw a user request
""" """
user_ip_cache_counter.inc() user_ip_cache_counter.inc()
yield self.store.insert_client_ip( await self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen user_id, access_token, ip, user_agent, device_id, last_seen
) )
yield self._server_notices_sender.on_user_ip(user_id) await self._server_notices_sender.on_user_ip(user_id)
def send_sync_to_all_connections(self, data): def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients. """Sends a SYNC command to all clients.

View file

@ -19,8 +19,6 @@ import logging
from collections import namedtuple from collections import namedtuple
from typing import Any from typing import Any
from twisted.internet import defer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -144,8 +142,7 @@ class Stream(object):
self.upto_token = self.current_token() self.upto_token = self.current_token()
self.last_token = self.upto_token self.last_token = self.upto_token
@defer.inlineCallbacks async def get_updates(self):
def get_updates(self):
"""Gets all updates since the last time this function was called (or """Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before), since the stream was constructed if it hadn't been called before),
until the `upto_token` until the `upto_token`
@ -156,13 +153,12 @@ class Stream(object):
list of ``(token, row)`` entries. ``row`` will be json-serialised and list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam. sent over the replication steam.
""" """
updates, current_token = yield self.get_updates_since(self.last_token) updates, current_token = await self.get_updates_since(self.last_token)
self.last_token = current_token self.last_token = current_token
return updates, current_token return updates, current_token
@defer.inlineCallbacks async def get_updates_since(self, from_token):
def get_updates_since(self, from_token):
"""Like get_updates except allows specifying from when we should """Like get_updates except allows specifying from when we should
stream updates stream updates
@ -182,15 +178,16 @@ class Stream(object):
if from_token == current_token: if from_token == current_token:
return [], current_token return [], current_token
logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED: if self._LIMITED:
rows = yield self.update_function( rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
) )
# never turn more than MAX_EVENTS_BEHIND + 1 into updates. # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else: else:
rows = yield self.update_function(from_token, current_token) rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows] updates = [(row[0], row[1:]) for row in rows]
@ -295,9 +292,8 @@ class PushRulesStream(Stream):
push_rules_token, _ = self.store.get_push_rules_stream_token() push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token return push_rules_token
@defer.inlineCallbacks async def update_function(self, from_token, to_token, limit):
def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
return [(row[0], row[2]) for row in rows] return [(row[0], row[2]) for row in rows]
@ -413,9 +409,8 @@ class AccountDataStream(Stream):
super(AccountDataStream, self).__init__(hs) super(AccountDataStream, self).__init__(hs)
@defer.inlineCallbacks async def update_function(self, from_token, to_token, limit):
def update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data(
global_results, room_results = yield self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit from_token, from_token, to_token, limit
) )

View file

@ -19,8 +19,6 @@ from typing import Tuple, Type
import attr import attr
from twisted.internet import defer
from ._base import Stream from ._base import Stream
@ -122,16 +120,15 @@ class EventsStream(Stream):
super(EventsStream, self).__init__(hs) super(EventsStream, self).__init__(hs)
@defer.inlineCallbacks async def update_function(self, from_token, current_token, limit=None):
def update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows(
event_rows = yield self._store.get_all_new_forward_event_rows(
from_token, current_token, limit from_token, current_token, limit
) )
event_updates = ( event_updates = (
(row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
) )
state_rows = yield self._store.get_all_updated_current_state_deltas( state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, current_token, limit from_token, current_token, limit
) )
state_updates = ( state_updates = (

View file

@ -73,6 +73,6 @@ class TestReplicationClientHandler(object):
def finished_connecting(self): def finished_connecting(self):
pass pass
def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, token, rows):
for r in rows: for r in rows:
self.received_rdata_rows.append((stream_name, token, r)) self.received_rdata_rows.append((stream_name, token, r))