From 51958df766667a75236764c9df57cfd60d57bd5e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Sun, 27 Jan 2019 23:24:17 +0000 Subject: [PATCH] MatrixFederationAgent: factor out routing logic This is going to get too big and unmanageable. --- .../federation/matrix_federation_agent.py | 80 ++++++++++++++----- 1 file changed, 62 insertions(+), 18 deletions(-) diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 9526f39cca..2a8bceed9a 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import attr from zope.interface import implementer from twisted.internet import defer @@ -85,9 +86,11 @@ class MatrixFederationAgent(object): response from being received (including problems that prevent the request from being sent). """ - parsed_uri = URI.fromBytes(uri, defaultPort=-1) + res = yield self._route_matrix_uri(parsed_uri) + # set up the TLS connection params + # # XXX disabling TLS is really only supported here for the benefit of the # unit tests. We should make the UTs cope with TLS rather than having to make # the code support the unit tests. @@ -95,22 +98,9 @@ class MatrixFederationAgent(object): tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - parsed_uri.host.decode("ascii") + res.tls_server_name.decode("ascii") ) - if parsed_uri.port != -1: - # there was an explicit port in the URI - target = parsed_uri.host, parsed_uri.port - else: - service_name = b"_matrix._tcp.%s" % (parsed_uri.host, ) - server_list = yield self._srv_resolver.resolve_service(service_name) - if not server_list: - target = (parsed_uri.host, 8448) - logger.debug( - "No SRV record for %s, using %s", service_name, target) - else: - target = pick_server_from_list(server_list) - # make sure that the Host header is set correctly if headers is None: headers = Headers() @@ -118,13 +108,13 @@ class MatrixFederationAgent(object): headers = headers.copy() if not headers.hasHeader(b'host'): - headers.addRawHeader(b'host', parsed_uri.netloc) + headers.addRawHeader(b'host', res.host_header) class EndpointFactory(object): @staticmethod def endpointForURI(_uri): - logger.info("Connecting to %s:%s", target[0], target[1]) - ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1]) + logger.info("Connecting to %s:%s", res.target_host, res.target_port) + ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port) if tls_options is not None: ep = wrapClientTLS(tls_options, ep) return ep @@ -134,3 +124,57 @@ class MatrixFederationAgent(object): agent.request(method, uri, headers, bodyProducer) ) defer.returnValue(res) + + @defer.inlineCallbacks + def _route_matrix_uri(self, parsed_uri): + """Helper for `request`: determine the routing for a Matrix URI + + Args: + parsed_uri (twisted.web.client.URI): uri to route. Note that it should be + parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1 + if there is no explicit port given. + + Returns: + Deferred[_RoutingResult] + """ + if parsed_uri.port != -1: + # there is an explicit port + defer.returnValue(_RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=parsed_uri.port, + )) + + # try a SRV lookup + service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) + server_list = yield self._srv_resolver.resolve_service(service_name) + + if not server_list: + target_host = parsed_uri.host + port = 8448 + logger.debug( + "No SRV record for %s, using %s:%i", + parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port, + ) + else: + target_host, port = pick_server_from_list(server_list) + logger.debug( + "Picked %s:%i from SRV records for %s", + target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"), + ) + + defer.returnValue(_RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=target_host, + target_port=port, + )) + + +@attr.s +class _RoutingResult(object): + host_header = attr.ib() + tls_server_name = attr.ib() + target_host = attr.ib() + target_port = attr.ib()