Merge pull request #2459 from matrix-org/rav/keyring_cleanups

Clean up Keyring code
This commit is contained in:
Richard van der Hoff 2017-09-20 11:29:39 +01:00 committed by GitHub
commit c94ab5976a
5 changed files with 355 additions and 216 deletions

View file

@ -18,7 +18,7 @@ from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import ( from synapse.util.logcontext import (
preserve_context_over_fn, PreserveLoggingContext, PreserveLoggingContext,
preserve_fn preserve_fn
) )
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -57,7 +57,8 @@ Attributes:
json_object(dict): The JSON object to verify. json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred): deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched a verify key has been fetched. The deferreds' callbacks are run with no
logcontext.
""" """
@ -82,9 +83,11 @@ class Keyring(object):
self.key_downloads = {} self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
return self.verify_json_objects_for_server( return logcontext.make_deferred_yieldable(
[(server_name, json_object)] self.verify_json_objects_for_server(
)[0] [(server_name, json_object)]
)[0]
)
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as """Bulk verifies signatures of json objects, bulk fetching keys as
@ -94,8 +97,10 @@ class Keyring(object):
server_and_json (list): List of pairs of (server_name, json_object) server_and_json (list): List of pairs of (server_name, json_object)
Returns: Returns:
list of deferreds indicating success or failure to verify each List<Deferred>: for each input pair, a deferred indicating success
json object's signature for the given server_name. or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
""" """
verify_requests = [] verify_requests = []
@ -122,96 +127,72 @@ class Keyring(object):
verify_requests.append(verify_request) verify_requests.append(verify_request)
@defer.inlineCallbacks preserve_fn(self._start_key_lookups)(verify_requests)
def handle_key_deferred(verify_request):
server_name = verify_request.server_name
try:
_, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
json_object = verify_request.json_object
logger.debug("Got key %s %s:%s for server %s, verifying" % (
key_id, verify_key.alg, verify_key.version, server_name,
))
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
server_to_deferred = {
server_name: defer.Deferred()
for server_name, _ in server_and_json
}
with PreserveLoggingContext():
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(verify_requests)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_request_ids = {}
def remove_deferreds(res, server_name, verify_request):
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
verify_request.deferred.addBoth(
remove_deferreds, server_name, verify_request,
)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
handle = preserve_fn(_handle_key_deferred)
return [ return [
preserve_context_over_fn(handle_key_deferred, verify_request) handle(rq) for rq in verify_requests
for verify_request in verify_requests
] ]
@defer.inlineCallbacks
def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.deferred will be resolved.
Args:
verify_requests (List[VerifyKeyRequest]):
"""
# create a deferred for each server we're going to look up the keys
# for; we'll resolve them once we have completed our lookups.
# These will be passed into wait_for_previous_lookups to block
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = {
rq.server_name: defer.Deferred()
for rq in verify_requests
}
# We want to wait for any previous lookups to complete before
# proceeding.
yield self.wait_for_previous_lookups(
[rq.server_name for rq in verify_requests],
server_to_deferred,
)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
#
# map from server name to a set of request ids
server_to_request_ids = {}
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
def remove_deferreds(res, verify_request):
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for verify_request in verify_requests:
verify_request.deferred.addBoth(
remove_deferreds, verify_request,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred): def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
@ -247,7 +228,7 @@ class Keyring(object):
self.key_downloads[server_name] = deferred self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name) deferred.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request """Tries to find at least one key for each verify request
For each verify_request, verify_request.deferred is called back with For each verify_request, verify_request.deferred is called back with
@ -316,21 +297,23 @@ class Keyring(object):
if not missing_keys: if not missing_keys:
break break
for verify_request in requests_missing_keys.values(): with PreserveLoggingContext():
verify_request.deferred.errback(SynapseError( for verify_request in requests_missing_keys.values():
401, verify_request.deferred.errback(SynapseError(
"No key for %s with id %s" % ( 401,
verify_request.server_name, verify_request.key_ids, "No key for %s with id %s" % (
), verify_request.server_name, verify_request.key_ids,
Codes.UNAUTHORIZED, ),
)) Codes.UNAUTHORIZED,
))
def on_err(err): def on_err(err):
for verify_request in verify_requests: with PreserveLoggingContext():
if not verify_request.deferred.called: for verify_request in verify_requests:
verify_request.deferred.errback(err) if not verify_request.deferred.called:
verify_request.deferred.errback(err)
do_iterations().addErrback(on_err) preserve_fn(do_iterations)().addErrback(on_err)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
@ -740,3 +723,47 @@ class Keyring(object):
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError))
@defer.inlineCallbacks
def _handle_key_deferred(verify_request):
server_name = verify_request.server_name
try:
with PreserveLoggingContext():
_, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, verify_request.key_ids),
Codes.UNAUTHORIZED,
)
json_object = verify_request.json_object
logger.debug("Got key %s %s:%s for server %s, verifying" % (
key_id, verify_key.alg, verify_key.version, server_name,
))
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)

View file

@ -18,8 +18,7 @@ from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import spamcheck from synapse.events import spamcheck
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import preserve_context_over_deferred, preserve_fn
from twisted.internet import defer from twisted.internet import defer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,56 +50,52 @@ class FederationBase(object):
""" """
deferreds = self._check_sigs_and_hashes(pdus) deferreds = self._check_sigs_and_hashes(pdus)
def callback(pdu): @defer.inlineCallbacks
return pdu def handle_check_result(pdu, deferred):
try:
res = yield logcontext.make_deferred_yieldable(deferred)
except SynapseError:
res = None
def errback(failure, pdu):
failure.trap(SynapseError)
return None
def try_local_db(res, pdu):
if not res: if not res:
# Check local db. # Check local db.
return self.store.get_event( res = yield self.store.get_event(
pdu.event_id, pdu.event_id,
allow_rejected=True, allow_rejected=True,
allow_none=True, allow_none=True,
) )
return res
def try_remote(res, pdu):
if not res and pdu.origin != origin: if not res and pdu.origin != origin:
return self.get_pdu( try:
destinations=[pdu.origin], res = yield self.get_pdu(
event_id=pdu.event_id, destinations=[pdu.origin],
outlier=outlier, event_id=pdu.event_id,
timeout=10000, outlier=outlier,
).addErrback(lambda e: None) timeout=10000,
return res )
except SynapseError:
pass
def warn(res, pdu):
if not res: if not res:
logger.warn( logger.warn(
"Failed to find copy of %s with valid signature", "Failed to find copy of %s with valid signature",
pdu.event_id, pdu.event_id,
) )
return res
for pdu, deferred in zip(pdus, deferreds): defer.returnValue(res)
deferred.addCallbacks(
callback, errback, errbackArgs=[pdu] handle = logcontext.preserve_fn(handle_check_result)
).addCallback( deferreds2 = [
try_local_db, pdu handle(pdu, deferred)
).addCallback( for pdu, deferred in zip(pdus, deferreds)
try_remote, pdu ]
).addCallback(
warn, pdu valid_pdus = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
deferreds2,
consumeErrors=True,
) )
).addErrback(unwrapFirstError)
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
deferreds,
consumeErrors=True
)).addErrback(unwrapFirstError)
if include_none: if include_none:
defer.returnValue(valid_pdus) defer.returnValue(valid_pdus)
@ -108,7 +103,9 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p]) defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu): def _check_sigs_and_hash(self, pdu):
return self._check_sigs_and_hashes([pdu])[0] return logcontext.make_deferred_yieldable(
self._check_sigs_and_hashes([pdu])[0],
)
def _check_sigs_and_hashes(self, pdus): def _check_sigs_and_hashes(self, pdus):
"""Checks that each of the received events is correctly signed by the """Checks that each of the received events is correctly signed by the
@ -123,6 +120,7 @@ class FederationBase(object):
* returns a redacted version of the event (if the signature * returns a redacted version of the event (if the signature
matched but the hash did not) matched but the hash did not)
* throws a SynapseError if the signature check failed. * throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
""" """
redacted_pdus = [ redacted_pdus = [
@ -130,34 +128,38 @@ class FederationBase(object):
for pdu in pdus for pdu in pdus
] ]
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json()) (p.origin, p.get_pdu_json())
for p in redacted_pdus for p in redacted_pdus
]) ])
ctx = logcontext.LoggingContext.current_context()
def callback(_, pdu, redacted): def callback(_, pdu, redacted):
if not check_event_content_hash(pdu): with logcontext.PreserveLoggingContext(ctx):
logger.warn( if not check_event_content_hash(pdu):
"Event content has been tampered, redacting %s: %s", logger.warn(
pdu.event_id, pdu.get_pdu_json() "Event content has been tampered, redacting %s: %s",
) pdu.event_id, pdu.get_pdu_json()
return redacted )
return redacted
if spamcheck.check_event_for_spam(pdu): if spamcheck.check_event_for_spam(pdu):
logger.warn( logger.warn(
"Event contains spam, redacting %s: %s", "Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json() pdu.event_id, pdu.get_pdu_json()
) )
return redacted return redacted
return pdu return pdu
def errback(failure, pdu): def errback(failure, pdu):
failure.trap(SynapseError) failure.trap(SynapseError)
logger.warn( with logcontext.PreserveLoggingContext(ctx):
"Signature check failed for %s", logger.warn(
pdu.event_id, "Signature check failed for %s",
) pdu.event_id,
)
return failure return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):

View file

@ -22,7 +22,7 @@ from synapse.api.constants import Membership
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@ -189,10 +189,10 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus), self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
defer.returnValue(pdus) defer.returnValue(pdus)
@ -252,7 +252,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] signed_pdu = yield self._check_sigs_and_hash(pdu)
break break

View file

@ -113,30 +113,37 @@ class KeyStore(SQLBaseStore):
keys[key_id] = key keys[key_id] = key
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key): verify_key):
"""Stores a NACL verification key for the given server. """Stores a NACL verification key for the given server.
Args: Args:
server_name (str): The name of the server. server_name (str): The name of the server.
key_id (str): The version of the key for the server.
from_server (str): Where the verification key was looked up from_server (str): Where the verification key was looked up
ts_now_ms (int): The time now in milliseconds time_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key. verify_key (nacl.signing.VerifyKey): The NACL verify key.
""" """
yield self._simple_upsert( key_id = "%s:%s" % (verify_key.alg, verify_key.version)
table="server_signature_keys",
keyvalues={ def _txn(txn):
"server_name": server_name, self._simple_upsert_txn(
"key_id": "%s:%s" % (verify_key.alg, verify_key.version), txn,
}, table="server_signature_keys",
values={ keyvalues={
"from_server": from_server, "server_name": server_name,
"ts_added_ms": time_now_ms, "key_id": key_id,
"verify_key": buffer(verify_key.encode()), },
}, values={
desc="store_server_verify_key", "from_server": from_server,
) "ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()),
},
)
txn.call_after(
self._get_server_verify_key.invalidate,
(server_name, key_id)
)
return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):

View file

@ -12,39 +12,72 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import signedjson import time
import signedjson.key
import signedjson.sign
from mock import Mock from mock import Mock
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.util import async from synapse.util import async, logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests import unittest, utils from tests import unittest, utils
from twisted.internet import defer from twisted.internet import defer
class MockPerspectiveServer(object):
def __init__(self):
self.server_name = "mock_server"
self.key = signedjson.key.generate_signing_key(0)
def get_verify_keys(self):
vk = signedjson.key.get_verify_key(self.key)
return {
"%s:%s" % (vk.alg, vk.version): vk,
}
def get_signed_key(self, server_name, verify_key):
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
res = {
"server_name": server_name,
"old_verify_keys": {},
"valid_until_ts": time.time() * 1000 + 3600,
"verify_keys": {
key_id: {
"key": signedjson.key.encode_verify_key_base64(verify_key)
}
}
}
signedjson.sign.sign_json(res, self.server_name, self.key)
return res
class KeyringTestCase(unittest.TestCase): class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock() self.http_client = Mock()
self.hs = yield utils.setup_test_homeserver( self.hs = yield utils.setup_test_homeserver(
handlers=None, handlers=None,
http_client=self.http_client, http_client=self.http_client,
) )
self.hs.config.perspectives = { self.hs.config.perspectives = {
"persp_server": {"k": "v"} self.mock_perspective_server.server_name:
self.mock_perspective_server.get_verify_keys()
} }
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "test_key", None),
expected
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_wait_for_previous_lookups(self): def test_wait_for_previous_lookups(self):
sentinel_context = LoggingContext.current_context() sentinel_context = LoggingContext.current_context()
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
def check_context(_, expected):
self.assertEquals(
LoggingContext.current_context().test_key, expected
)
lookup_1_deferred = defer.Deferred() lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred() lookup_2_deferred = defer.Deferred()
@ -60,7 +93,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertTrue(wait_1_deferred.called) self.assertTrue(wait_1_deferred.called)
# ... so we should have preserved the LoggingContext. # ... so we should have preserved the LoggingContext.
self.assertIs(LoggingContext.current_context(), context_one) self.assertIs(LoggingContext.current_context(), context_one)
wait_1_deferred.addBoth(check_context, "one") wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two: with LoggingContext("two") as context_two:
context_two.test_key = "two" context_two.test_key = "two"
@ -74,7 +107,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertFalse(wait_2_deferred.called) self.assertFalse(wait_2_deferred.called)
# ... so we should have reset the LoggingContext. # ... so we should have reset the LoggingContext.
self.assertIs(LoggingContext.current_context(), sentinel_context) self.assertIs(LoggingContext.current_context(), sentinel_context)
wait_2_deferred.addBoth(check_context, "two") wait_2_deferred.addBoth(self.check_context, "two")
# let the first lookup complete (in the sentinel context) # let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None) lookup_1_deferred.callback(None)
@ -89,38 +122,108 @@ class KeyringTestCase(unittest.TestCase):
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
json1 = {} json1 = {}
signedjson.sign.sign_json(json1, "server1", key1) signedjson.sign.sign_json(json1, "server10", key1)
self.http_client.post_json.return_value = defer.Deferred() persp_resp = {
"server_keys": [
self.mock_perspective_server.get_signed_key(
"server10",
signedjson.key.get_verify_key(key1)
),
]
}
persp_deferred = defer.Deferred()
# start off a first set of lookups @defer.inlineCallbacks
res_deferreds = kr.verify_json_objects_for_server( def get_perspectives(**kwargs):
[("server1", json1), self.assertEquals(
("server2", {}) LoggingContext.current_context().test_key, "11",
] )
with logcontext.PreserveLoggingContext():
yield persp_deferred
defer.returnValue(persp_resp)
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
context_11.test_key = "11"
# start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1),
("server11", {})
]
)
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
yield res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
yield async.sleep(0.005)
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
context_12 = LoggingContext("12")
context_12.test_key = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
self.http_client.post_json.reset_mock()
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
yield async.sleep(0.005)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
# complete the first request
with logcontext.PreserveLoggingContext():
persp_deferred.callback(persp_resp)
self.assertIs(LoggingContext.current_context(), context_11)
with logcontext.PreserveLoggingContext():
yield res_deferreds[0]
yield res_deferreds_2[0]
@defer.inlineCallbacks
def test_verify_json_for_server(self):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
yield self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000,
signedjson.key.get_verify_key(key1),
) )
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
# the unsigned json should be rejected pretty quickly sentinel_context = LoggingContext.current_context()
try:
yield res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
self.assertFalse(res_deferreds[0].called) with LoggingContext("one") as context_one:
context_one.test_key = "one"
# wait a tick for it to send the request to the perspectives server defer = kr.verify_json_for_server("server9", {})
# (it first tries the datastore) try:
yield async.sleep(0.005) yield defer
self.http_client.post_json.assert_called_once() self.fail("should fail on unsigned json")
except SynapseError:
pass
self.assertIs(LoggingContext.current_context(), context_one)
# a second request for a server with outstanding requests should defer = kr.verify_json_for_server("server9", json1)
# block rather than start a second call self.assertFalse(defer.called)
self.http_client.post_json.reset_mock() self.assertIs(LoggingContext.current_context(), sentinel_context)
self.http_client.post_json.return_value = defer.Deferred() yield defer
kr.verify_json_objects_for_server( self.assertIs(LoggingContext.current_context(), context_one)
[("server1", json1)],
)
yield async.sleep(0.005)
self.http_client.post_json.assert_not_called()