mirror of
https://github.com/element-hq/synapse
synced 2024-09-28 16:32:40 +00:00
Merge branch 'erikj/reduce_size_of_cache' into erikj/merge_cache_prs
This commit is contained in:
commit
a99c692906
25 changed files with 364 additions and 248 deletions
1
changelog.d/9726.bugfix
Normal file
1
changelog.d/9726.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fixes the OIDC SSO flow when using a `public_baseurl` value including a non-root URL path.
|
1
changelog.d/9817.misc
Normal file
1
changelog.d/9817.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug which caused `max_upload_size` to not be correctly enforced.
|
1
changelog.d/9874.misc
Normal file
1
changelog.d/9874.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Pass a reactor into `SynapseSite` to make testing easier.
|
1
changelog.d/9876.misc
Normal file
1
changelog.d/9876.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.
|
1
changelog.d/9878.misc
Normal file
1
changelog.d/9878.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove redundant `_PushHTTPChannel` test class.
|
|
@ -12,14 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pymacaroons
|
||||
from netaddr import IPAddress
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
import synapse.types
|
||||
from synapse import event_auth
|
||||
from synapse.api.auth_blocking import AuthBlocking
|
||||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||
|
@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent
|
|||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging import opentracing as opentracing
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.types import Requester, StateMap, UserID, create_requester
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -68,7 +70,7 @@ class Auth:
|
|||
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -88,13 +90,13 @@ class Auth:
|
|||
|
||||
async def check_from_context(
|
||||
self, room_version: str, event, context, do_sig_check=True
|
||||
):
|
||||
) -> None:
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
auth_events_ids = self.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
auth_events_by_id = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
|
||||
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
event_auth.check(
|
||||
|
@ -151,17 +153,11 @@ class Auth:
|
|||
|
||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||
|
||||
async def check_host_in_room(self, room_id, host):
|
||||
async def check_host_in_room(self, room_id: str, host: str) -> bool:
|
||||
with Measure(self.clock, "check_host_in_room"):
|
||||
latest_event_ids = await self.store.is_host_joined(room_id, host)
|
||||
return latest_event_ids
|
||||
return await self.store.is_host_joined(room_id, host)
|
||||
|
||||
def can_federate(self, event, auth_events):
|
||||
creation_event = auth_events.get((EventTypes.Create, ""))
|
||||
|
||||
return creation_event.content.get("m.federate", True) is True
|
||||
|
||||
def get_public_keys(self, invite_event):
|
||||
def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
|
||||
return event_auth.get_public_keys(invite_event)
|
||||
|
||||
async def get_user_by_req(
|
||||
|
@ -170,7 +166,7 @@ class Auth:
|
|||
allow_guest: bool = False,
|
||||
rights: str = "access",
|
||||
allow_expired: bool = False,
|
||||
) -> synapse.types.Requester:
|
||||
) -> Requester:
|
||||
"""Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
|
@ -196,7 +192,7 @@ class Auth:
|
|||
access_token = self.get_access_token_from_request(request)
|
||||
|
||||
user_id, app_service = await self._get_appservice_user_id(request)
|
||||
if user_id:
|
||||
if user_id and app_service:
|
||||
if ip_addr and self._track_appservice_user_ips:
|
||||
await self.store.insert_client_ip(
|
||||
user_id=user_id,
|
||||
|
@ -206,9 +202,7 @@ class Auth:
|
|||
device_id="dummy-device", # stubbed
|
||||
)
|
||||
|
||||
requester = synapse.types.create_requester(
|
||||
user_id, app_service=app_service
|
||||
)
|
||||
requester = create_requester(user_id, app_service=app_service)
|
||||
|
||||
request.requester = user_id
|
||||
opentracing.set_tag("authenticated_entity", user_id)
|
||||
|
@ -251,7 +245,7 @@ class Auth:
|
|||
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
||||
)
|
||||
|
||||
requester = synapse.types.create_requester(
|
||||
requester = create_requester(
|
||||
user_info.user_id,
|
||||
token_id,
|
||||
is_guest,
|
||||
|
@ -271,7 +265,9 @@ class Auth:
|
|||
except KeyError:
|
||||
raise MissingClientTokenError()
|
||||
|
||||
async def _get_appservice_user_id(self, request):
|
||||
async def _get_appservice_user_id(
|
||||
self, request: Request
|
||||
) -> Tuple[Optional[str], Optional[ApplicationService]]:
|
||||
app_service = self.store.get_app_service_by_token(
|
||||
self.get_access_token_from_request(request)
|
||||
)
|
||||
|
@ -283,6 +279,9 @@ class Auth:
|
|||
if ip_address not in app_service.ip_range_whitelist:
|
||||
return None, None
|
||||
|
||||
# This will always be set by the time Twisted calls us.
|
||||
assert request.args is not None
|
||||
|
||||
if b"user_id" not in request.args:
|
||||
return app_service.sender, app_service
|
||||
|
||||
|
@ -387,7 +386,9 @@ class Auth:
|
|||
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
|
||||
raise InvalidClientTokenError("Invalid macaroon passed.")
|
||||
|
||||
def _parse_and_validate_macaroon(self, token, rights="access"):
|
||||
def _parse_and_validate_macaroon(
|
||||
self, token: str, rights: str = "access"
|
||||
) -> Tuple[str, bool]:
|
||||
"""Takes a macaroon and tries to parse and validate it. This is cached
|
||||
if and only if rights == access and there isn't an expiry.
|
||||
|
||||
|
@ -432,15 +433,16 @@ class Auth:
|
|||
|
||||
return user_id, guest
|
||||
|
||||
def validate_macaroon(self, macaroon, type_string, user_id):
|
||||
def validate_macaroon(
|
||||
self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
validate that a Macaroon is understood by and was signed by this server.
|
||||
|
||||
Args:
|
||||
macaroon(pymacaroons.Macaroon): The macaroon to validate
|
||||
type_string(str): The kind of token required (e.g. "access",
|
||||
"delete_pusher")
|
||||
user_id (str): The user_id required
|
||||
macaroon: The macaroon to validate
|
||||
type_string: The kind of token required (e.g. "access", "delete_pusher")
|
||||
user_id: The user_id required
|
||||
"""
|
||||
v = pymacaroons.Verifier()
|
||||
|
||||
|
@ -465,9 +467,7 @@ class Auth:
|
|||
if not service:
|
||||
logger.warning("Unrecognised appservice access token.")
|
||||
raise InvalidClientTokenError()
|
||||
request.requester = synapse.types.create_requester(
|
||||
service.sender, app_service=service
|
||||
)
|
||||
request.requester = create_requester(service.sender, app_service=service)
|
||||
return service
|
||||
|
||||
async def is_server_admin(self, user: UserID) -> bool:
|
||||
|
@ -519,7 +519,7 @@ class Auth:
|
|||
|
||||
return auth_ids
|
||||
|
||||
async def check_can_change_room_list(self, room_id: str, user: UserID):
|
||||
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
|
||||
"""Determine whether the user is allowed to edit the room's entry in the
|
||||
published room list.
|
||||
|
||||
|
@ -554,11 +554,11 @@ class Auth:
|
|||
return user_level >= send_level
|
||||
|
||||
@staticmethod
|
||||
def has_access_token(request: Request):
|
||||
def has_access_token(request: Request) -> bool:
|
||||
"""Checks if the request has an access_token.
|
||||
|
||||
Returns:
|
||||
bool: False if no access_token was given, True otherwise.
|
||||
False if no access_token was given, True otherwise.
|
||||
"""
|
||||
# This will always be set by the time Twisted calls us.
|
||||
assert request.args is not None
|
||||
|
@ -568,13 +568,13 @@ class Auth:
|
|||
return bool(query_params) or bool(auth_headers)
|
||||
|
||||
@staticmethod
|
||||
def get_access_token_from_request(request: Request):
|
||||
def get_access_token_from_request(request: Request) -> str:
|
||||
"""Extracts the access_token from the request.
|
||||
|
||||
Args:
|
||||
request: The http request.
|
||||
Returns:
|
||||
unicode: The access_token
|
||||
The access_token
|
||||
Raises:
|
||||
MissingClientTokenError: If there isn't a single access_token in the
|
||||
request
|
||||
|
@ -649,5 +649,5 @@ class Auth:
|
|||
% (user_id, room_id),
|
||||
)
|
||||
|
||||
def check_auth_blocking(self, *args, **kwargs):
|
||||
return self._auth_blocking.check_auth_blocking(*args, **kwargs)
|
||||
async def check_auth_blocking(self, *args, **kwargs) -> None:
|
||||
await self._auth_blocking.check_auth_blocking(*args, **kwargs)
|
||||
|
|
|
@ -13,18 +13,21 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.constants import LimitBlockingTypes, UserTypes
|
||||
from synapse.api.errors import Codes, ResourceLimitError
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.types import Requester
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthBlocking:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
|
@ -43,7 +46,7 @@ class AuthBlocking:
|
|||
threepid: Optional[dict] = None,
|
||||
user_type: Optional[str] = None,
|
||||
requester: Optional[Requester] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Checks if the user should be rejected for some external reason,
|
||||
such as monthly active user limiting or global disable flag
|
||||
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
|
||||
"""Contains constants from the specification."""
|
||||
|
||||
# the max size of a (canonical-json-encoded) event
|
||||
MAX_PDU_SIZE = 65536
|
||||
|
||||
# the "depth" field on events is limited to 2**63 - 1
|
||||
MAX_DEPTH = 2 ** 63 - 1
|
||||
|
||||
|
|
|
@ -30,9 +30,10 @@ from twisted.internet import defer, error, reactor
|
|||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import MAX_PDU_SIZE
|
||||
from synapse.app import check_bind_error
|
||||
from synapse.app.phone_stats_home import start_phone_stats_home
|
||||
from synapse.config.server import ListenerConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
|
@ -288,7 +289,7 @@ def refresh_certificate(hs):
|
|||
logger.info("Context factories updated.")
|
||||
|
||||
|
||||
async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
|
||||
async def start(hs: "synapse.server.HomeServer"):
|
||||
"""
|
||||
Start a Synapse server or worker.
|
||||
|
||||
|
@ -300,7 +301,6 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
|
|||
|
||||
Args:
|
||||
hs: homeserver instance
|
||||
listeners: Listener configuration ('listeners' in homeserver.yaml)
|
||||
"""
|
||||
# Set up the SIGHUP machinery.
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
|
@ -336,7 +336,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
|
|||
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
|
||||
|
||||
# It is now safe to start your Synapse.
|
||||
hs.start_listening(listeners)
|
||||
hs.start_listening()
|
||||
hs.get_datastore().db_pool.start_profiling()
|
||||
hs.get_pusherpool().start()
|
||||
|
||||
|
@ -530,3 +530,25 @@ def sdnotify(state):
|
|||
# this is a bit surprising, since we don't expect to have a NOTIFY_SOCKET
|
||||
# unless systemd is expecting us to notify it.
|
||||
logger.warning("Unable to send notification to systemd: %s", e)
|
||||
|
||||
|
||||
def max_request_body_size(config: HomeServerConfig) -> int:
|
||||
"""Get a suitable maximum size for incoming HTTP requests"""
|
||||
|
||||
# Other than media uploads, the biggest request we expect to see is a fully-loaded
|
||||
# /federation/v1/send request.
|
||||
#
|
||||
# The main thing in such a request is up to 50 PDUs, and up to 100 EDUs. PDUs are
|
||||
# limited to 65536 bytes (possibly slightly more if the sender didn't use canonical
|
||||
# json encoding); there is no specced limit to EDUs (see
|
||||
# https://github.com/matrix-org/matrix-doc/issues/3121).
|
||||
#
|
||||
# in short, we somewhat arbitrarily limit requests to 200 * 64K (about 12.5M)
|
||||
#
|
||||
max_request_size = 200 * MAX_PDU_SIZE
|
||||
|
||||
# if we have a media repo enabled, we may need to allow larger uploads than that
|
||||
if config.media.can_load_media_repo:
|
||||
max_request_size = max(max_request_size, config.media.max_upload_size)
|
||||
|
||||
return max_request_size
|
||||
|
|
|
@ -70,12 +70,6 @@ class AdminCmdSlavedStore(
|
|||
class AdminCmdServer(HomeServer):
|
||||
DATASTORE_CLASS = AdminCmdSlavedStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
pass
|
||||
|
||||
def start_listening(self, listeners):
|
||||
pass
|
||||
|
||||
|
||||
async def export_data_command(hs, args):
|
||||
"""Export data for a user.
|
||||
|
@ -232,7 +226,7 @@ def start(config_options):
|
|||
|
||||
async def run():
|
||||
with LoggingContext("command"):
|
||||
_base.start(ss, [])
|
||||
_base.start(ss)
|
||||
await args.func(ss, args)
|
||||
|
||||
_base.start_worker_reactor(
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Iterable, Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from twisted.internet import address
|
||||
from twisted.web.resource import IResource
|
||||
|
@ -32,7 +32,7 @@ from synapse.api.urls import (
|
|||
SERVER_KEY_V2_PREFIX,
|
||||
)
|
||||
from synapse.app import _base
|
||||
from synapse.app._base import register_start
|
||||
from synapse.app._base import max_request_body_size, register_start
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
|
@ -367,14 +367,16 @@ class GenericWorkerServer(HomeServer):
|
|||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
max_request_body_size=max_request_body_size(self.config),
|
||||
reactor=self.get_reactor(),
|
||||
),
|
||||
reactor=self.get_reactor(),
|
||||
)
|
||||
|
||||
logger.info("Synapse worker now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners: Iterable[ListenerConfig]):
|
||||
for listener in listeners:
|
||||
def start_listening(self):
|
||||
for listener in self.config.worker_listeners:
|
||||
if listener.type == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener.type == "manhole":
|
||||
|
@ -468,7 +470,7 @@ def start(config_options):
|
|||
# streams. Will no-op if no streams can be written to by this worker.
|
||||
hs.get_replication_streamer()
|
||||
|
||||
register_start(_base.start, hs, config.worker_listeners)
|
||||
register_start(_base.start, hs)
|
||||
|
||||
_base.start_worker_reactor("synapse-generic-worker", config)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Iterable, Iterator
|
||||
from typing import Iterator
|
||||
|
||||
from twisted.internet import reactor
|
||||
from twisted.web.resource import EncodingResourceWrapper, IResource
|
||||
|
@ -36,7 +36,13 @@ from synapse.api.urls import (
|
|||
WEB_CLIENT_PREFIX,
|
||||
)
|
||||
from synapse.app import _base
|
||||
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
|
||||
from synapse.app._base import (
|
||||
listen_ssl,
|
||||
listen_tcp,
|
||||
max_request_body_size,
|
||||
quit_with_error,
|
||||
register_start,
|
||||
)
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.emailconfig import ThreepidBehaviour
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
@ -126,19 +132,21 @@ class SynapseHomeServer(HomeServer):
|
|||
else:
|
||||
root_resource = OptionsResource()
|
||||
|
||||
root_resource = create_resource_tree(resources, root_resource)
|
||||
site = SynapseSite(
|
||||
"synapse.access.%s.%s" % ("https" if tls else "http", site_tag),
|
||||
site_tag,
|
||||
listener_config,
|
||||
create_resource_tree(resources, root_resource),
|
||||
self.version_string,
|
||||
max_request_body_size=max_request_body_size(self.config),
|
||||
reactor=self.get_reactor(),
|
||||
)
|
||||
|
||||
if tls:
|
||||
ports = listen_ssl(
|
||||
bind_addresses,
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.https.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
),
|
||||
site,
|
||||
self.tls_server_context_factory,
|
||||
reactor=self.get_reactor(),
|
||||
)
|
||||
|
@ -148,13 +156,7 @@ class SynapseHomeServer(HomeServer):
|
|||
ports = listen_tcp(
|
||||
bind_addresses,
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
self.version_string,
|
||||
),
|
||||
site,
|
||||
reactor=self.get_reactor(),
|
||||
)
|
||||
logger.info("Synapse now listening on TCP port %d", port)
|
||||
|
@ -273,14 +275,14 @@ class SynapseHomeServer(HomeServer):
|
|||
|
||||
return resources
|
||||
|
||||
def start_listening(self, listeners: Iterable[ListenerConfig]):
|
||||
def start_listening(self):
|
||||
if self.config.redis_enabled:
|
||||
# If redis is enabled we connect via the replication command handler
|
||||
# in the same way as the workers (since we're effectively a client
|
||||
# rather than a server).
|
||||
self.get_tcp_replication().start_replication(self)
|
||||
|
||||
for listener in listeners:
|
||||
for listener in self.config.server.listeners:
|
||||
if listener.type == "http":
|
||||
self._listening_services.extend(
|
||||
self._listener_http(self.config, listener)
|
||||
|
@ -413,7 +415,7 @@ def setup(config_options):
|
|||
# Loading the provider metadata also ensures the provider config is valid.
|
||||
await oidc.load_metadata()
|
||||
|
||||
await _base.start(hs, config.listeners)
|
||||
await _base.start(hs)
|
||||
|
||||
hs.get_datastore().db_pool.updates.start_doing_background_updates()
|
||||
|
||||
|
|
|
@ -31,7 +31,6 @@ from twisted.logger import (
|
|||
)
|
||||
|
||||
import synapse
|
||||
from synapse.app import _base as appbase
|
||||
from synapse.logging._structured import setup_structured_logging
|
||||
from synapse.logging.context import LoggingContextFilter
|
||||
from synapse.logging.filter import MetadataFilter
|
||||
|
@ -318,6 +317,8 @@ def setup_logging(
|
|||
# Perform one-time logging configuration.
|
||||
_setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
|
||||
# Add a SIGHUP handler to reload the logging configuration, if one is available.
|
||||
from synapse.app import _base as appbase
|
||||
|
||||
appbase.register_sighup(_reload_logging_config, log_config_path)
|
||||
|
||||
# Log immediately so we can grep backwards.
|
||||
|
|
|
@ -235,7 +235,11 @@ class ServerConfig(Config):
|
|||
self.print_pidfile = config.get("print_pidfile")
|
||||
self.user_agent_suffix = config.get("user_agent_suffix")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||
|
||||
self.public_baseurl = config.get("public_baseurl")
|
||||
if self.public_baseurl is not None:
|
||||
if self.public_baseurl[-1] != "/":
|
||||
self.public_baseurl += "/"
|
||||
|
||||
# Whether to enable user presence.
|
||||
presence_config = config.get("presence") or {}
|
||||
|
@ -407,10 +411,6 @@ class ServerConfig(Config):
|
|||
config_path=("federation_ip_range_blacklist",),
|
||||
)
|
||||
|
||||
if self.public_baseurl is not None:
|
||||
if self.public_baseurl[-1] != "/":
|
||||
self.public_baseurl += "/"
|
||||
|
||||
# (undocumented) option for torturing the worker-mode replication a bit,
|
||||
# for testing. The value defines the number of milliseconds to pause before
|
||||
# sending out any replication updates.
|
||||
|
|
|
@ -14,14 +14,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import SignatureVerifyException, verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.constants import MAX_PDU_SIZE, EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import AuthError, EventSizeError, SynapseError
|
||||
from synapse.api.room_versions import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
|
@ -205,7 +205,7 @@ def _check_size_limits(event: EventBase) -> None:
|
|||
too_big("type")
|
||||
if len(event.event_id) > 255:
|
||||
too_big("event_id")
|
||||
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
|
||||
if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE:
|
||||
too_big("event")
|
||||
|
||||
|
||||
|
@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
|
|||
return False
|
||||
|
||||
|
||||
def get_public_keys(invite_event):
|
||||
def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
|
||||
public_keys = []
|
||||
if "public_key" in invite_event.content:
|
||||
o = {"public_key": invite_event.content["public_key"]}
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import attr
|
||||
import pymacaroons
|
||||
|
@ -68,8 +68,8 @@ logger = logging.getLogger(__name__)
|
|||
#
|
||||
# Here we have the names of the cookies, and the options we use to set them.
|
||||
_SESSION_COOKIES = [
|
||||
(b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
|
||||
(b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
|
||||
(b"oidc_session", b"HttpOnly; Secure; SameSite=None"),
|
||||
(b"oidc_session_no_samesite", b"HttpOnly"),
|
||||
]
|
||||
|
||||
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
|
||||
|
@ -279,6 +279,13 @@ class OidcProvider:
|
|||
self._config = provider
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
|
||||
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
|
||||
# We'll insert this into the Path= parameter of any session cookies we set.
|
||||
public_baseurl_path = urlparse(hs.config.server.public_baseurl).path
|
||||
self._callback_path_prefix = (
|
||||
public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc"
|
||||
)
|
||||
|
||||
self._oidc_attribute_requirements = provider.attribute_requirements
|
||||
self._scopes = provider.scopes
|
||||
self._user_profile_method = provider.user_profile_method
|
||||
|
@ -779,8 +786,13 @@ class OidcProvider:
|
|||
|
||||
for cookie_name, options in _SESSION_COOKIES:
|
||||
request.cookies.append(
|
||||
b"%s=%s; Max-Age=3600; %s"
|
||||
% (cookie_name, cookie.encode("utf-8"), options)
|
||||
b"%s=%s; Max-Age=3600; Path=%s; %s"
|
||||
% (
|
||||
cookie_name,
|
||||
cookie.encode("utf-8"),
|
||||
self._callback_path_prefix,
|
||||
options,
|
||||
)
|
||||
)
|
||||
|
||||
metadata = await self.load_metadata()
|
||||
|
|
|
@ -14,13 +14,14 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from twisted.internet.interfaces import IAddress, IReactorTime
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.resource import IResource
|
||||
from twisted.web.server import Request, Site
|
||||
|
||||
from synapse.config.server import ListenerConfig
|
||||
|
@ -49,6 +50,7 @@ class SynapseRequest(Request):
|
|||
* Redaction of access_token query-params in __repr__
|
||||
* Logging at start and end
|
||||
* Metrics to record CPU, wallclock and DB time by endpoint.
|
||||
* A limit to the size of request which will be accepted
|
||||
|
||||
It also provides a method `processing`, which returns a context manager. If this
|
||||
method is called, the request won't be logged until the context manager is closed;
|
||||
|
@ -59,8 +61,9 @@ class SynapseRequest(Request):
|
|||
logcontext: the log context for this request
|
||||
"""
|
||||
|
||||
def __init__(self, channel, *args, **kw):
|
||||
def __init__(self, channel, *args, max_request_body_size=1024, **kw):
|
||||
Request.__init__(self, channel, *args, **kw)
|
||||
self._max_request_body_size = max_request_body_size
|
||||
self.site = channel.site # type: SynapseSite
|
||||
self._channel = channel # this is used by the tests
|
||||
self.start_time = 0.0
|
||||
|
@ -97,6 +100,18 @@ class SynapseRequest(Request):
|
|||
self.site.site_tag,
|
||||
)
|
||||
|
||||
def handleContentChunk(self, data):
|
||||
# we should have a `content` by now.
|
||||
assert self.content, "handleContentChunk() called before gotLength()"
|
||||
if self.content.tell() + len(data) > self._max_request_body_size:
|
||||
logger.warning(
|
||||
"Aborting connection from %s because the request exceeds maximum size",
|
||||
self.client,
|
||||
)
|
||||
self.transport.abortConnection()
|
||||
return
|
||||
super().handleContentChunk(data)
|
||||
|
||||
@property
|
||||
def requester(self) -> Optional[Union[Requester, str]]:
|
||||
return self._requester
|
||||
|
@ -485,29 +500,55 @@ class _XForwardedForAddress:
|
|||
|
||||
class SynapseSite(Site):
|
||||
"""
|
||||
Subclass of a twisted http Site that does access logging with python's
|
||||
standard logging
|
||||
Synapse-specific twisted http Site
|
||||
|
||||
This does two main things.
|
||||
|
||||
First, it replaces the requestFactory in use so that we build SynapseRequests
|
||||
instead of regular t.w.server.Requests. All of the constructor params are really
|
||||
just parameters for SynapseRequest.
|
||||
|
||||
Second, it inhibits the log() method called by Request.finish, since SynapseRequest
|
||||
does its own logging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger_name,
|
||||
site_tag,
|
||||
logger_name: str,
|
||||
site_tag: str,
|
||||
config: ListenerConfig,
|
||||
resource,
|
||||
resource: IResource,
|
||||
server_version_string,
|
||||
*args,
|
||||
**kwargs,
|
||||
max_request_body_size: int,
|
||||
reactor: IReactorTime,
|
||||
):
|
||||
Site.__init__(self, resource, *args, **kwargs)
|
||||
"""
|
||||
|
||||
Args:
|
||||
logger_name: The name of the logger to use for access logs.
|
||||
site_tag: A tag to use for this site - mostly in access logs.
|
||||
config: Configuration for the HTTP listener corresponding to this site
|
||||
resource: The base of the resource tree to be used for serving requests on
|
||||
this site
|
||||
server_version_string: A string to present for the Server header
|
||||
max_request_body_size: Maximum request body length to allow before
|
||||
dropping the connection
|
||||
reactor: reactor to be used to manage connection timeouts
|
||||
"""
|
||||
Site.__init__(self, resource, reactor=reactor)
|
||||
|
||||
self.site_tag = site_tag
|
||||
|
||||
assert config.http_options is not None
|
||||
proxied = config.http_options.x_forwarded
|
||||
self.requestFactory = (
|
||||
XForwardedForRequest if proxied else SynapseRequest
|
||||
) # type: Type[Request]
|
||||
request_class = XForwardedForRequest if proxied else SynapseRequest
|
||||
|
||||
def request_factory(channel, queued) -> Request:
|
||||
return request_class(
|
||||
channel, max_request_body_size=max_request_body_size, queued=queued
|
||||
)
|
||||
|
||||
self.requestFactory = request_factory # type: ignore
|
||||
self.access_logger = logging.getLogger(logger_name)
|
||||
self.server_version_string = server_version_string.encode("ascii")
|
||||
|
||||
|
|
|
@ -51,8 +51,6 @@ class UploadResource(DirectServeJsonResource):
|
|||
|
||||
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
if content_length is None:
|
||||
raise SynapseError(msg="Request must specify a Content-Length", code=400)
|
||||
|
|
|
@ -287,6 +287,14 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
if self.config.run_background_tasks:
|
||||
self.setup_background_tasks()
|
||||
|
||||
def start_listening(self) -> None:
|
||||
"""Start the HTTP, manhole, metrics, etc listeners
|
||||
|
||||
Does nothing in this base class; overridden in derived classes to start the
|
||||
appropriate listeners.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup_background_tasks(self) -> None:
|
||||
"""
|
||||
Some handlers have side effects on instantiation (like registering
|
||||
|
|
|
@ -17,8 +17,10 @@ from functools import wraps
|
|||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
|
@ -83,15 +85,30 @@ class _Node:
|
|||
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
|
||||
|
||||
def __init__(
|
||||
self, prev_node, next_node, key, value, callbacks: Optional[set] = None
|
||||
self,
|
||||
prev_node,
|
||||
next_node,
|
||||
key,
|
||||
value,
|
||||
callbacks: Collection[Callable[[], None]] = (),
|
||||
):
|
||||
self.prev_node = prev_node
|
||||
self.next_node = next_node
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.callbacks = callbacks or set()
|
||||
|
||||
self.memory = 0
|
||||
|
||||
# Set of callbacks to run when the node gets deleted. We store as a list
|
||||
# rather than a set to keep memory usage down (and since we expect few
|
||||
# entries per node the performance of checking for duplication in a list
|
||||
# vs using a set is negligible).
|
||||
#
|
||||
# Note that we store this as an optional list to keep the memory
|
||||
# footprint down. Empty lists are 56 bytes (and empty sets are 216 bytes).
|
||||
self.callbacks = None # type: Optional[List[Callable[[], None]]]
|
||||
|
||||
self.add_callbacks(callbacks)
|
||||
|
||||
if TRACK_MEMORY_USAGE:
|
||||
self.memory = (
|
||||
_get_size_of(key)
|
||||
|
@ -101,6 +118,32 @@ class _Node:
|
|||
)
|
||||
self.memory += _get_size_of(self.memory, recurse=False)
|
||||
|
||||
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
|
||||
"""Add to stored list of callbacks, removing duplicates."""
|
||||
|
||||
if not callbacks:
|
||||
return
|
||||
|
||||
if not self.callbacks:
|
||||
self.callbacks = []
|
||||
|
||||
for callback in callbacks:
|
||||
if callback not in self.callbacks:
|
||||
self.callbacks.append(callback)
|
||||
|
||||
def run_and_clear_callbacks(self) -> None:
|
||||
"""Run all callbacks and clear the stored set of callbacks. Used when
|
||||
the node is being deleted.
|
||||
"""
|
||||
|
||||
if not self.callbacks:
|
||||
return
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback()
|
||||
|
||||
self.callbacks = None
|
||||
|
||||
|
||||
class LruCache(Generic[KT, VT]):
|
||||
"""
|
||||
|
@ -213,10 +256,10 @@ class LruCache(Generic[KT, VT]):
|
|||
|
||||
self.len = synchronized(cache_len)
|
||||
|
||||
def add_node(key, value, callbacks: Optional[set] = None):
|
||||
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
|
||||
prev_node = list_root
|
||||
next_node = prev_node.next_node
|
||||
node = _Node(prev_node, next_node, key, value, callbacks or set())
|
||||
node = _Node(prev_node, next_node, key, value, callbacks)
|
||||
prev_node.next_node = node
|
||||
next_node.prev_node = node
|
||||
cache[key] = node
|
||||
|
@ -250,9 +293,7 @@ class LruCache(Generic[KT, VT]):
|
|||
deleted_len = size_callback(node.value)
|
||||
cached_cache_len[0] -= deleted_len
|
||||
|
||||
for cb in node.callbacks:
|
||||
cb()
|
||||
node.callbacks.clear()
|
||||
node.run_and_clear_callbacks()
|
||||
|
||||
if TRACK_MEMORY_USAGE and metrics:
|
||||
metrics.dec_memory_usage(node.memory)
|
||||
|
@ -263,7 +304,7 @@ class LruCache(Generic[KT, VT]):
|
|||
def cache_get(
|
||||
key: KT,
|
||||
default: Literal[None] = None,
|
||||
callbacks: Iterable[Callable[[], None]] = ...,
|
||||
callbacks: Collection[Callable[[], None]] = ...,
|
||||
update_metrics: bool = ...,
|
||||
) -> Optional[VT]:
|
||||
...
|
||||
|
@ -272,7 +313,7 @@ class LruCache(Generic[KT, VT]):
|
|||
def cache_get(
|
||||
key: KT,
|
||||
default: T,
|
||||
callbacks: Iterable[Callable[[], None]] = ...,
|
||||
callbacks: Collection[Callable[[], None]] = ...,
|
||||
update_metrics: bool = ...,
|
||||
) -> Union[T, VT]:
|
||||
...
|
||||
|
@ -281,13 +322,13 @@ class LruCache(Generic[KT, VT]):
|
|||
def cache_get(
|
||||
key: KT,
|
||||
default: Optional[T] = None,
|
||||
callbacks: Iterable[Callable[[], None]] = (),
|
||||
callbacks: Collection[Callable[[], None]] = (),
|
||||
update_metrics: bool = True,
|
||||
):
|
||||
node = cache.get(key, None)
|
||||
if node is not None:
|
||||
move_node_to_front(node)
|
||||
node.callbacks.update(callbacks)
|
||||
node.add_callbacks(callbacks)
|
||||
if update_metrics and metrics:
|
||||
metrics.inc_hits()
|
||||
return node.value
|
||||
|
@ -303,10 +344,8 @@ class LruCache(Generic[KT, VT]):
|
|||
# We sometimes store large objects, e.g. dicts, which cause
|
||||
# the inequality check to take a long time. So let's only do
|
||||
# the check if we have some callbacks to call.
|
||||
if node.callbacks and value != node.value:
|
||||
for cb in node.callbacks:
|
||||
cb()
|
||||
node.callbacks.clear()
|
||||
if value != node.value:
|
||||
node.run_and_clear_callbacks()
|
||||
|
||||
# We don't bother to protect this by value != node.value as
|
||||
# generally size_callback will be cheap compared with equality
|
||||
|
@ -316,7 +355,7 @@ class LruCache(Generic[KT, VT]):
|
|||
cached_cache_len[0] -= size_callback(node.value)
|
||||
cached_cache_len[0] += size_callback(value)
|
||||
|
||||
node.callbacks.update(callbacks)
|
||||
node.add_callbacks(callbacks)
|
||||
|
||||
move_node_to_front(node)
|
||||
node.value = value
|
||||
|
@ -369,8 +408,7 @@ class LruCache(Generic[KT, VT]):
|
|||
list_root.next_node = list_root
|
||||
list_root.prev_node = list_root
|
||||
for node in cache.values():
|
||||
for cb in node.callbacks:
|
||||
cb()
|
||||
node.run_and_clear_callbacks()
|
||||
cache.clear()
|
||||
if size_callback:
|
||||
cached_cache_len[0] = 0
|
||||
|
|
83
tests/http/test_site.py
Normal file
83
tests/http/test_site.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet.address import IPv6Address
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
from synapse.app.homeserver import SynapseHomeServer
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class SynapseRequestTestCase(HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
|
||||
|
||||
def test_large_request(self):
|
||||
"""overlarge HTTP requests should be rejected"""
|
||||
self.hs.start_listening()
|
||||
|
||||
# find the HTTP server which is configured to listen on port 0
|
||||
(port, factory, _backlog, interface) = self.reactor.tcpServers[0]
|
||||
self.assertEqual(interface, "::")
|
||||
self.assertEqual(port, 0)
|
||||
|
||||
# as a control case, first send a regular request.
|
||||
|
||||
# complete the connection and wire it up to a fake transport
|
||||
client_address = IPv6Address("TCP", "::1", "2345")
|
||||
protocol = factory.buildProtocol(client_address)
|
||||
transport = StringTransport()
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
protocol.dataReceived(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Connection: close\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"0\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
while not transport.disconnecting:
|
||||
self.reactor.advance(1)
|
||||
|
||||
# we should get a 404
|
||||
self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
|
||||
|
||||
# now send an oversized request
|
||||
protocol = factory.buildProtocol(client_address)
|
||||
transport = StringTransport()
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
protocol.dataReceived(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Connection: close\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
# we deliberately send all the data in one big chunk, to ensure that
|
||||
# twisted isn't buffering the data in the chunked transfer decoder.
|
||||
# we start with the chunk size, in hex. (We won't actually send this much)
|
||||
protocol.dataReceived(b"10000000\r\n")
|
||||
sent = 0
|
||||
while not transport.disconnected:
|
||||
self.assertLess(sent, 0x10000000, "connection did not drop")
|
||||
protocol.dataReceived(b"\0" * 1024)
|
||||
sent += 1024
|
||||
|
||||
# default max upload size is 50M, so it should drop on the next buffer after
|
||||
# that.
|
||||
self.assertEqual(sent, 50 * 1024 * 1024 + 1024)
|
|
@ -12,14 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.web.http import HTTPChannel
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Request, Site
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.http.server import JsonResource
|
||||
|
@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
|
|||
ServerReplicationStreamProtocol,
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
|
||||
channel = self.site.buildProtocol(None)
|
||||
|
||||
# hook into the channel's request factory so that we can keep a record
|
||||
# of the requests
|
||||
requests: List[SynapseRequest] = []
|
||||
real_request_factory = channel.requestFactory
|
||||
|
||||
def request_factory(*args, **kwargs):
|
||||
request = real_request_factory(*args, **kwargs)
|
||||
requests.append(request)
|
||||
return request
|
||||
|
||||
channel.requestFactory = request_factory
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
|
@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
server_to_client_transport.loseConnection()
|
||||
client_to_server_transport.loseConnection()
|
||||
|
||||
return channel.request
|
||||
# there should have been exactly one request
|
||||
self.assertEqual(len(requests), 1)
|
||||
|
||||
return requests[0]
|
||||
|
||||
def assert_request_is_get_repl_stream_updates(
|
||||
self, request: SynapseRequest, stream_name: str
|
||||
|
@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
config=worker_hs.config.server.listeners[0],
|
||||
resource=resource,
|
||||
server_version_string="1",
|
||||
max_request_body_size=4096,
|
||||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
if worker_hs.config.redis.redis_enabled:
|
||||
|
@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
|
||||
channel = self._hs_to_site[hs].buildProtocol(None)
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
|
@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
|
|||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
|
||||
class _PushHTTPChannel(HTTPChannel):
|
||||
"""A HTTPChannel that wraps pull producers to push producers.
|
||||
|
||||
This is a hack to get around the fact that HTTPChannel transparently wraps a
|
||||
pull producer (which is what Synapse uses to reply to requests) with
|
||||
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
|
||||
uses the standard reactor rather than letting us use our test reactor, which
|
||||
makes it very hard to test.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
|
||||
):
|
||||
super().__init__()
|
||||
self.reactor = reactor
|
||||
self.requestFactory = request_factory
|
||||
self.site = site
|
||||
|
||||
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
# Convert pull producers to push producer.
|
||||
if not streaming:
|
||||
self._pull_to_push_producer = _PullToPushProducer(
|
||||
self.reactor, producer, self
|
||||
)
|
||||
producer = self._pull_to_push_producer
|
||||
|
||||
super().registerProducer(producer, True)
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self._pull_to_push_producer:
|
||||
# We need to manually stop the _PullToPushProducer.
|
||||
self._pull_to_push_producer.stop()
|
||||
|
||||
def checkPersistence(self, request, version):
|
||||
"""Check whether the connection can be re-used"""
|
||||
# We hijack this to always say no for ease of wiring stuff up in
|
||||
# `handle_http_replication_attempt`.
|
||||
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
||||
return False
|
||||
|
||||
def requestDone(self, request):
|
||||
# Store the request for inspection.
|
||||
self.request = request
|
||||
super().requestDone(request)
|
||||
|
||||
|
||||
class _PullToPushProducer:
|
||||
"""A push producer that wraps a pull producer."""
|
||||
|
||||
def __init__(
|
||||
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
|
||||
):
|
||||
self._clock = Clock(reactor)
|
||||
self._producer = producer
|
||||
self._consumer = consumer
|
||||
|
||||
# While running we use a looping call with a zero delay to call
|
||||
# resumeProducing on given producer.
|
||||
self._looping_call = None # type: Optional[LoopingCall]
|
||||
|
||||
# We start writing next reactor tick.
|
||||
self._start_loop()
|
||||
|
||||
def _start_loop(self):
|
||||
"""Start the looping call to"""
|
||||
|
||||
if not self._looping_call:
|
||||
# Start a looping call which runs every tick.
|
||||
self._looping_call = self._clock.looping_call(self._run_once, 0)
|
||||
|
||||
def stop(self):
|
||||
"""Stops calling resumeProducing."""
|
||||
if self._looping_call:
|
||||
self._looping_call.stop()
|
||||
self._looping_call = None
|
||||
|
||||
def pauseProducing(self):
|
||||
"""Implements IPushProducer"""
|
||||
self.stop()
|
||||
|
||||
def resumeProducing(self):
|
||||
"""Implements IPushProducer"""
|
||||
self._start_loop()
|
||||
|
||||
def stopProducing(self):
|
||||
"""Implements IPushProducer"""
|
||||
self.stop()
|
||||
self._producer.stopProducing()
|
||||
|
||||
def _run_once(self):
|
||||
"""Calls resumeProducing on producer once."""
|
||||
|
||||
try:
|
||||
self._producer.resumeProducing()
|
||||
except Exception:
|
||||
logger.exception("Failed to call resumeProducing")
|
||||
try:
|
||||
self._consumer.unregisterProducer()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.stopProducing()
|
||||
|
||||
|
||||
class FakeRedisPubSubServer:
|
||||
"""A fake Redis server for pub/sub."""
|
||||
|
||||
|
|
|
@ -603,12 +603,6 @@ class FakeTransport:
|
|||
if self.disconnected:
|
||||
return
|
||||
|
||||
if not hasattr(self.other, "transport"):
|
||||
# the other has no transport yet; reschedule
|
||||
if self.autoflush:
|
||||
self._reactor.callLater(0.0, self.flush)
|
||||
return
|
||||
|
||||
if maxbytes is not None:
|
||||
to_write = self.buffer[:maxbytes]
|
||||
else:
|
||||
|
|
|
@ -202,6 +202,8 @@ class OptionsResourceTests(unittest.TestCase):
|
|||
parse_listener_def({"type": "http", "port": 0}),
|
||||
self.resource,
|
||||
"1.0",
|
||||
max_request_body_size=1234,
|
||||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
# render the request and return the channel
|
||||
|
|
|
@ -247,6 +247,8 @@ class HomeserverTestCase(TestCase):
|
|||
config=self.hs.config.server.listeners[0],
|
||||
resource=self.resource,
|
||||
server_version_string="1",
|
||||
max_request_body_size=1234,
|
||||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
from tests.rest.client.v1.utils import RestHelper
|
||||
|
|
Loading…
Reference in a new issue