Merge branch 'release-v0.6.0' of github.com:matrix-org/synapse into erikj-perf

This commit is contained in:
Erik Johnston 2014-12-18 18:57:21 +00:00
commit 41ce544abe
5 changed files with 80 additions and 45 deletions

View file

@ -256,31 +256,35 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_state_for_context(self, destination, context, event_id=None): def get_state_for_context(self, destination, context, event_id):
"""Requests all of the `current` state PDUs for a given context from """Requests all of the `current` state PDUs for a given context from
a remote home server. a remote home server.
Args: Args:
destination (str): The remote homeserver to query for the state. destination (str): The remote homeserver to query for the state.
context (str): The context we're interested in. context (str): The context we're interested in.
event_id (str): The id of the event we want the state at.
Returns: Returns:
Deferred: Results in a list of PDUs. Deferred: Results in a list of PDUs.
""" """
transaction_data = yield self.transport_layer.get_context_state( result = yield self.transport_layer.get_context_state(
destination, destination,
context, context,
event_id=event_id, event_id=event_id,
) )
transaction = Transaction(**transaction_data)
pdus = [ pdus = [
self.event_from_pdu_json(p, outlier=True) self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
for p in transaction.pdus
] ]
defer.returnValue(pdus) auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -383,10 +387,16 @@ class ReplicationLayer(object):
context, context,
event_id, event_id,
) )
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
else: else:
raise NotImplementedError("Specify an event") raise NotImplementedError("Specify an event")
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -573,6 +583,8 @@ class ReplicationLayer(object):
state = None state = None
auth_chain = []
# We need to make sure we have all the auth events. # We need to make sure we have all the auth events.
# for e_id, _ in pdu.auth_events: # for e_id, _ in pdu.auth_events:
# exists = yield self._get_persisted_pdu( # exists = yield self._get_persisted_pdu(
@ -645,7 +657,7 @@ class ReplicationLayer(object):
"_handle_new_pdu getting state for %s", "_handle_new_pdu getting state for %s",
pdu.room_id pdu.room_id
) )
state = yield self.get_state_for_context( state, auth_chain = yield self.get_state_for_context(
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
@ -655,6 +667,7 @@ class ReplicationLayer(object):
pdu, pdu,
backfilled=backfilled, backfilled=backfilled,
state=state, state=state,
auth_chain=auth_chain,
) )
else: else:
ret = None ret = None

View file

@ -95,7 +95,8 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, backfilled, state=None): def on_receive_pdu(self, origin, pdu, backfilled, state=None,
auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -150,8 +151,15 @@ class FederationHandler(BaseHandler):
if not is_in_room and not event.internal_metadata.outlier: if not is_in_room and not event.internal_metadata.outlier:
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
replication_layer = self.replication_layer replication = self.replication_layer
auth_chain = yield replication_layer.get_event_auth(
if not state:
state, auth_chain = yield replication.get_state_for_context(
origin, context=event.room_id, event_id=event.event_id,
)
if not auth_chain:
auth_chain = yield replication.get_event_auth(
origin, origin,
context=event.room_id, context=event.room_id,
event_id=event.event_id, event_id=event.event_id,
@ -160,25 +168,18 @@ class FederationHandler(BaseHandler):
for e in auth_chain: for e in auth_chain:
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e, fetch_missing=False) yield self._handle_new_event(e, fetch_auth_from=origin)
except: except:
logger.exception( logger.exception(
"Failed to handle auth event %s", "Failed to handle auth event %s",
e.event_id, e.event_id,
) )
if not state:
state = yield replication_layer.get_state_for_context(
origin,
context=event.room_id,
event_id=event.event_id,
)
# FIXME: Get auth chain for these state events
current_state = state current_state = state
if state: if state:
for e in state: for e in state:
logging.info("A :) %r", e)
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e) yield self._handle_new_event(e)
@ -392,7 +393,7 @@ class FederationHandler(BaseHandler):
for e in auth_chain: for e in auth_chain:
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e, fetch_missing=False) yield self._handle_new_event(e)
except: except:
logger.exception( logger.exception(
"Failed to handle auth event %s", "Failed to handle auth event %s",
@ -404,8 +405,7 @@ class FederationHandler(BaseHandler):
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event( yield self._handle_new_event(
e, e, fetch_auth_from=target_host
fetch_missing=True
) )
except: except:
logger.exception( logger.exception(
@ -682,7 +682,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False, def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_missing=True): current_state=None, fetch_auth_from=None):
logger.debug( logger.debug(
"_handle_new_event: Before annotate: %s, sigs: %s", "_handle_new_event: Before annotate: %s, sigs: %s",
@ -703,11 +703,20 @@ class FederationHandler(BaseHandler):
known_ids = set( known_ids = set(
[s.event_id for s in context.auth_events.values()] [s.event_id for s in context.auth_events.values()]
) )
for e_id, _ in event.auth_events: for e_id, _ in event.auth_events:
if e_id not in known_ids: if e_id not in known_ids:
e = yield self.store.get_event( e = yield self.store.get_event(e_id, allow_none=True)
e_id, allow_none=True,
if not e and fetch_auth_from is not None:
# Grab the auth_chain over federation if we are missing
# auth events.
auth_chain = yield self.replication_layer.get_event_auth(
fetch_auth_from, event.event_id, event.room_id
) )
for auth_event in auth_chain:
yield self._handle_new_event(auth_event)
e = yield self.store.get_event(e_id, allow_none=True)
if not e: if not e:
# TODO: Do some conflict res to make sure that we're # TODO: Do some conflict res to make sure that we're

View file

@ -115,10 +115,10 @@ class Signal(object):
failure.value, failure.value,
failure.getTracebackObject())) failure.getTracebackObject()))
if not self.suppress_failures: if not self.suppress_failures:
raise failure failure.raiseException()
deferreds.append(d.addErrback(eb)) deferreds.append(d.addErrback(eb))
results = []
result = yield defer.DeferredList( for deferred in deferreds:
deferreds, fireOnOneErrback=not self.suppress_failures result = yield deferred
) results.append(result)
defer.returnValue(result) defer.returnValue(results)

