Merge branch 'develop' of github.com:matrix-org/synapse into erikj/factor_remote_leave

This commit is contained in:
Erik Johnston 2018-03-13 13:17:08 +00:00
commit bf8e97bd3c
21 changed files with 154 additions and 120 deletions

View file

@ -27,7 +27,13 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastore()
self._clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,

View file

@ -58,6 +58,7 @@ class FederationClient(FederationBase):
self._clear_tried_cache, 60 * 1000, self._clear_tried_cache, 60 * 1000,
) )
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""

View file

@ -17,12 +17,14 @@ import logging
import simplejson as json import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import ( from synapse.federation.federation_base import (
FederationBase, FederationBase,
event_from_pdu_json, event_from_pdu_json,
) )
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
import synapse.metrics import synapse.metrics
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -56,6 +58,12 @@ class FederationServer(FederationBase):
self._server_linearizer = async.Linearizer("fed_server") self._server_linearizer = async.Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler") self._transaction_linearizer = async.Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store)
self.handler = None
self.registry = hs.get_federation_registry()
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
@ -67,35 +75,6 @@ class FederationServer(FederationBase):
""" """
self.handler = handler self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
@ -229,16 +208,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def received_edu(self, origin, edu_type, content): def received_edu(self, origin, edu_type, content):
received_edus_counter.inc() received_edus_counter.inc()
yield self.registry.on_edu(edu_type, origin, content)
if edu_type in self.edu_handlers:
try:
yield self.edu_handlers[edu_type](origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -328,14 +298,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_request(self, query_type, args): def on_query_request(self, query_type, args):
received_queries_counter.inc(query_type) received_queries_counter.inc(query_type)
resp = yield self.registry.on_query(query_type, args)
if query_type in self.query_handlers: defer.returnValue((200, resp))
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id): def on_make_join_request(self, room_id, user_id):
@ -607,3 +571,66 @@ class FederationServer(FederationBase):
origin, room_id, event_dict origin, room_id, event_dict
) )
defer.returnValue(ret) defer.returnValue(ret)
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
def register_edu_handler(self, edu_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
Args:
edu_type (str): The type of the incoming EDU to register handler for
handler (Callable[[str, dict]]): A callable invoked on incoming EDU
of the given type. The arguments are the origin server name and
the EDU contents.
"""
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (Callable[[dict], Deferred[dict]]): Invoked to handle
incoming queries of this type. The return will be yielded
on and the result used as the response to the query request.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type)
if not handler:
logger.warn("No handler registered for EDU type %s", edu_type)
try:
yield handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
if not handler:
logger.warn("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)

View file

@ -20,8 +20,6 @@ a given transport.
from .federation_client import FederationClient from .federation_client import FederationClient
from .federation_server import FederationServer from .federation_server import FederationServer
from .persistence import TransactionActions
import logging import logging
@ -47,26 +45,6 @@ class ReplicationLayer(FederationClient, FederationServer):
""" """
def __init__(self, hs, transport_layer): def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
self.federation_client = self
self.store = hs.get_datastore()
self.handler = None
self.edu_handlers = {}
self.query_handlers = {}
self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store)
self.hs = hs
super(ReplicationLayer, self).__init__(hs) super(ReplicationLayer, self).__init__(hs)
def __str__(self): def __str__(self):

View file

@ -41,10 +41,12 @@ class DeviceHandler(BaseHandler):
self._edu_updater = DeviceListEduUpdater(hs, self) self._edu_updater = DeviceListEduUpdater(hs, self)
self.federation.register_edu_handler( federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update, "m.device_list_update", self._edu_updater.incoming_device_list_update,
) )
self.federation.register_query_handler( federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices, "user_devices", self.on_federation_query_user_devices,
) )

View file

@ -37,7 +37,7 @@ class DeviceMessageHandler(object):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu "m.direct_to_device", self.on_direct_to_device_edu
) )

View file

@ -37,7 +37,7 @@ class DirectoryHandler(BaseHandler):
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"directory", self.on_directory_query "directory", self.on_directory_query
) )

View file

@ -40,7 +40,7 @@ class E2eKeysHandler(object):
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
# "query handler" interface. # "query handler" interface.
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )

View file

@ -667,7 +667,7 @@ class EventCreationHandler(object):
event (FrozenEvent) event (FrozenEvent)
context (EventContext) context (EventContext)
ratelimit (bool) ratelimit (bool)
extra_users (list(str)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
try: try:

View file

@ -98,24 +98,26 @@ class PresenceHandler(object):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.replication.register_edu_handler( federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_invite", "m.presence_invite",
lambda origin, content: self.invite_presence( lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_accept", "m.presence_accept",
lambda origin, content: self.accept_presence( lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_deny", "m.presence_deny",
lambda origin, content: self.deny_presence( lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),

View file

@ -32,7 +32,7 @@ class ProfileHandler(BaseHandler):
super(ProfileHandler, self).__init__(hs) super(ProfileHandler, self).__init__(hs)
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"profile", self.on_profile_query "profile", self.on_profile_query
) )

View file

@ -35,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()

View file

@ -446,16 +446,34 @@ class RegistrationHandler(BaseHandler):
return self.hs.get_auth_handler() return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if
one doesn't already exist.
Args:
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
Returns:
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
3PID guest account.
"""
access_token = yield self.store.get_3pid_guest_access_token(medium, address) access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token: if access_token:
defer.returnValue(access_token) user_info = yield self.auth.get_user_by_access_token(
access_token
)
_, access_token = yield self.register( defer.returnValue((user_info["user"].to_string(), access_token))
user_id, access_token = yield self.register(
generate_token=True, generate_token=True,
make_guest=True make_guest=True
) )
access_token = yield self.store.save_or_get_3pid_guest_access_token( access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id medium, address, access_token, inviter_user_id
) )
defer.returnValue(access_token)
defer.returnValue((user_id, access_token))

