Fix unit tests

on_notifier_poke no longer runs synchonously, so we have to do a different hack
to make sure that the replication data has been sent. Let's actually listen for
its arrival.
This commit is contained in:
Richard van der Hoff 2018-07-25 10:30:36 +01:00
parent 3f11d84534
commit f59be4eb0e
2 changed files with 31 additions and 8 deletions

View file

@ -192,7 +192,7 @@ class ReplicationClientHandler(object):
"""Returns a deferred that is resolved when we receive a SYNC command """Returns a deferred that is resolved when we receive a SYNC command
with given data. with given data.
Used by tests. [Not currently] used by tests.
""" """
return self.awaiting_syncs.setdefault(data, defer.Deferred()) return self.awaiting_syncs.setdefault(data, defer.Deferred())

View file

@ -11,23 +11,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile import tempfile
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
from synapse.replication.tcp.client import ( from synapse.replication.tcp.client import (
ReplicationClientFactory, ReplicationClientFactory,
ReplicationClientHandler, ReplicationClientHandler,
) )
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
class TestReplicationClientHandler(ReplicationClientHandler):
"""Overrides on_rdata so that we can wait for it to happen"""
def __init__(self, store):
super(TestReplicationClientHandler, self).__init__(store)
self._rdata_awaiters = []
def await_replication(self):
d = Deferred()
self._rdata_awaiters.append(d)
return make_deferred_yieldable(d)
def on_rdata(self, stream_name, token, rows):
awaiters = self._rdata_awaiters
self._rdata_awaiters = []
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
with PreserveLoggingContext():
for a in awaiters:
a.callback(None)
class BaseSlavedStoreTestCase(unittest.TestCase): class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
@ -52,7 +73,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.addCleanup(listener.stopListening) self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer self.streamer = server_factory.streamer
self.replication_handler = ReplicationClientHandler(self.slaved_store) self.replication_handler = TestReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory( client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler self.hs, "client_name", self.replication_handler
) )
@ -60,12 +81,14 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.addCleanup(client_factory.stopTrying) self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect) self.addCleanup(client_connector.disconnect)
@defer.inlineCallbacks
def replicate(self): def replicate(self):
yield self.streamer.on_notifier_poke() """Tell the master side of replication that something has happened, and then
d = self.replication_handler.await_sync("replication_test") wait for the replication to occur.
self.streamer.send_sync_to_all_connections("replication_test") """
yield d # xxx: should we be more specific in what we wait for?
d = self.replication_handler.await_replication()
self.streamer.on_notifier_poke()
return d
@defer.inlineCallbacks @defer.inlineCallbacks
def check(self, method, args, expected_result=None): def check(self, method, args, expected_result=None):