View file

@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_auth_chain",
]) ])
self.mock_persistence.get_received_txn_response.return_value = ( self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None) defer.succeed(None)
@ -59,6 +60,7 @@ class FederationTestCase(unittest.TestCase):
self.mock_persistence.get_destination_retry_timings.return_value = ( self.mock_persistence.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0)) defer.succeed(DestinationsTable.EntryType("", 0, 0))
) )
self.mock_persistence.get_auth_chain.return_value = []
self.mock_config = Mock() self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()] self.mock_config.signing_key = [MockKey()]
self.clock = MockClock() self.clock = MockClock()

View file

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests import unittest from . import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock, patch from mock import Mock, patch
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.util.async import run_on_reactor
class DistributorTestCase(unittest.TestCase): class DistributorTestCase(unittest.TestCase):
@ -26,6 +27,7 @@ class DistributorTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dist = Distributor() self.dist = Distributor()
@defer.inlineCallbacks
def test_signal_dispatch(self): def test_signal_dispatch(self):
self.dist.declare("alert") self.dist.declare("alert")
@ -33,10 +35,11 @@ class DistributorTestCase(unittest.TestCase):
self.dist.observe("alert", observer) self.dist.observe("alert", observer)
d = self.dist.fire("alert", 1, 2, 3) d = self.dist.fire("alert", 1, 2, 3)
yield d
self.assertTrue(d.called) self.assertTrue(d.called)
observer.assert_called_with(1, 2, 3) observer.assert_called_with(1, 2, 3)
@defer.inlineCallbacks
def test_signal_dispatch_deferred(self): def test_signal_dispatch_deferred(self):
self.dist.declare("whine") self.dist.declare("whine")
@ -50,8 +53,10 @@ class DistributorTestCase(unittest.TestCase):
self.assertFalse(d_outer.called) self.assertFalse(d_outer.called)
d_inner.callback(None) d_inner.callback(None)
yield d_outer
self.assertTrue(d_outer.called) self.assertTrue(d_outer.called)
@defer.inlineCallbacks
def test_signal_catch(self): def test_signal_catch(self):
self.dist.declare("alarm") self.dist.declare("alarm")
@ -65,6 +70,7 @@ class DistributorTestCase(unittest.TestCase):
spec=["warning"] spec=["warning"]
) as mock_logger: ) as mock_logger:
d = self.dist.fire("alarm", "Go") d = self.dist.fire("alarm", "Go")
yield d
self.assertTrue(d.called) self.assertTrue(d.called)
observers[0].assert_called_once("Go") observers[0].assert_called_once("Go")
@ -81,23 +87,28 @@ class DistributorTestCase(unittest.TestCase):
self.dist.declare("whail") self.dist.declare("whail")
observer = Mock() class MyException(Exception):
observer.return_value = defer.fail( pass
Exception("Oopsie")
) @defer.inlineCallbacks
def observer():
yield run_on_reactor()
raise MyException("Oopsie")
self.dist.observe("whail", observer) self.dist.observe("whail", observer)
d = self.dist.fire("whail") d = self.dist.fire("whail")
yield self.assertFailure(d, Exception) yield self.assertFailure(d, MyException)
self.dist.suppress_failures = True
@defer.inlineCallbacks
def test_signal_prereg(self): def test_signal_prereg(self):
observer = Mock() observer = Mock()
self.dist.observe("flare", observer) self.dist.observe("flare", observer)
self.dist.declare("flare") self.dist.declare("flare")
self.dist.fire("flare", 4, 5) yield self.dist.fire("flare", 4, 5)
observer.assert_called_with(4, 5) observer.assert_called_with(4, 5)