Re-introduce the outbound federation proxy (#15913)

Allow configuring the set of workers to proxy outbound federation traffic through (`outbound_federation_restricted_to`).

This is useful when you have a worker setup with `federation_sender` instances responsible for sending outbound federation requests and want to make sure *all* outbound federation traffic goes through those instances. Before this change, the generic workers would still contact federation themselves for things like profile lookups, backfill, etc. This PR allows you to set more strict access controls/firewall for all workers and only allow the `federation_sender`'s to contact the outside world.
This commit is contained in:
Eric Eastwood 2023-07-18 03:49:21 -05:00 committed by GitHub
parent c692283751
commit 1c802de626
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 1128 additions and 96 deletions

View file

@ -0,0 +1 @@
Allow configuring the set of workers to proxy outbound federation traffic through via `outbound_federation_restricted_to`.

View file

@ -3960,13 +3960,14 @@ federation_sender_instances:
--- ---
### `instance_map` ### `instance_map`
When using workers this should be a map from [`worker_name`](#worker_name) to the When using workers this should be a map from [`worker_name`](#worker_name) to the HTTP
HTTP replication listener of the worker, if configured, and to the main process. replication listener of the worker, if configured, and to the main process. Each worker
Each worker declared under [`stream_writers`](../../workers.md#stream-writers) needs declared under [`stream_writers`](../../workers.md#stream-writers) and
a HTTP replication listener, and that listener should be included in the `instance_map`. [`outbound_federation_restricted_to`](#outbound_federation_restricted_to) needs a HTTP
The main process also needs an entry on the `instance_map`, and it should be listed under replication listener, and that listener should be included in the `instance_map`. The
`main` **if even one other worker exists**. Ensure the port matches with what is declared main process also needs an entry on the `instance_map`, and it should be listed under
inside the `listener` block for a `replication` listener. `main` **if even one other worker exists**. Ensure the port matches with what is
declared inside the `listener` block for a `replication` listener.
Example configuration: Example configuration:
@ -4004,6 +4005,24 @@ stream_writers:
typing: worker1 typing: worker1
``` ```
--- ---
### `outbound_federation_restricted_to`
When using workers, you can restrict outbound federation traffic to only go through a
specific subset of workers. Any worker specified here must also be in the
[`instance_map`](#instance_map).
[`worker_replication_secret`](#worker_replication_secret) must also be configured to
authorize inter-worker communication.
```yaml
outbound_federation_restricted_to:
- federation_sender1
- federation_sender2
```
Also see the [worker
documentation](../../workers.md#restrict-outbound-federation-traffic-to-a-specific-set-of-workers)
for more info.
---
### `run_background_tasks_on` ### `run_background_tasks_on`
The [worker](../../workers.md#background-tasks) that is used to run The [worker](../../workers.md#background-tasks) that is used to run

View file

@ -531,6 +531,30 @@ the stream writer for the `presence` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
#### Restrict outbound federation traffic to a specific set of workers
The
[`outbound_federation_restricted_to`](usage/configuration/config_documentation.md#outbound_federation_restricted_to)
configuration is useful to make sure outbound federation traffic only goes through a
specified subset of workers. This allows you to set more strict access controls (like a
firewall) for all workers and only allow the `federation_sender`'s to contact the
outside world.
```yaml
instance_map:
main:
host: localhost
port: 8030
federation_sender1:
host: localhost
port: 8034
outbound_federation_restricted_to:
- federation_sender1
worker_replication_secret: "secret_secret"
```
#### Background tasks #### Background tasks
There is also support for moving background tasks to a separate There is also support for moving background tasks to a separate

View file

@ -217,6 +217,13 @@ class InvalidAPICallError(SynapseError):
super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON) super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON)
class InvalidProxyCredentialsError(SynapseError):
"""Error raised when the proxy credentials are invalid."""
def __init__(self, msg: str, errcode: str = Codes.UNKNOWN):
super().__init__(401, msg, errcode)
class ProxiedRequestError(SynapseError): class ProxiedRequestError(SynapseError):
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call. """An error from a general matrix endpoint, eg. from a proxied Matrix API call.

View file

@ -386,6 +386,7 @@ def listen_unix(
def listen_http( def listen_http(
hs: "HomeServer",
listener_config: ListenerConfig, listener_config: ListenerConfig,
root_resource: Resource, root_resource: Resource,
version_string: str, version_string: str,
@ -406,6 +407,7 @@ def listen_http(
version_string, version_string,
max_request_body_size=max_request_body_size, max_request_body_size=max_request_body_size,
reactor=reactor, reactor=reactor,
hs=hs,
) )
if isinstance(listener_config, TCPListenerConfig): if isinstance(listener_config, TCPListenerConfig):

View file

@ -221,6 +221,7 @@ class GenericWorkerServer(HomeServer):
root_resource = create_resource_tree(resources, OptionsResource()) root_resource = create_resource_tree(resources, OptionsResource())
_base.listen_http( _base.listen_http(
self,
listener_config, listener_config,
root_resource, root_resource,
self.version_string, self.version_string,

View file

@ -139,6 +139,7 @@ class SynapseHomeServer(HomeServer):
root_resource = OptionsResource() root_resource = OptionsResource()
ports = listen_http( ports = listen_http(
self,
listener_config, listener_config,
create_resource_tree(resources, root_resource), create_resource_tree(resources, root_resource),
self.version_string, self.version_string,

View file

@ -15,7 +15,7 @@
import argparse import argparse
import logging import logging
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Union
import attr import attr
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
@ -171,6 +171,27 @@ class WriterLocations:
) )
@attr.s(auto_attribs=True)
class OutboundFederationRestrictedTo:
"""Whether we limit outbound federation to a certain set of instances.
Attributes:
instances: optional list of instances that can make outbound federation
requests. If None then all instances can make federation requests.
locations: list of instance locations to connect to proxy via.
"""
instances: Optional[List[str]]
locations: List[InstanceLocationConfig] = attr.Factory(list)
def __contains__(self, instance: str) -> bool:
# It feels a bit dirty to return `True` if `instances` is `None`, but it makes
# sense in downstream usage in the sense that if
# `outbound_federation_restricted_to` is not configured, then any instance can
# talk to federation (no restrictions so always return `True`).
return self.instances is None or instance in self.instances
class WorkerConfig(Config): class WorkerConfig(Config):
"""The workers are processes run separately to the main synapse process. """The workers are processes run separately to the main synapse process.
They have their own pid_file and listener configuration. They use the They have their own pid_file and listener configuration. They use the
@ -385,6 +406,28 @@ class WorkerConfig(Config):
new_option_name="update_user_directory_from_worker", new_option_name="update_user_directory_from_worker",
) )
outbound_federation_restricted_to = config.get(
"outbound_federation_restricted_to", None
)
self.outbound_federation_restricted_to = OutboundFederationRestrictedTo(
outbound_federation_restricted_to
)
if outbound_federation_restricted_to:
if not self.worker_replication_secret:
raise ConfigError(
"`worker_replication_secret` must be configured when using `outbound_federation_restricted_to`."
)
for instance in outbound_federation_restricted_to:
if instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured in 'outbound_federation_restricted_to' but does not appear in `instance_map` config."
% (instance,)
)
self.outbound_federation_restricted_to.locations.append(
self.instance_map[instance]
)
def _should_this_worker_perform_duty( def _should_this_worker_perform_duty(
self, self,
config: Dict[str, Any], config: Dict[str, Any],

View file

@ -1037,7 +1037,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
if reason.check(ResponseDone): if reason.check(ResponseDone):
self.deferred.callback(self.length) self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss): elif reason.check(PotentialDataLoss):
# stolen from https://github.com/twisted/treq/pull/49/files # This applies to requests which don't set `Content-Length` or a
# `Transfer-Encoding` in the response because in this case the end of the
# response is indicated by the connection being closed, an event which may
# also be due to a transient network problem or other error. But since this
# behavior is expected of some servers (like YouTube), let's ignore it.
# Stolen from https://github.com/twisted/treq/pull/49/files
# http://twistedmatrix.com/trac/ticket/4840 # http://twistedmatrix.com/trac/ticket/4840
self.deferred.callback(self.length) self.deferred.callback(self.length)
else: else:

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import base64 import base64
import logging import logging
from typing import Optional, Union from typing import Optional, Union
@ -39,8 +40,14 @@ class ProxyConnectError(ConnectError):
pass pass
@attr.s(auto_attribs=True)
class ProxyCredentials: class ProxyCredentials:
@abc.abstractmethod
def as_proxy_authorization_value(self) -> bytes:
raise NotImplementedError()
@attr.s(auto_attribs=True)
class BasicProxyCredentials(ProxyCredentials):
username_password: bytes username_password: bytes
def as_proxy_authorization_value(self) -> bytes: def as_proxy_authorization_value(self) -> bytes:
@ -55,6 +62,17 @@ class ProxyCredentials:
return b"Basic " + base64.encodebytes(self.username_password) return b"Basic " + base64.encodebytes(self.username_password)
@attr.s(auto_attribs=True)
class BearerProxyCredentials(ProxyCredentials):
access_token: bytes
def as_proxy_authorization_value(self) -> bytes:
"""
Return the value for a Proxy-Authorization header (i.e. 'Bearer xxx').
"""
return b"Bearer " + self.access_token
@implementer(IStreamClientEndpoint) @implementer(IStreamClientEndpoint)
class HTTPConnectProxyEndpoint: class HTTPConnectProxyEndpoint:
"""An Endpoint implementation which will send a CONNECT request to an http proxy """An Endpoint implementation which will send a CONNECT request to an http proxy

View file

@ -50,7 +50,7 @@ from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import Cooperator from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse from twisted.web.iweb import IAgent, IBodyProducer, IResponse
import synapse.metrics import synapse.metrics
import synapse.util.retryutils import synapse.util.retryutils
@ -71,7 +71,9 @@ from synapse.http.client import (
encode_query_args, encode_query_args,
read_body_with_max_size, read_body_with_max_size,
) )
from synapse.http.connectproxyclient import BearerProxyCredentials
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.proxyagent import ProxyAgent
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -393,17 +395,41 @@ class MatrixFederationHttpClient:
if hs.config.server.user_agent_suffix: if hs.config.server.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix) user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)
federation_agent = MatrixFederationAgent( outbound_federation_restricted_to = (
self.reactor, hs.config.worker.outbound_federation_restricted_to
tls_client_options_factory,
user_agent.encode("ascii"),
hs.config.server.federation_ip_range_allowlist,
hs.config.server.federation_ip_range_blocklist,
) )
if hs.get_instance_name() in outbound_federation_restricted_to:
# Talk to federation directly
federation_agent: IAgent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent.encode("ascii"),
hs.config.server.federation_ip_range_allowlist,
hs.config.server.federation_ip_range_blocklist,
)
else:
proxy_authorization_secret = hs.config.worker.worker_replication_secret
assert (
proxy_authorization_secret is not None
), "`worker_replication_secret` must be set when using `outbound_federation_restricted_to` (used to authenticate requests across workers)"
federation_proxy_credentials = BearerProxyCredentials(
proxy_authorization_secret.encode("ascii")
)
# We need to talk to federation via the proxy via one of the configured
# locations
federation_proxy_locations = outbound_federation_restricted_to.locations
federation_agent = ProxyAgent(
self.reactor,
self.reactor,
tls_client_options_factory,
federation_proxy_locations=federation_proxy_locations,
federation_proxy_credentials=federation_proxy_credentials,
)
# Use a BlocklistingAgentWrapper to prevent circumventing the IP # Use a BlocklistingAgentWrapper to prevent circumventing the IP
# blocking via IP literals in server names # blocking via IP literals in server names
self.agent = BlocklistingAgentWrapper( self.agent: IAgent = BlocklistingAgentWrapper(
federation_agent, federation_agent,
ip_blocklist=hs.config.server.federation_ip_range_blocklist, ip_blocklist=hs.config.server.federation_ip_range_blocklist,
) )
@ -412,7 +438,6 @@ class MatrixFederationHttpClient:
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self.version_string_bytes = hs.version_string.encode("ascii") self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout_seconds = hs.config.federation.client_timeout_ms / 1000 self.default_timeout_seconds = hs.config.federation.client_timeout_ms / 1000
self.max_long_retry_delay_seconds = ( self.max_long_retry_delay_seconds = (
hs.config.federation.max_long_retry_delay_ms / 1000 hs.config.federation.max_long_retry_delay_ms / 1000
) )
@ -1131,6 +1156,101 @@ class MatrixFederationHttpClient:
Succeeds when we get a 2xx HTTP response. The Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body. result will be the decoded JSON body.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
json_dict, _ = await self.get_json_with_headers(
destination=destination,
path=path,
args=args,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
ignore_backoff=ignore_backoff,
try_trailing_slash_on_400=try_trailing_slash_on_400,
parser=parser,
)
return json_dict
@overload
async def get_json_with_headers(
self,
destination: str,
path: str,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]:
...
@overload
async def get_json_with_headers(
self,
destination: str,
path: str,
args: Optional[QueryParams] = ...,
retry_on_dns_fail: bool = ...,
timeout: Optional[int] = ...,
ignore_backoff: bool = ...,
try_trailing_slash_on_400: bool = ...,
parser: ByteParser[T] = ...,
) -> Tuple[T, Dict[bytes, List[bytes]]]:
...
async def get_json_with_headers(
self,
destination: str,
path: str,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser[T]] = None,
) -> Tuple[Union[JsonDict, T], Dict[bytes, List[bytes]]]:
"""GETs some json from the given host homeserver and path
Args:
destination: The remote server to send the HTTP request to.
path: The HTTP path.
args: A dictionary used to create query strings, defaults to
None.
retry_on_dns_fail: true if the request should be retried on DNS failures
timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
timeout applies to the time spent waiting for response headers for
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
Returns:
Succeeds when we get a 2xx HTTP response. The result will be a tuple of the
decoded JSON body and a dict of the response headers.
Raises: Raises:
HttpResponseException: If we get an HTTP response code >= 300 HttpResponseException: If we get an HTTP response code >= 300
(except 429). (except 429).
@ -1156,6 +1276,8 @@ class MatrixFederationHttpClient:
timeout=timeout, timeout=timeout,
) )
headers = dict(response.headers.getAllRawHeaders())
if timeout is not None: if timeout is not None:
_sec_timeout = timeout / 1000 _sec_timeout = timeout / 1000
else: else:
@ -1173,7 +1295,7 @@ class MatrixFederationHttpClient:
parser=parser, parser=parser,
) )
return body return body, headers
async def delete_json( async def delete_json(
self, self,

283
synapse/http/proxy.py Normal file
View file

@ -0,0 +1,283 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import urllib.parse
from typing import TYPE_CHECKING, Any, Optional, Set, Tuple, cast
from twisted.internet import protocol
from twisted.internet.interfaces import ITCPTransport
from twisted.internet.protocol import connectionDone
from twisted.python import failure
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.api.errors import Codes, InvalidProxyCredentialsError
from synapse.http import QuieterFileBodyProducer
from synapse.http.server import _AsyncResource
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor
from synapse.util.async_helpers import timeout_deferred
if TYPE_CHECKING:
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616
# section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be
# consumed by the immediate recipient and not be forwarded on.
HOP_BY_HOP_HEADERS = {
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"TE",
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
def parse_connection_header_value(
connection_header_value: Optional[bytes],
) -> Set[str]:
"""
Parse the `Connection` header to determine which headers we should not be copied
over from the remote response.
As defined by RFC2616 section 14.10 and RFC9110 section 7.6.1
Example: `Connection: close, X-Foo, X-Bar` will return `{"Close", "X-Foo", "X-Bar"}`
Even though "close" is a special directive, let's just treat it as just another
header for simplicity. If people want to check for this directive, they can simply
check for `"Close" in headers`.
Args:
connection_header_value: The value of the `Connection` header.
Returns:
The set of header names that should not be copied over from the remote response.
The keys are capitalized in canonical capitalization.
"""
headers = Headers()
extra_headers_to_remove: Set[str] = set()
if connection_header_value:
extra_headers_to_remove = {
headers._canonicalNameCaps(connection_option.strip()).decode("ascii")
for connection_option in connection_header_value.split(b",")
}
return extra_headers_to_remove
class ProxyResource(_AsyncResource):
"""
A stub resource that proxies any requests with a `matrix-federation://` scheme
through the given `federation_agent` to the remote homeserver and ferries back the
info.
"""
isLeaf = True
def __init__(self, reactor: ISynapseReactor, hs: "HomeServer"):
super().__init__(True)
self.reactor = reactor
self.agent = hs.get_federation_http_client().agent
self._proxy_authorization_secret = hs.config.worker.worker_replication_secret
def _check_auth(self, request: Request) -> None:
# The `matrix-federation://` proxy functionality can only be used with auth.
# Protect homserver admins forgetting to configure a secret.
assert self._proxy_authorization_secret is not None
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Proxy-Authorization")
if not auth_headers:
raise InvalidProxyCredentialsError(
"Missing Proxy-Authorization header.", Codes.MISSING_TOKEN
)
if len(auth_headers) > 1:
raise InvalidProxyCredentialsError(
"Too many Proxy-Authorization headers.", Codes.UNAUTHORIZED
)
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
received_secret = parts[1].decode("ascii")
if self._proxy_authorization_secret == received_secret:
# Success!
return
raise InvalidProxyCredentialsError(
"Invalid Proxy-Authorization header.", Codes.UNAUTHORIZED
)
async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]:
uri = urllib.parse.urlparse(request.uri)
assert uri.scheme == b"matrix-federation"
# Check the authorization headers before handling the request.
self._check_auth(request)
headers = Headers()
for header_name in (b"User-Agent", b"Authorization", b"Content-Type"):
header_value = request.getHeader(header_name)
if header_value:
headers.addRawHeader(header_name, header_value)
request_deferred = run_in_background(
self.agent.request,
request.method,
request.uri,
headers=headers,
bodyProducer=QuieterFileBodyProducer(request.content),
)
request_deferred = timeout_deferred(
request_deferred,
# This should be set longer than the timeout in `MatrixFederationHttpClient`
# so that it has enough time to complete and pass us the data before we give
# up.
timeout=90,
reactor=self.reactor,
)
response = await make_deferred_yieldable(request_deferred)
return response.code, response
def _send_response(
self,
request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
response = cast(IResponse, response_object)
response_headers = cast(Headers, response.headers)
request.setResponseCode(code)
# The `Connection` header also defines which headers should not be copied over.
connection_header = response_headers.getRawHeaders(b"connection")
extra_headers_to_remove = parse_connection_header_value(
connection_header[0] if connection_header else None
)
# Copy headers.
for k, v in response_headers.getAllRawHeaders():
# Do not copy over any hop-by-hop headers. These are meant to only be
# consumed by the immediate recipient and not be forwarded on.
header_key = k.decode("ascii")
if (
header_key in HOP_BY_HOP_HEADERS
or header_key in extra_headers_to_remove
):
continue
request.responseHeaders.setRawHeaders(k, v)
response.deliverBody(_ProxyResponseBody(request))
def _send_error_response(
self,
f: failure.Failure,
request: "SynapseRequest",
) -> None:
if isinstance(f.value, InvalidProxyCredentialsError):
error_response_code = f.value.code
error_response_json = {"errcode": f.value.errcode, "err": f.value.msg}
else:
error_response_code = 502
error_response_json = {
"errcode": Codes.UNKNOWN,
"err": "ProxyResource: Error when proxying request: %s %s -> %s"
% (
request.method.decode("ascii"),
request.uri.decode("ascii"),
f,
),
}
request.setResponseCode(error_response_code)
request.setHeader(b"Content-Type", b"application/json")
request.write((json.dumps(error_response_json)).encode())
request.finish()
class _ProxyResponseBody(protocol.Protocol):
"""
A protocol that proxies the given remote response data back out to the given local
request.
"""
transport: Optional[ITCPTransport] = None
def __init__(self, request: "SynapseRequest") -> None:
self._request = request
def dataReceived(self, data: bytes) -> None:
# Avoid sending response data to the local request that already disconnected
if self._request._disconnected and self.transport is not None:
# Close the connection (forcefully) since all the data will get
# discarded anyway.
self.transport.abortConnection()
return
self._request.write(data)
def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the local request is already finished (successfully or failed), don't
# worry about sending anything back.
if self._request.finished:
return
if reason.check(ResponseDone):
self._request.finish()
else:
# Abort the underlying request since our remote request also failed.
self._request.transport.abortConnection()
class ProxySite(Site):
"""
Proxies any requests with a `matrix-federation://` scheme through the given
`federation_agent`. Otherwise, behaves like a normal `Site`.
"""
def __init__(
self,
resource: IResource,
reactor: ISynapseReactor,
hs: "HomeServer",
):
super().__init__(resource, reactor=reactor)
self._proxy_resource = ProxyResource(reactor, hs=hs)
def getResourceFor(self, request: "SynapseRequest") -> IResource:
uri = urllib.parse.urlparse(request.uri)
if uri.scheme == b"matrix-federation":
return self._proxy_resource
return super().getResourceFor(request)

View file

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import random
import re import re
from typing import Any, Dict, Optional, Tuple from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from urllib.request import ( # type: ignore[attr-defined] from urllib.request import ( # type: ignore[attr-defined]
getproxies_environment, getproxies_environment,
@ -23,8 +24,17 @@ from urllib.request import ( # type: ignore[attr-defined]
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import (
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint HostnameEndpoint,
UNIXClientEndpoint,
wrapClientTLS,
)
from twisted.internet.interfaces import (
IProtocol,
IProtocolFactory,
IReactorCore,
IStreamClientEndpoint,
)
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.client import ( from twisted.web.client import (
URI, URI,
@ -36,8 +46,18 @@ from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
from synapse.config.workers import (
InstanceLocationConfig,
InstanceTcpLocationConfig,
InstanceUnixLocationConfig,
)
from synapse.http import redact_uri from synapse.http import redact_uri
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials from synapse.http.connectproxyclient import (
BasicProxyCredentials,
HTTPConnectProxyEndpoint,
ProxyCredentials,
)
from synapse.logging.context import run_in_background
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -74,6 +94,14 @@ class ProxyAgent(_AgentBase):
use_proxy: Whether proxy settings should be discovered and used use_proxy: Whether proxy settings should be discovered and used
from conventional environment variables. from conventional environment variables.
federation_proxy_locations: An optional list of locations to proxy outbound federation
traffic through (only requests that use the `matrix-federation://` scheme
will be proxied).
federation_proxy_credentials: Required if `federation_proxy_locations` is set. The
credentials to use when proxying outbound federation traffic through another
worker.
Raises: Raises:
ValueError if use_proxy is set and the environment variables ValueError if use_proxy is set and the environment variables
contain an invalid proxy specification. contain an invalid proxy specification.
@ -89,6 +117,8 @@ class ProxyAgent(_AgentBase):
bindAddress: Optional[bytes] = None, bindAddress: Optional[bytes] = None,
pool: Optional[HTTPConnectionPool] = None, pool: Optional[HTTPConnectionPool] = None,
use_proxy: bool = False, use_proxy: bool = False,
federation_proxy_locations: Collection[InstanceLocationConfig] = (),
federation_proxy_credentials: Optional[ProxyCredentials] = None,
): ):
contextFactory = contextFactory or BrowserLikePolicyForHTTPS() contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
@ -127,6 +157,47 @@ class ProxyAgent(_AgentBase):
self._policy_for_https = contextFactory self._policy_for_https = contextFactory
self._reactor = reactor self._reactor = reactor
self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None
self._federation_proxy_credentials: Optional[ProxyCredentials] = None
if federation_proxy_locations:
assert (
federation_proxy_credentials is not None
), "`federation_proxy_credentials` are required when using `federation_proxy_locations`"
endpoints: List[IStreamClientEndpoint] = []
for federation_proxy_location in federation_proxy_locations:
endpoint: IStreamClientEndpoint
if isinstance(federation_proxy_location, InstanceTcpLocationConfig):
endpoint = HostnameEndpoint(
self.proxy_reactor,
federation_proxy_location.host,
federation_proxy_location.port,
)
if federation_proxy_location.tls:
tls_connection_creator = (
self._policy_for_https.creatorForNetloc(
federation_proxy_location.host.encode("utf-8"),
federation_proxy_location.port,
)
)
endpoint = wrapClientTLS(tls_connection_creator, endpoint)
elif isinstance(federation_proxy_location, InstanceUnixLocationConfig):
endpoint = UNIXClientEndpoint(
self.proxy_reactor, federation_proxy_location.path
)
else:
# It is supremely unlikely we ever hit this
raise SchemeNotSupported(
f"Unknown type of Endpoint requested, check {federation_proxy_location}"
)
endpoints.append(endpoint)
self._federation_proxy_endpoint = _RandomSampleEndpoints(endpoints)
self._federation_proxy_credentials = federation_proxy_credentials
def request( def request(
self, self,
method: bytes, method: bytes,
@ -214,6 +285,25 @@ class ProxyAgent(_AgentBase):
parsed_uri.port, parsed_uri.port,
self.https_proxy_creds, self.https_proxy_creds,
) )
elif (
parsed_uri.scheme == b"matrix-federation"
and self._federation_proxy_endpoint
):
assert (
self._federation_proxy_credentials is not None
), "`federation_proxy_credentials` are required when using `federation_proxy_locations`"
# Set a Proxy-Authorization header
if headers is None:
headers = Headers()
# We always need authentication for the outbound federation proxy
headers.addRawHeader(
b"Proxy-Authorization",
self._federation_proxy_credentials.as_proxy_authorization_value(),
)
endpoint = self._federation_proxy_endpoint
request_path = uri
else: else:
# not using a proxy # not using a proxy
endpoint = HostnameEndpoint( endpoint = HostnameEndpoint(
@ -233,6 +323,11 @@ class ProxyAgent(_AgentBase):
endpoint = wrapClientTLS(tls_connection_creator, endpoint) endpoint = wrapClientTLS(tls_connection_creator, endpoint)
elif parsed_uri.scheme == b"http": elif parsed_uri.scheme == b"http":
pass pass
elif (
parsed_uri.scheme == b"matrix-federation"
and self._federation_proxy_endpoint
):
pass
else: else:
return defer.fail( return defer.fail(
Failure( Failure(
@ -334,6 +429,42 @@ def parse_proxy(
credentials = None credentials = None
if url.username and url.password: if url.username and url.password:
credentials = ProxyCredentials(b"".join([url.username, b":", url.password])) credentials = BasicProxyCredentials(
b"".join([url.username, b":", url.password])
)
return url.scheme, url.hostname, url.port or default_port, credentials return url.scheme, url.hostname, url.port or default_port, credentials
@implementer(IStreamClientEndpoint)
class _RandomSampleEndpoints:
"""An endpoint that randomly iterates through a given list of endpoints at
each connection attempt.
"""
def __init__(
self,
endpoints: Sequence[IStreamClientEndpoint],
) -> None:
assert endpoints
self._endpoints = endpoints
def __repr__(self) -> str:
return f"<_RandomSampleEndpoints endpoints={self._endpoints}>"
def connect(
self, protocol_factory: IProtocolFactory
) -> "defer.Deferred[IProtocol]":
"""Implements IStreamClientEndpoint interface"""
return run_in_background(self._do_connect, protocol_factory)
async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
failures: List[Failure] = []
for endpoint in random.sample(self._endpoints, k=len(self._endpoints)):
try:
return await endpoint.connect(protocol_factory)
except Exception:
failures.append(Failure())
failures.pop().raiseException()

View file

@ -18,6 +18,7 @@ import html
import logging import logging
import types import types
import urllib import urllib
import urllib.parse
from http import HTTPStatus from http import HTTPStatus
from http.client import FOUND from http.client import FOUND
from inspect import isawaitable from inspect import isawaitable
@ -65,7 +66,6 @@ from synapse.api.errors import (
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import json_encoder from synapse.util import json_encoder
@ -76,6 +76,7 @@ from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING: if TYPE_CHECKING:
import opentracing import opentracing
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -102,7 +103,7 @@ HTTP_STATUS_REQUEST_CANCELLED = 499
def return_json_error( def return_json_error(
f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig] f: failure.Failure, request: "SynapseRequest", config: Optional[HomeServerConfig]
) -> None: ) -> None:
"""Sends a JSON error response to clients.""" """Sends a JSON error response to clients."""
@ -220,8 +221,8 @@ def return_html_error(
def wrap_async_request_handler( def wrap_async_request_handler(
h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]] h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]]
) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]: ) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing. """Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed This helps ensure that work done by the request handler after the request is completed
@ -235,7 +236,7 @@ def wrap_async_request_handler(
""" """
async def wrapped_async_request_handler( async def wrapped_async_request_handler(
self: "_AsyncResource", request: SynapseRequest self: "_AsyncResource", request: "SynapseRequest"
) -> None: ) -> None:
with request.processing(): with request.processing():
await h(self, request) await h(self, request)
@ -300,7 +301,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
self._extract_context = extract_context self._extract_context = extract_context
def render(self, request: SynapseRequest) -> int: def render(self, request: "SynapseRequest") -> int:
"""This gets called by twisted every time someone sends us a request.""" """This gets called by twisted every time someone sends us a request."""
request.render_deferred = defer.ensureDeferred( request.render_deferred = defer.ensureDeferred(
self._async_render_wrapper(request) self._async_render_wrapper(request)
@ -308,7 +309,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return NOT_DONE_YET return NOT_DONE_YET
@wrap_async_request_handler @wrap_async_request_handler
async def _async_render_wrapper(self, request: SynapseRequest) -> None: async def _async_render_wrapper(self, request: "SynapseRequest") -> None:
"""This is a wrapper that delegates to `_async_render` and handles """This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc. exceptions, return values, metrics, etc.
""" """
@ -326,9 +327,15 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
# of our stack, and thus gives us a sensible stack # of our stack, and thus gives us a sensible stack
# trace. # trace.
f = failure.Failure() f = failure.Failure()
logger.exception(
"Error handling request",
exc_info=(f.type, f.value, f.getTracebackObject()),
)
self._send_error_response(f, request) self._send_error_response(f, request)
async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]: async def _async_render(
self, request: "SynapseRequest"
) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for no appropriate method exists. Can be overridden in sub classes for
different routing. different routing.
@ -358,7 +365,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def _send_response( def _send_response(
self, self,
request: SynapseRequest, request: "SynapseRequest",
code: int, code: int,
response_object: Any, response_object: Any,
) -> None: ) -> None:
@ -368,7 +375,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
def _send_error_response( def _send_error_response(
self, self,
f: failure.Failure, f: failure.Failure,
request: SynapseRequest, request: "SynapseRequest",
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@ -384,7 +391,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response( def _send_response(
self, self,
request: SynapseRequest, request: "SynapseRequest",
code: int, code: int,
response_object: Any, response_object: Any,
) -> None: ) -> None:
@ -401,7 +408,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_error_response( def _send_error_response(
self, self,
f: failure.Failure, f: failure.Failure,
request: SynapseRequest, request: "SynapseRequest",
) -> None: ) -> None:
"""Implements _AsyncResource._send_error_response""" """Implements _AsyncResource._send_error_response"""
return_json_error(f, request, None) return_json_error(f, request, None)
@ -473,7 +480,7 @@ class JsonResource(DirectServeJsonResource):
) )
def _get_handler_for_request( def _get_handler_for_request(
self, request: SynapseRequest self, request: "SynapseRequest"
) -> Tuple[ServletCallback, str, Dict[str, str]]: ) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request. """Finds a callback method to handle the given request.
@ -503,7 +510,7 @@ class JsonResource(DirectServeJsonResource):
# Huh. No one wanted to handle that? Fiiiiiine. # Huh. No one wanted to handle that? Fiiiiiine.
raise UnrecognizedRequestError(code=404) raise UnrecognizedRequestError(code=404)
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request) callback, servlet_classname, group_dict = self._get_handler_for_request(request)
request.is_render_cancellable = is_function_cancellable(callback) request.is_render_cancellable = is_function_cancellable(callback)
@ -535,7 +542,7 @@ class JsonResource(DirectServeJsonResource):
def _send_error_response( def _send_error_response(
self, self,
f: failure.Failure, f: failure.Failure,
request: SynapseRequest, request: "SynapseRequest",
) -> None: ) -> None:
"""Implements _AsyncResource._send_error_response""" """Implements _AsyncResource._send_error_response"""
return_json_error(f, request, self.hs.config) return_json_error(f, request, self.hs.config)
@ -551,7 +558,7 @@ class DirectServeHtmlResource(_AsyncResource):
def _send_response( def _send_response(
self, self,
request: SynapseRequest, request: "SynapseRequest",
code: int, code: int,
response_object: Any, response_object: Any,
) -> None: ) -> None:
@ -565,7 +572,7 @@ class DirectServeHtmlResource(_AsyncResource):
def _send_error_response( def _send_error_response(
self, self,
f: failure.Failure, f: failure.Failure,
request: SynapseRequest, request: "SynapseRequest",
) -> None: ) -> None:
"""Implements _AsyncResource._send_error_response""" """Implements _AsyncResource._send_error_response"""
return_html_error(f, request, self.ERROR_TEMPLATE) return_html_error(f, request, self.ERROR_TEMPLATE)
@ -592,7 +599,7 @@ class UnrecognizedRequestResource(resource.Resource):
errcode of M_UNRECOGNIZED. errcode of M_UNRECOGNIZED.
""" """
def render(self, request: SynapseRequest) -> int: def render(self, request: "SynapseRequest") -> int:
f = failure.Failure(UnrecognizedRequestError(code=404)) f = failure.Failure(UnrecognizedRequestError(code=404))
return_json_error(f, request, None) return_json_error(f, request, None)
# A response has already been sent but Twisted requires either NOT_DONE_YET # A response has already been sent but Twisted requires either NOT_DONE_YET
@ -622,7 +629,7 @@ class RootRedirect(resource.Resource):
class OptionsResource(resource.Resource): class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children.""" """Responds to OPTION requests for itself and all children."""
def render_OPTIONS(self, request: SynapseRequest) -> bytes: def render_OPTIONS(self, request: "SynapseRequest") -> bytes:
request.setResponseCode(204) request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0") request.setHeader(b"Content-Length", b"0")
@ -737,7 +744,7 @@ def _encode_json_bytes(json_object: object) -> bytes:
def respond_with_json( def respond_with_json(
request: SynapseRequest, request: "SynapseRequest",
code: int, code: int,
json_object: Any, json_object: Any,
send_cors: bool = False, send_cors: bool = False,
@ -787,7 +794,7 @@ def respond_with_json(
def respond_with_json_bytes( def respond_with_json_bytes(
request: SynapseRequest, request: "SynapseRequest",
code: int, code: int,
json_bytes: bytes, json_bytes: bytes,
send_cors: bool = False, send_cors: bool = False,
@ -825,7 +832,7 @@ def respond_with_json_bytes(
async def _async_write_json_to_request_in_thread( async def _async_write_json_to_request_in_thread(
request: SynapseRequest, request: "SynapseRequest",
json_encoder: Callable[[Any], bytes], json_encoder: Callable[[Any], bytes],
json_object: Any, json_object: Any,
) -> None: ) -> None:
@ -883,7 +890,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator) _ByteProducer(request, bytes_generator)
def set_cors_headers(request: SynapseRequest) -> None: def set_cors_headers(request: "SynapseRequest") -> None:
"""Set the CORS headers so that javascript running in a web browsers can """Set the CORS headers so that javascript running in a web browsers can
use this API use this API
@ -981,7 +988,7 @@ def set_clickjacking_protection_headers(request: Request) -> None:
def respond_with_redirect( def respond_with_redirect(
request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False request: "SynapseRequest", url: bytes, statusCode: int = FOUND, cors: bool = False
) -> None: ) -> None:
""" """
Write a 302 (or other specified status code) response to the request, if it is still alive. Write a 302 (or other specified status code) response to the request, if it is still alive.

View file

@ -21,25 +21,29 @@ from zope.interface import implementer
from twisted.internet.address import UNIXAddress from twisted.internet.address import UNIXAddress
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IReactorTime from twisted.internet.interfaces import IAddress
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from twisted.web.resource import IResource, Resource from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site from twisted.web.server import Request
from synapse.config.server import ListenerConfig from synapse.config.server import ListenerConfig
from synapse.http import get_request_user_agent, redact_uri from synapse.http import get_request_user_agent, redact_uri
from synapse.http.proxy import ProxySite
from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import ( from synapse.logging.context import (
ContextRequest, ContextRequest,
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
) )
from synapse.types import Requester from synapse.types import ISynapseReactor, Requester
if TYPE_CHECKING: if TYPE_CHECKING:
import opentracing import opentracing
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_next_request_seq = 0 _next_request_seq = 0
@ -102,7 +106,7 @@ class SynapseRequest(Request):
# A boolean indicating whether `render_deferred` should be cancelled if the # A boolean indicating whether `render_deferred` should be cancelled if the
# client disconnects early. Expected to be set by the coroutine started by # client disconnects early. Expected to be set by the coroutine started by
# `Resource.render`, if rendering is asynchronous. # `Resource.render`, if rendering is asynchronous.
self.is_render_cancellable = False self.is_render_cancellable: bool = False
global _next_request_seq global _next_request_seq
self.request_seq = _next_request_seq self.request_seq = _next_request_seq
@ -601,7 +605,7 @@ class _XForwardedForAddress:
host: str host: str
class SynapseSite(Site): class SynapseSite(ProxySite):
""" """
Synapse-specific twisted http Site Synapse-specific twisted http Site
@ -623,7 +627,8 @@ class SynapseSite(Site):
resource: IResource, resource: IResource,
server_version_string: str, server_version_string: str,
max_request_body_size: int, max_request_body_size: int,
reactor: IReactorTime, reactor: ISynapseReactor,
hs: "HomeServer",
): ):
""" """
@ -638,7 +643,11 @@ class SynapseSite(Site):
dropping the connection dropping the connection
reactor: reactor to be used to manage connection timeouts reactor: reactor to be used to manage connection timeouts
""" """
Site.__init__(self, resource, reactor=reactor) super().__init__(
resource=resource,
reactor=reactor,
hs=hs,
)
self.site_tag = site_tag self.site_tag = site_tag
self.reactor = reactor self.reactor = reactor
@ -649,7 +658,9 @@ class SynapseSite(Site):
request_id_header = config.http_options.request_id_header request_id_header = config.http_options.request_id_header
self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 self.experimental_cors_msc3886: bool = (
config.http_options.experimental_cors_msc3886
)
def request_factory(channel: HTTPChannel, queued: bool) -> Request: def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class( return request_class(

View file

@ -31,9 +31,7 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase): class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs return hs
def default_config(self) -> JsonDict: def default_config(self) -> JsonDict:
@ -91,9 +89,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
@patch("synapse.app.homeserver.KeyResource", new=Mock()) @patch("synapse.app.homeserver.KeyResource", new=Mock())
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
federation_http_client=None, homeserver_to_use=SynapseHomeServer
)
return hs return hs
@parameterized.expand( @parameterized.expand(

View file

@ -41,7 +41,6 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.appservice_api = mock.Mock() self.appservice_api = mock.Mock()
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"server", "server",
federation_http_client=None,
application_service_api=self.appservice_api, application_service_api=self.appservice_api,
) )
handler = hs.get_device_handler() handler = hs.get_device_handler()
@ -401,7 +400,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server")
handler = hs.get_device_handler() handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler) assert isinstance(handler, DeviceHandler)
self.handler = handler self.handler = handler

View file

@ -57,7 +57,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(federation_http_client=None) hs = self.setup_test_homeserver()
self.handler = hs.get_federation_handler() self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs

View file

@ -993,7 +993,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"server", "server",
federation_http_client=None,
federation_sender=Mock(spec=FederationSender), federation_sender=Mock(spec=FederationSender),
) )
return hs return hs

View file

@ -17,6 +17,8 @@ import json
from typing import Dict, List, Set from typing import Dict, List, Set
from unittest.mock import ANY, Mock, call from unittest.mock import ANY, Mock, call
from netaddr import IPSet
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -24,6 +26,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.typing import TypingWriterHandler from synapse.handlers.typing import TypingWriterHandler
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
@ -76,6 +79,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the federation client too # we mock out the federation client too
self.mock_federation_client = Mock(spec=["put_json"]) self.mock_federation_client = Mock(spec=["put_json"])
self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
self.mock_federation_client.agent = MatrixFederationAgent(
reactor,
tls_client_options_factory=None,
user_agent=b"SynapseInTrialTest/0.0.0",
ip_allowlist=None,
ip_blocklist=IPSet(),
)
# the tests assume that we are starting at unix time 1000 # the tests assume that we are starting at unix time 1000
reactor.pump((1000,)) reactor.pump((1000,))

View file

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Generator from typing import Any, Dict, Generator
from unittest.mock import Mock from unittest.mock import ANY, Mock, create_autospec
from netaddr import IPSet from netaddr import IPSet
from parameterized import parameterized from parameterized import parameterized
@ -21,10 +21,12 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred, TimeoutError from twisted.internet.defer import Deferred, TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError from twisted.internet.error import ConnectingCancelledError, DNSLookupError
from twisted.test.proto_helpers import MemoryReactor, StringTransport from twisted.test.proto_helpers import MemoryReactor, StringTransport
from twisted.web.client import ResponseNeverReceived from twisted.web.client import Agent, ResponseNeverReceived
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from twisted.web.http_headers import Headers
from synapse.api.errors import RequestSendFailed from synapse.api.errors import HttpResponseException, RequestSendFailed
from synapse.config._base import ConfigError
from synapse.http.matrixfederationclient import ( from synapse.http.matrixfederationclient import (
ByteParser, ByteParser,
MatrixFederationHttpClient, MatrixFederationHttpClient,
@ -39,7 +41,9 @@ from synapse.logging.context import (
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeTransport from tests.server import FakeTransport
from tests.test_utils import FakeResponse
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -658,3 +662,275 @@ class FederationClientTests(HomeserverTestCase):
self.assertEqual(self.cl.max_short_retry_delay_seconds, 7) self.assertEqual(self.cl.max_short_retry_delay_seconds, 7)
self.assertEqual(self.cl.max_long_retries, 20) self.assertEqual(self.cl.max_long_retries, 20)
self.assertEqual(self.cl.max_short_retries, 5) self.assertEqual(self.cl.max_short_retries, 5)
class FederationClientProxyTests(BaseMultiWorkerStreamTestCase):
def default_config(self) -> Dict[str, Any]:
conf = super().default_config()
conf["instance_map"] = {
"main": {"host": "testserv", "port": 8765},
"federation_sender": {"host": "testserv", "port": 1001},
}
return conf
@override_config(
{
"outbound_federation_restricted_to": ["federation_sender"],
"worker_replication_secret": "secret",
}
)
def test_proxy_requests_through_federation_sender_worker(self) -> None:
"""
Test that all outbound federation requests go through the `federation_sender`
worker
"""
# Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
# so we can act like some remote server responding to requests
mock_client_on_federation_sender = Mock()
mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
# Create the `federation_sender` worker
self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "federation_sender"},
federation_http_client=mock_client_on_federation_sender,
)
# Fake `remoteserv:8008` responding to requests
mock_agent_on_federation_sender.request.side_effect = (
lambda *args, **kwargs: defer.succeed(
FakeResponse.json(
payload={
"foo": "bar",
}
)
)
)
# This federation request from the main process should be proxied through the
# `federation_sender` worker off to the remote server
test_request_from_main_process_d = defer.ensureDeferred(
self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar")
)
# Pump the reactor so our deferred goes through the motions
self.pump()
# Make sure that the request was proxied through the `federation_sender` worker
mock_agent_on_federation_sender.request.assert_called_once_with(
b"GET",
b"matrix-federation://remoteserv:8008/foo/bar",
headers=ANY,
bodyProducer=ANY,
)
# Make sure the response is as expected back on the main worker
res = self.successResultOf(test_request_from_main_process_d)
self.assertEqual(res, {"foo": "bar"})
@override_config(
{
"outbound_federation_restricted_to": ["federation_sender"],
"worker_replication_secret": "secret",
}
)
def test_proxy_request_with_network_error_through_federation_sender_worker(
self,
) -> None:
"""
Test that when the outbound federation request fails with a network related
error, a sensible error makes its way back to the main process.
"""
# Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
# so we can act like some remote server responding to requests
mock_client_on_federation_sender = Mock()
mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
# Create the `federation_sender` worker
self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "federation_sender"},
federation_http_client=mock_client_on_federation_sender,
)
# Fake `remoteserv:8008` responding to requests
mock_agent_on_federation_sender.request.side_effect = (
lambda *args, **kwargs: defer.fail(ResponseNeverReceived("fake error"))
)
# This federation request from the main process should be proxied through the
# `federation_sender` worker off to the remote server
test_request_from_main_process_d = defer.ensureDeferred(
self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar")
)
# Pump the reactor so our deferred goes through the motions. We pump with 10
# seconds (0.1 * 100) so the `MatrixFederationHttpClient` runs out of retries
# and finally passes along the error response.
self.pump(0.1)
# Make sure that the request was proxied through the `federation_sender` worker
mock_agent_on_federation_sender.request.assert_called_with(
b"GET",
b"matrix-federation://remoteserv:8008/foo/bar",
headers=ANY,
bodyProducer=ANY,
)
# Make sure we get some sort of error back on the main worker
failure_res = self.failureResultOf(test_request_from_main_process_d)
self.assertIsInstance(failure_res.value, RequestSendFailed)
self.assertIsInstance(failure_res.value.inner_exception, HttpResponseException)
self.assertEqual(failure_res.value.inner_exception.code, 502)
@override_config(
{
"outbound_federation_restricted_to": ["federation_sender"],
"worker_replication_secret": "secret",
}
)
def test_proxy_requests_and_discards_hop_by_hop_headers(self) -> None:
"""
Test to make sure hop-by-hop headers and addional headers defined in the
`Connection` header are discarded when proxying requests
"""
# Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
# so we can act like some remote server responding to requests
mock_client_on_federation_sender = Mock()
mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
# Create the `federation_sender` worker
self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "federation_sender"},
federation_http_client=mock_client_on_federation_sender,
)
# Fake `remoteserv:8008` responding to requests
mock_agent_on_federation_sender.request.side_effect = lambda *args, **kwargs: defer.succeed(
FakeResponse(
code=200,
body=b'{"foo": "bar"}',
headers=Headers(
{
"Content-Type": ["application/json"],
"Connection": ["close, X-Foo, X-Bar"],
# Should be removed because it's defined in the `Connection` header
"X-Foo": ["foo"],
"X-Bar": ["bar"],
# Should be removed because it's a hop-by-hop header
"Proxy-Authorization": "abcdef",
}
),
)
)
# This federation request from the main process should be proxied through the
# `federation_sender` worker off to the remote server
test_request_from_main_process_d = defer.ensureDeferred(
self.hs.get_federation_http_client().get_json_with_headers(
"remoteserv:8008", "foo/bar"
)
)
# Pump the reactor so our deferred goes through the motions
self.pump()
# Make sure that the request was proxied through the `federation_sender` worker
mock_agent_on_federation_sender.request.assert_called_once_with(
b"GET",
b"matrix-federation://remoteserv:8008/foo/bar",
headers=ANY,
bodyProducer=ANY,
)
res, headers = self.successResultOf(test_request_from_main_process_d)
header_names = set(headers.keys())
# Make sure the response does not include the hop-by-hop headers
self.assertNotIn(b"X-Foo", header_names)
self.assertNotIn(b"X-Bar", header_names)
self.assertNotIn(b"Proxy-Authorization", header_names)
# Make sure the response is as expected back on the main worker
self.assertEqual(res, {"foo": "bar"})
@override_config(
{
"outbound_federation_restricted_to": ["federation_sender"],
# `worker_replication_secret` is set here so that the test setup is able to pass
# but the actual homserver creation test is in the test body below
"worker_replication_secret": "secret",
}
)
def test_not_able_to_proxy_requests_through_federation_sender_worker_when_no_secret_configured(
self,
) -> None:
"""
Test that we aren't able to proxy any outbound federation requests when
`worker_replication_secret` is not configured.
"""
with self.assertRaises(ConfigError):
# Create the `federation_sender` worker
self.make_worker_hs(
"synapse.app.generic_worker",
{
"worker_name": "federation_sender",
# Test that we aren't able to proxy any outbound federation requests
# when `worker_replication_secret` is not configured.
"worker_replication_secret": None,
},
)
@override_config(
{
"outbound_federation_restricted_to": ["federation_sender"],
"worker_replication_secret": "secret",
}
)
def test_not_able_to_proxy_requests_through_federation_sender_worker_when_wrong_auth_given(
self,
) -> None:
"""
Test that we aren't able to proxy any outbound federation requests when the
wrong authorization is given.
"""
# Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
# so we can act like some remote server responding to requests
mock_client_on_federation_sender = Mock()
mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
# Create the `federation_sender` worker
self.make_worker_hs(
"synapse.app.generic_worker",
{
"worker_name": "federation_sender",
# Test that we aren't able to proxy any outbound federation requests
# when `worker_replication_secret` is wrong.
"worker_replication_secret": "wrong",
},
federation_http_client=mock_client_on_federation_sender,
)
# This federation request from the main process should be proxied through the
# `federation_sender` worker off but will fail here because it's using the wrong
# authorization.
test_request_from_main_process_d = defer.ensureDeferred(
self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar")
)
# Pump the reactor so our deferred goes through the motions. We pump with 10
# seconds (0.1 * 100) so the `MatrixFederationHttpClient` runs out of retries
# and finally passes along the error response.
self.pump(0.1)
# Make sure that the request was *NOT* proxied through the `federation_sender`
# worker
mock_agent_on_federation_sender.request.assert_not_called()
failure_res = self.failureResultOf(test_request_from_main_process_d)
self.assertIsInstance(failure_res.value, HttpResponseException)
self.assertEqual(failure_res.value.code, 401)

53
tests/http/test_proxy.py Normal file
View file

@ -0,0 +1,53 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Set
from parameterized import parameterized
from synapse.http.proxy import parse_connection_header_value
from tests.unittest import TestCase
class ProxyTests(TestCase):
@parameterized.expand(
[
[b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
# No whitespace
[b"close,X-Foo,X-Bar", {"Close", "X-Foo", "X-Bar"}],
# More whitespace
[b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
# "close" directive in not the first position
[b"X-Foo, X-Bar, close", {"X-Foo", "X-Bar", "Close"}],
# Normalizes header capitalization
[b"keep-alive, x-fOo, x-bAr", {"Keep-Alive", "X-Foo", "X-Bar"}],
# Handles header names with whitespace
[
b"keep-alive, x foo, x bar",
{"Keep-Alive", "X foo", "X bar"},
],
]
)
def test_parse_connection_header_value(
self,
connection_header_value: bytes,
expected_extra_headers_to_remove: Set[str],
) -> None:
"""
Tests that the connection header value is parsed correctly
"""
self.assertEqual(
expected_extra_headers_to_remove,
parse_connection_header_value(connection_header_value),
)

View file

@ -33,7 +33,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from synapse.http.client import BlocklistingReactorWrapper from synapse.http.client import BlocklistingReactorWrapper
from synapse.http.connectproxyclient import ProxyCredentials from synapse.http.connectproxyclient import BasicProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy from synapse.http.proxyagent import ProxyAgent, parse_proxy
from tests.http import ( from tests.http import (
@ -205,7 +205,7 @@ class ProxyParserTests(TestCase):
""" """
proxy_cred = None proxy_cred = None
if expected_credentials: if expected_credentials:
proxy_cred = ProxyCredentials(expected_credentials) proxy_cred = BasicProxyCredentials(expected_credentials)
self.assertEqual( self.assertEqual(
( (
expected_scheme, expected_scheme,

View file

@ -70,10 +70,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Make a new HomeServer object for the worker # Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver( self.worker_hs = self.setup_test_homeserver(
federation_http_client=None,
homeserver_to_use=GenericWorkerServer, homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(), config=self._get_worker_hs_config(),
reactor=self.reactor, reactor=self.reactor,
federation_http_client=None,
) )
# Since we use sqlite in memory databases we need to make sure the # Since we use sqlite in memory databases we need to make sure the
@ -385,6 +385,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
server_version_string="1", server_version_string="1",
max_request_body_size=8192, max_request_body_size=8192,
reactor=self.reactor, reactor=self.reactor,
hs=worker_hs,
) )
worker_hs.get_replication_command_handler().start_replication(worker_hs) worker_hs.get_replication_command_handler().start_replication(worker_hs)

View file

@ -14,14 +14,18 @@
import logging import logging
from unittest.mock import Mock from unittest.mock import Mock
from netaddr import IPSet
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.handlers.typing import TypingWriterHandler from synapse.handlers.typing import TypingWriterHandler
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import get_clock
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,13 +45,25 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
room.register_servlets, room.register_servlets,
] ]
def setUp(self) -> None:
super().setUp()
reactor, _ = get_clock()
self.matrix_federation_agent = MatrixFederationAgent(
reactor,
tls_client_options_factory=None,
user_agent=b"SynapseInTrialTest/0.0.0",
ip_allowlist=None,
ip_blocklist=IPSet(),
)
def test_send_event_single_sender(self) -> None: def test_send_event_single_sender(self) -> None:
"""Test that using a single federation sender worker correctly sends a """Test that using a single federation sender worker correctly sends a
new event. new event.
""" """
mock_client = Mock(spec=["put_json"]) mock_client = Mock(spec=["put_json"])
mock_client.put_json.return_value = make_awaitable({}) mock_client.put_json.return_value = make_awaitable({})
mock_client.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
{ {
@ -78,6 +94,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
""" """
mock_client1 = Mock(spec=["put_json"]) mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({}) mock_client1.put_json.return_value = make_awaitable({})
mock_client1.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
{ {
@ -92,6 +109,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"]) mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({}) mock_client2.put_json.return_value = make_awaitable({})
mock_client2.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
{ {
@ -145,6 +163,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
""" """
mock_client1 = Mock(spec=["put_json"]) mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({}) mock_client1.put_json.return_value = make_awaitable({})
mock_client1.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
{ {
@ -159,6 +178,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"]) mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({}) mock_client2.put_json.return_value = make_awaitable({})
mock_client2.agent = self.matrix_federation_agent
self.make_worker_hs( self.make_worker_hs(
"synapse.app.generic_worker", "synapse.app.generic_worker",
{ {

View file

@ -40,7 +40,6 @@ class PresenceTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
"red", "red",
federation_http_client=None,
federation_client=Mock(), federation_client=Mock(),
presence_handler=self.presence_handler, presence_handler=self.presence_handler,
) )

View file

@ -67,8 +67,6 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver( self.hs = self.setup_test_homeserver(
"red", "red",
federation_http_client=None,
federation_client=Mock(),
) )
self.hs.get_federation_handler = Mock() # type: ignore[assignment] self.hs.get_federation_handler = Mock() # type: ignore[assignment]

View file

@ -31,7 +31,7 @@ room_key: RoomKey = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server")
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs

View file

@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
servlets = [room.register_servlets] servlets = [room.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server")
return hs return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

View file

@ -45,9 +45,7 @@ def fake_listdir(filepath: str) -> List[str]:
class WorkerSchemaTests(HomeserverTestCase): class WorkerSchemaTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs return hs
def default_config(self) -> JsonDict: def default_config(self) -> JsonDict:

View file

@ -38,7 +38,7 @@ from tests.http.server._base import test_disconnect
from tests.server import ( from tests.server import (
FakeChannel, FakeChannel,
FakeSite, FakeSite,
ThreadedMemoryReactorClock, get_clock,
make_request, make_request,
setup_test_homeserver, setup_test_homeserver,
) )
@ -46,12 +46,11 @@ from tests.server import (
class JsonResourceTests(unittest.TestCase): class JsonResourceTests(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock() reactor, clock = get_clock()
self.hs_clock = Clock(self.reactor) self.reactor = reactor
self.homeserver = setup_test_homeserver( self.homeserver = setup_test_homeserver(
self.addCleanup, self.addCleanup,
federation_http_client=None, clock=clock,
clock=self.hs_clock,
reactor=self.reactor, reactor=self.reactor,
) )
@ -209,7 +208,13 @@ class JsonResourceTests(unittest.TestCase):
class OptionsResourceTests(unittest.TestCase): class OptionsResourceTests(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock() reactor, clock = get_clock()
self.reactor = reactor
self.homeserver = setup_test_homeserver(
self.addCleanup,
clock=clock,
reactor=self.reactor,
)
class DummyResource(Resource): class DummyResource(Resource):
isLeaf = True isLeaf = True
@ -242,6 +247,7 @@ class OptionsResourceTests(unittest.TestCase):
"1.0", "1.0",
max_request_body_size=4096, max_request_body_size=4096,
reactor=self.reactor, reactor=self.reactor,
hs=self.homeserver,
) )
# render the request and return the channel # render the request and return the channel
@ -344,7 +350,8 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
await self.callback(request) await self.callback(request)
def setUp(self) -> None: def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock() reactor, _ = get_clock()
self.reactor = reactor
def test_good_response(self) -> None: def test_good_response(self) -> None:
async def callback(request: SynapseRequest) -> None: async def callback(request: SynapseRequest) -> None:
@ -462,9 +469,9 @@ class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation.""" """Tests for `DirectServeJsonResource` cancellation."""
def setUp(self) -> None: def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock() reactor, clock = get_clock()
self.clock = Clock(self.reactor) self.reactor = reactor
self.resource = CancellableDirectServeJsonResource(self.clock) self.resource = CancellableDirectServeJsonResource(clock)
self.site = FakeSite(self.resource, self.reactor) self.site = FakeSite(self.resource, self.reactor)
def test_cancellable_disconnect(self) -> None: def test_cancellable_disconnect(self) -> None:
@ -496,9 +503,9 @@ class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation.""" """Tests for `DirectServeHtmlResource` cancellation."""
def setUp(self) -> None: def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock() reactor, clock = get_clock()
self.clock = Clock(self.reactor) self.reactor = reactor
self.resource = CancellableDirectServeHtmlResource(self.clock) self.resource = CancellableDirectServeHtmlResource(clock)
self.site = FakeSite(self.resource, self.reactor) self.site = FakeSite(self.resource, self.reactor)
def test_cancellable_disconnect(self) -> None: def test_cancellable_disconnect(self) -> None:

View file

@ -358,6 +358,7 @@ class HomeserverTestCase(TestCase):
server_version_string="1", server_version_string="1",
max_request_body_size=4096, max_request_body_size=4096,
reactor=self.reactor, reactor=self.reactor,
hs=self.hs,
) )
from tests.rest.client.utils import RestHelper from tests.rest.client.utils import RestHelper