View file

@ -138,7 +138,7 @@ class RoomMemberHandler(object):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content): def _remote_join(self, remote_room_hosts, room_id, user, content):
"""Try and join a room that this server is not in """Try and join a room that this server is not in
Args: Args:
@ -342,7 +342,7 @@ class RoomMemberHandler(object):
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id) inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
@ -356,7 +356,7 @@ class RoomMemberHandler(object):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
ret = yield self.remote_join( ret = yield self._remote_join(
remote_room_hosts, room_id, target, content remote_room_hosts, room_id, target, content
) )
defer.returnValue(ret) defer.returnValue(ret)
@ -364,7 +364,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE: elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room: if not is_host_in_room:
# perhaps we've been invited # perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id) inviter = yield self._get_inviter(target.to_string(), room_id)
if not inviter: if not inviter:
raise SynapseError(404, "Not a known room") raise SynapseError(404, "Not a known room")
@ -528,7 +528,7 @@ class RoomMemberHandler(object):
defer.returnValue((RoomID.from_string(room_id), servers)) defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_inviter(self, user_id, room_id): def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room( invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id, user_id=user_id,
room_id=room_id, room_id=room_id,
@ -605,7 +605,7 @@ class RoomMemberHandler(object):
if "mxid" in data: if "mxid" in data:
if "signatures" not in data: if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding") raise AuthError(401, "No signatures on 3pid binding")
yield self.verify_any_signature(data, id_server) yield self._verify_any_signature(data, id_server)
defer.returnValue(data["mxid"]) defer.returnValue(data["mxid"])
except IOError as e: except IOError as e:
@ -613,7 +613,7 @@ class RoomMemberHandler(object):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname): def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]: if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
@ -767,20 +767,16 @@ class RoomMemberHandler(object):
} }
if self.config.invite_3pid_guest: if self.config.invite_3pid_guest:
registration_handler = self.registration_handler rh = self.registration_handler
guest_access_token = yield registration_handler.guest_access_token_for( guest_user_id, guest_access_token = yield rh.get_or_register_3pid_guest(
medium=medium, medium=medium,
address=address, address=address,
inviter_user_id=inviter_user_id, inviter_user_id=inviter_user_id,
) )
guest_user_info = yield self.auth.get_user_by_access_token(
guest_access_token
)
invite_config.update({ invite_config.update({
"guest_access_token": guest_access_token, "guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(), "guest_user_id": guest_user_id,
}) })
data = yield self.simple_http_client.post_urlencoded_get_json( data = yield self.simple_http_client.post_urlencoded_get_json(

View file

@ -56,7 +56,7 @@ class TypingHandler(object):
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)

View file

@ -25,7 +25,7 @@ from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import Requester from synapse.types import Requester, UserID
import logging import logging
import re import re
@ -46,7 +46,7 @@ def send_event_to_master(client, host, port, requester, event, context,
event (FrozenEvent) event (FrozenEvent)
context (EventContext) context (EventContext)
ratelimit (bool) ratelimit (bool)
extra_users (list(str)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
uri = "http://%s:%s/_synapse/replication/send_event/%s" % ( uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
host, port, event.event_id, host, port, event.event_id,
@ -59,7 +59,7 @@ def send_event_to_master(client, host, port, requester, event, context,
"context": context.serialize(event), "context": context.serialize(event),
"requester": requester.serialize(), "requester": requester.serialize(),
"ratelimit": ratelimit, "ratelimit": ratelimit,
"extra_users": extra_users, "extra_users": [u.to_string() for u in extra_users],
} }
try: try:
@ -143,7 +143,7 @@ class ReplicationSendEventRestServlet(RestServlet):
context = yield EventContext.deserialize(self.store, content["context"]) context = yield EventContext.deserialize(self.store, content["context"])
ratelimit = content["ratelimit"] ratelimit = content["ratelimit"]
extra_users = content["extra_users"] extra_users = [UserID.from_string(u) for u in content["extra_users"]]
if requester.user: if requester.user:
request.authenticated_entity = requester.user.to_string() request.authenticated_entity = requester.user.to_string()

View file

@ -599,7 +599,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|unban|kick|forget)") "(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -34,6 +34,7 @@ from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.federation_server import FederationHandlerRegistry
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
@ -147,6 +148,7 @@ class HomeServer(object):
'groups_attestation_renewer', 'groups_attestation_renewer',
'spam_checker', 'spam_checker',
'room_member_handler', 'room_member_handler',
'federation_registry',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -387,6 +389,9 @@ class HomeServer(object):
def build_room_member_handler(self): def build_room_member_handler(self):
return RoomMemberHandler(self) return RoomMemberHandler(self)
def build_federation_registry(self):
return FederationHandlerRegistry()
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View file

@ -283,10 +283,11 @@ class EventsStore(EventsWorkerStore):
def _maybe_start_persisting(self, room_id): def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def persisting_queue(item): def persisting_queue(item):
yield self._persist_events( with Measure(self._clock, "persist_events"):
item.events_and_contexts, yield self._persist_events(
backfilled=item.backfilled, item.events_and_contexts,
) backfilled=item.backfilled,
)
self._event_persist_queue.handle_queue(room_id, persisting_queue) self._event_persist_queue.handle_queue(room_id, persisting_queue)

View file

@ -35,21 +35,20 @@ class DirectoryTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock()
"make_query", self.mock_registry = Mock()
"register_edu_handler",
])
self.query_handlers = {} self.query_handlers = {}
def register_query_handler(query_type, handler): def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
replication_layer=self.mock_federation, replication_layer=self.mock_federation,
federation_registry=self.mock_registry,
) )
hs.handlers = DirectoryHandlers(hs) hs.handlers = DirectoryHandlers(hs)

View file

@ -37,23 +37,22 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock()
"make_query", self.mock_registry = Mock()
"register_edu_handler",
])
self.query_handlers = {} self.query_handlers = {}
def register_query_handler(query_type, handler): def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
handlers=None, handlers=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
replication_layer=self.mock_federation, replication_layer=self.mock_federation,
federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]) ])