Merge branch 'erikj/reduce_size_of_cache' into erikj/merge_cache_prs

This commit is contained in:
Erik Johnston 2021-04-26 16:30:42 +01:00
commit a99c692906
25 changed files with 364 additions and 248 deletions

1
changelog.d/9726.bugfix Normal file
View file

@ -0,0 +1 @@
Fixes the OIDC SSO flow when using a `public_baseurl` value including a non-root URL path.

1
changelog.d/9817.misc Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug which caused `max_upload_size` to not be correctly enforced.

1
changelog.d/9874.misc Normal file
View file

@ -0,0 +1 @@
Pass a reactor into `SynapseSite` to make testing easier.

1
changelog.d/9876.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.

1
changelog.d/9878.misc Normal file
View file

@ -0,0 +1 @@
Remove redundant `_PushHTTPChannel` test class.

View file

@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import pymacaroons import pymacaroons
from netaddr import IPAddress from netaddr import IPAddress
from twisted.web.server import Request from twisted.web.server import Request
import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,7 +70,7 @@ class Auth:
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler. The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -88,13 +90,13 @@ class Auth:
async def check_from_context( async def check_from_context(
self, room_version: str, event, context, do_sig_check=True self, room_version: str, event, context, do_sig_check=True
): ) -> None:
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events( auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = await self.store.get_events(auth_events_ids) auth_events_by_id = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check( event_auth.check(
@ -151,17 +153,11 @@ class Auth:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
async def check_host_in_room(self, room_id, host): async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = await self.store.is_host_joined(room_id, host) return await self.store.is_host_joined(room_id, host)
return latest_event_ids
def can_federate(self, event, auth_events): def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event) return event_auth.get_public_keys(invite_event)
async def get_user_by_req( async def get_user_by_req(
@ -170,7 +166,7 @@ class Auth:
allow_guest: bool = False, allow_guest: bool = False,
rights: str = "access", rights: str = "access",
allow_expired: bool = False, allow_expired: bool = False,
) -> synapse.types.Requester: ) -> Requester:
"""Get a registered user's ID. """Get a registered user's ID.
Args: Args:
@ -196,7 +192,7 @@ class Auth:
access_token = self.get_access_token_from_request(request) access_token = self.get_access_token_from_request(request)
user_id, app_service = await self._get_appservice_user_id(request) user_id, app_service = await self._get_appservice_user_id(request)
if user_id: if user_id and app_service:
if ip_addr and self._track_appservice_user_ips: if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user_id, user_id=user_id,
@ -206,9 +202,7 @@ class Auth:
device_id="dummy-device", # stubbed device_id="dummy-device", # stubbed
) )
requester = synapse.types.create_requester( requester = create_requester(user_id, app_service=app_service)
user_id, app_service=app_service
)
request.requester = user_id request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("authenticated_entity", user_id)
@ -251,7 +245,7 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN, errcode=Codes.GUEST_ACCESS_FORBIDDEN,
) )
requester = synapse.types.create_requester( requester = create_requester(
user_info.user_id, user_info.user_id,
token_id, token_id,
is_guest, is_guest,
@ -271,7 +265,9 @@ class Auth:
except KeyError: except KeyError:
raise MissingClientTokenError() raise MissingClientTokenError()
async def _get_appservice_user_id(self, request): async def _get_appservice_user_id(
self, request: Request
) -> Tuple[Optional[str], Optional[ApplicationService]]:
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request) self.get_access_token_from_request(request)
) )
@ -283,6 +279,9 @@ class Auth:
if ip_address not in app_service.ip_range_whitelist: if ip_address not in app_service.ip_range_whitelist:
return None, None return None, None
# This will always be set by the time Twisted calls us.
assert request.args is not None
if b"user_id" not in request.args: if b"user_id" not in request.args:
return app_service.sender, app_service return app_service.sender, app_service
@ -387,7 +386,9 @@ class Auth:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e) logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise InvalidClientTokenError("Invalid macaroon passed.") raise InvalidClientTokenError("Invalid macaroon passed.")
def _parse_and_validate_macaroon(self, token, rights="access"): def _parse_and_validate_macaroon(
self, token: str, rights: str = "access"
) -> Tuple[str, bool]:
"""Takes a macaroon and tries to parse and validate it. This is cached """Takes a macaroon and tries to parse and validate it. This is cached
if and only if rights == access and there isn't an expiry. if and only if rights == access and there isn't an expiry.
@ -432,15 +433,16 @@ class Auth:
return user_id, guest return user_id, guest
def validate_macaroon(self, macaroon, type_string, user_id): def validate_macaroon(
self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
) -> None:
""" """
validate that a Macaroon is understood by and was signed by this server. validate that a Macaroon is understood by and was signed by this server.
Args: Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon: The macaroon to validate
type_string(str): The kind of token required (e.g. "access", type_string: The kind of token required (e.g. "access", "delete_pusher")
"delete_pusher") user_id: The user_id required
user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
@ -465,9 +467,7 @@ class Auth:
if not service: if not service:
logger.warning("Unrecognised appservice access token.") logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError() raise InvalidClientTokenError()
request.requester = synapse.types.create_requester( request.requester = create_requester(service.sender, app_service=service)
service.sender, app_service=service
)
return service return service
async def is_server_admin(self, user: UserID) -> bool: async def is_server_admin(self, user: UserID) -> bool:
@ -519,7 +519,7 @@ class Auth:
return auth_ids return auth_ids
async def check_can_change_room_list(self, room_id: str, user: UserID): async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the """Determine whether the user is allowed to edit the room's entry in the
published room list. published room list.
@ -554,11 +554,11 @@ class Auth:
return user_level >= send_level return user_level >= send_level
@staticmethod @staticmethod
def has_access_token(request: Request): def has_access_token(request: Request) -> bool:
"""Checks if the request has an access_token. """Checks if the request has an access_token.
Returns: Returns:
bool: False if no access_token was given, True otherwise. False if no access_token was given, True otherwise.
""" """
# This will always be set by the time Twisted calls us. # This will always be set by the time Twisted calls us.
assert request.args is not None assert request.args is not None
@ -568,13 +568,13 @@ class Auth:
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@staticmethod @staticmethod
def get_access_token_from_request(request: Request): def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request. """Extracts the access_token from the request.
Args: Args:
request: The http request. request: The http request.
Returns: Returns:
unicode: The access_token The access_token
Raises: Raises:
MissingClientTokenError: If there isn't a single access_token in the MissingClientTokenError: If there isn't a single access_token in the
request request
@ -649,5 +649,5 @@ class Auth:
% (user_id, room_id), % (user_id, room_id),
) )
def check_auth_blocking(self, *args, **kwargs): async def check_auth_blocking(self, *args, **kwargs) -> None:
return self._auth_blocking.check_auth_blocking(*args, **kwargs) await self._auth_blocking.check_auth_blocking(*args, **kwargs)

View file

@ -13,18 +13,21 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.types import Requester from synapse.types import Requester
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AuthBlocking: class AuthBlocking:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
@ -43,7 +46,7 @@ class AuthBlocking:
threepid: Optional[dict] = None, threepid: Optional[dict] = None,
user_type: Optional[str] = None, user_type: Optional[str] = None,
requester: Optional[Requester] = None, requester: Optional[Requester] = None,
): ) -> None:
"""Checks if the user should be rejected for some external reason, """Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag

View file

@ -17,6 +17,9 @@
"""Contains constants from the specification.""" """Contains constants from the specification."""
# the max size of a (canonical-json-encoded) event
MAX_PDU_SIZE = 65536
# the "depth" field on events is limited to 2**63 - 1 # the "depth" field on events is limited to 2**63 - 1
MAX_DEPTH = 2 ** 63 - 1 MAX_DEPTH = 2 ** 63 - 1

View file

@ -30,9 +30,10 @@ from twisted.internet import defer, error, reactor
from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse import synapse
from synapse.api.constants import MAX_PDU_SIZE
from synapse.app import check_bind_error from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.server import ListenerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -288,7 +289,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.") logger.info("Context factories updated.")
async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): async def start(hs: "synapse.server.HomeServer"):
""" """
Start a Synapse server or worker. Start a Synapse server or worker.
@ -300,7 +301,6 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
Args: Args:
hs: homeserver instance hs: homeserver instance
listeners: Listener configuration ('listeners' in homeserver.yaml)
""" """
# Set up the SIGHUP machinery. # Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"): if hasattr(signal, "SIGHUP"):
@ -336,7 +336,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
# It is now safe to start your Synapse. # It is now safe to start your Synapse.
hs.start_listening(listeners) hs.start_listening()
hs.get_datastore().db_pool.start_profiling() hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start() hs.get_pusherpool().start()
@ -530,3 +530,25 @@ def sdnotify(state):
# this is a bit surprising, since we don't expect to have a NOTIFY_SOCKET # this is a bit surprising, since we don't expect to have a NOTIFY_SOCKET
# unless systemd is expecting us to notify it. # unless systemd is expecting us to notify it.
logger.warning("Unable to send notification to systemd: %s", e) logger.warning("Unable to send notification to systemd: %s", e)
def max_request_body_size(config: HomeServerConfig) -> int:
"""Get a suitable maximum size for incoming HTTP requests"""
# Other than media uploads, the biggest request we expect to see is a fully-loaded
# /federation/v1/send request.
#
# The main thing in such a request is up to 50 PDUs, and up to 100 EDUs. PDUs are
# limited to 65536 bytes (possibly slightly more if the sender didn't use canonical
# json encoding); there is no specced limit to EDUs (see
# https://github.com/matrix-org/matrix-doc/issues/3121).
#
# in short, we somewhat arbitrarily limit requests to 200 * 64K (about 12.5M)
#
max_request_size = 200 * MAX_PDU_SIZE
# if we have a media repo enabled, we may need to allow larger uploads than that
if config.media.can_load_media_repo:
max_request_size = max(max_request_size, config.media.max_upload_size)
return max_request_size

View file

@ -70,12 +70,6 @@ class AdminCmdSlavedStore(
class AdminCmdServer(HomeServer): class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore DATASTORE_CLASS = AdminCmdSlavedStore
def _listen_http(self, listener_config):
pass
def start_listening(self, listeners):
pass
async def export_data_command(hs, args): async def export_data_command(hs, args):
"""Export data for a user. """Export data for a user.
@ -232,7 +226,7 @@ def start(config_options):
async def run(): async def run():
with LoggingContext("command"): with LoggingContext("command"):
_base.start(ss, []) _base.start(ss)
await args.func(ss, args) await args.func(ss, args)
_base.start_worker_reactor( _base.start_worker_reactor(

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import sys import sys
from typing import Dict, Iterable, Optional from typing import Dict, Optional
from twisted.internet import address from twisted.internet import address
from twisted.web.resource import IResource from twisted.web.resource import IResource
@ -32,7 +32,7 @@ from synapse.api.urls import (
SERVER_KEY_V2_PREFIX, SERVER_KEY_V2_PREFIX,
) )
from synapse.app import _base from synapse.app import _base
from synapse.app._base import register_start from synapse.app._base import max_request_body_size, register_start
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging from synapse.config.logger import setup_logging
@ -367,14 +367,16 @@ class GenericWorkerServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
self.version_string, self.version_string,
max_request_body_size=max_request_body_size(self.config),
reactor=self.get_reactor(),
), ),
reactor=self.get_reactor(), reactor=self.get_reactor(),
) )
logger.info("Synapse worker now listening on port %d", port) logger.info("Synapse worker now listening on port %d", port)
def start_listening(self, listeners: Iterable[ListenerConfig]): def start_listening(self):
for listener in listeners: for listener in self.config.worker_listeners:
if listener.type == "http": if listener.type == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener.type == "manhole": elif listener.type == "manhole":
@ -468,7 +470,7 @@ def start(config_options):
# streams. Will no-op if no streams can be written to by this worker. # streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer() hs.get_replication_streamer()
register_start(_base.start, hs, config.worker_listeners) register_start(_base.start, hs)
_base.start_worker_reactor("synapse-generic-worker", config) _base.start_worker_reactor("synapse-generic-worker", config)

View file

@ -17,7 +17,7 @@
import logging import logging
import os import os
import sys import sys
from typing import Iterable, Iterator from typing import Iterator
from twisted.internet import reactor from twisted.internet import reactor
from twisted.web.resource import EncodingResourceWrapper, IResource from twisted.web.resource import EncodingResourceWrapper, IResource
@ -36,7 +36,13 @@ from synapse.api.urls import (
WEB_CLIENT_PREFIX, WEB_CLIENT_PREFIX,
) )
from synapse.app import _base from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start from synapse.app._base import (
listen_ssl,
listen_tcp,
max_request_body_size,
quit_with_error,
register_start,
)
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -126,19 +132,21 @@ class SynapseHomeServer(HomeServer):
else: else:
root_resource = OptionsResource() root_resource = OptionsResource()
root_resource = create_resource_tree(resources, root_resource) site = SynapseSite(
"synapse.access.%s.%s" % ("https" if tls else "http", site_tag),
site_tag,
listener_config,
create_resource_tree(resources, root_resource),
self.version_string,
max_request_body_size=max_request_body_size(self.config),
reactor=self.get_reactor(),
)
if tls: if tls:
ports = listen_ssl( ports = listen_ssl(
bind_addresses, bind_addresses,
port, port,
SynapseSite( site,
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
self.version_string,
),
self.tls_server_context_factory, self.tls_server_context_factory,
reactor=self.get_reactor(), reactor=self.get_reactor(),
) )
@ -148,13 +156,7 @@ class SynapseHomeServer(HomeServer):
ports = listen_tcp( ports = listen_tcp(
bind_addresses, bind_addresses,
port, port,
SynapseSite( site,
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
self.version_string,
),
reactor=self.get_reactor(), reactor=self.get_reactor(),
) )
logger.info("Synapse now listening on TCP port %d", port) logger.info("Synapse now listening on TCP port %d", port)
@ -273,14 +275,14 @@ class SynapseHomeServer(HomeServer):
return resources return resources
def start_listening(self, listeners: Iterable[ListenerConfig]): def start_listening(self):
if self.config.redis_enabled: if self.config.redis_enabled:
# If redis is enabled we connect via the replication command handler # If redis is enabled we connect via the replication command handler
# in the same way as the workers (since we're effectively a client # in the same way as the workers (since we're effectively a client
# rather than a server). # rather than a server).
self.get_tcp_replication().start_replication(self) self.get_tcp_replication().start_replication(self)
for listener in listeners: for listener in self.config.server.listeners:
if listener.type == "http": if listener.type == "http":
self._listening_services.extend( self._listening_services.extend(
self._listener_http(self.config, listener) self._listener_http(self.config, listener)
@ -413,7 +415,7 @@ def setup(config_options):
# Loading the provider metadata also ensures the provider config is valid. # Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata() await oidc.load_metadata()
await _base.start(hs, config.listeners) await _base.start(hs)
hs.get_datastore().db_pool.updates.start_doing_background_updates() hs.get_datastore().db_pool.updates.start_doing_background_updates()

View file

@ -31,7 +31,6 @@ from twisted.logger import (
) )
import synapse import synapse
from synapse.app import _base as appbase
from synapse.logging._structured import setup_structured_logging from synapse.logging._structured import setup_structured_logging
from synapse.logging.context import LoggingContextFilter from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter from synapse.logging.filter import MetadataFilter
@ -318,6 +317,8 @@ def setup_logging(
# Perform one-time logging configuration. # Perform one-time logging configuration.
_setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner) _setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
# Add a SIGHUP handler to reload the logging configuration, if one is available. # Add a SIGHUP handler to reload the logging configuration, if one is available.
from synapse.app import _base as appbase
appbase.register_sighup(_reload_logging_config, log_config_path) appbase.register_sighup(_reload_logging_config, log_config_path)
# Log immediately so we can grep backwards. # Log immediately so we can grep backwards.

View file

@ -235,7 +235,11 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile") self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix") self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
# Whether to enable user presence. # Whether to enable user presence.
presence_config = config.get("presence") or {} presence_config = config.get("presence") or {}
@ -407,10 +411,6 @@ class ServerConfig(Config):
config_path=("federation_ip_range_blacklist",), config_path=("federation_ip_range_blacklist",),
) )
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
# (undocumented) option for torturing the worker-mode replication a bit, # (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before # for testing. The value defines the number of milliseconds to pause before
# sending out any replication updates. # sending out any replication updates.

View file

@ -14,14 +14,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import MAX_PDU_SIZE, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, EventSizeError, SynapseError from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.api.room_versions import ( from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS, KNOWN_ROOM_VERSIONS,
@ -205,7 +205,7 @@ def _check_size_limits(event: EventBase) -> None:
too_big("type") too_big("type")
if len(event.event_id) > 255: if len(event.event_id) > 255:
too_big("event_id") too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536: if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE:
too_big("event") too_big("event")
@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
return False return False
def get_public_keys(invite_event): def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
public_keys = [] public_keys = []
if "public_key" in invite_event.content: if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]} o = {"public_key": invite_event.content["public_key"]}

View file

@ -15,7 +15,7 @@
import inspect import inspect
import logging import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode from urllib.parse import urlencode, urlparse
import attr import attr
import pymacaroons import pymacaroons
@ -68,8 +68,8 @@ logger = logging.getLogger(__name__)
# #
# Here we have the names of the cookies, and the options we use to set them. # Here we have the names of the cookies, and the options we use to set them.
_SESSION_COOKIES = [ _SESSION_COOKIES = [
(b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"), (b"oidc_session", b"HttpOnly; Secure; SameSite=None"),
(b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"), (b"oidc_session_no_samesite", b"HttpOnly"),
] ]
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
@ -279,6 +279,13 @@ class OidcProvider:
self._config = provider self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str self._callback_url = hs.config.oidc_callback_url # type: str
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
public_baseurl_path = urlparse(hs.config.server.public_baseurl).path
self._callback_path_prefix = (
public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc"
)
self._oidc_attribute_requirements = provider.attribute_requirements self._oidc_attribute_requirements = provider.attribute_requirements
self._scopes = provider.scopes self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method self._user_profile_method = provider.user_profile_method
@ -779,8 +786,13 @@ class OidcProvider:
for cookie_name, options in _SESSION_COOKIES: for cookie_name, options in _SESSION_COOKIES:
request.cookies.append( request.cookies.append(
b"%s=%s; Max-Age=3600; %s" b"%s=%s; Max-Age=3600; Path=%s; %s"
% (cookie_name, cookie.encode("utf-8"), options) % (
cookie_name,
cookie.encode("utf-8"),
self._callback_path_prefix,
options,
)
) )
metadata = await self.load_metadata() metadata = await self.load_metadata()

View file

@ -14,13 +14,14 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional, Tuple, Type, Union from typing import Optional, Tuple, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.interfaces import IAddress from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.resource import IResource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig from synapse.config.server import ListenerConfig
@ -49,6 +50,7 @@ class SynapseRequest(Request):
* Redaction of access_token query-params in __repr__ * Redaction of access_token query-params in __repr__
* Logging at start and end * Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint. * Metrics to record CPU, wallclock and DB time by endpoint.
* A limit to the size of request which will be accepted
It also provides a method `processing`, which returns a context manager. If this It also provides a method `processing`, which returns a context manager. If this
method is called, the request won't be logged until the context manager is closed; method is called, the request won't be logged until the context manager is closed;
@ -59,8 +61,9 @@ class SynapseRequest(Request):
logcontext: the log context for this request logcontext: the log context for this request
""" """
def __init__(self, channel, *args, **kw): def __init__(self, channel, *args, max_request_body_size=1024, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.site = channel.site # type: SynapseSite self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.start_time = 0.0 self.start_time = 0.0
@ -97,6 +100,18 @@ class SynapseRequest(Request):
self.site.site_tag, self.site.site_tag,
) )
def handleContentChunk(self, data):
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
logger.warning(
"Aborting connection from %s because the request exceeds maximum size",
self.client,
)
self.transport.abortConnection()
return
super().handleContentChunk(data)
@property @property
def requester(self) -> Optional[Union[Requester, str]]: def requester(self) -> Optional[Union[Requester, str]]:
return self._requester return self._requester
@ -485,29 +500,55 @@ class _XForwardedForAddress:
class SynapseSite(Site): class SynapseSite(Site):
""" """
Subclass of a twisted http Site that does access logging with python's Synapse-specific twisted http Site
standard logging
This does two main things.
First, it replaces the requestFactory in use so that we build SynapseRequests
instead of regular t.w.server.Requests. All of the constructor params are really
just parameters for SynapseRequest.
Second, it inhibits the log() method called by Request.finish, since SynapseRequest
does its own logging.
""" """
def __init__( def __init__(
self, self,
logger_name, logger_name: str,
site_tag, site_tag: str,
config: ListenerConfig, config: ListenerConfig,
resource, resource: IResource,
server_version_string, server_version_string,
*args, max_request_body_size: int,
**kwargs, reactor: IReactorTime,
): ):
Site.__init__(self, resource, *args, **kwargs) """
Args:
logger_name: The name of the logger to use for access logs.
site_tag: A tag to use for this site - mostly in access logs.
config: Configuration for the HTTP listener corresponding to this site
resource: The base of the resource tree to be used for serving requests on
this site
server_version_string: A string to present for the Server header
max_request_body_size: Maximum request body length to allow before
dropping the connection
reactor: reactor to be used to manage connection timeouts
"""
Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag self.site_tag = site_tag
assert config.http_options is not None assert config.http_options is not None
proxied = config.http_options.x_forwarded proxied = config.http_options.x_forwarded
self.requestFactory = ( request_class = XForwardedForRequest if proxied else SynapseRequest
XForwardedForRequest if proxied else SynapseRequest
) # type: Type[Request] def request_factory(channel, queued) -> Request:
return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued
)
self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii") self.server_version_string = server_version_string.encode("ascii")

View file

@ -51,8 +51,6 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_POST(self, request: SynapseRequest) -> None: async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader("Content-Length")
if content_length is None: if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400) raise SynapseError(msg="Request must specify a Content-Length", code=400)

View file

@ -287,6 +287,14 @@ class HomeServer(metaclass=abc.ABCMeta):
if self.config.run_background_tasks: if self.config.run_background_tasks:
self.setup_background_tasks() self.setup_background_tasks()
def start_listening(self) -> None:
"""Start the HTTP, manhole, metrics, etc listeners
Does nothing in this base class; overridden in derived classes to start the
appropriate listeners.
"""
pass
def setup_background_tasks(self) -> None: def setup_background_tasks(self) -> None:
""" """
Some handlers have side effects on instantiation (like registering Some handlers have side effects on instantiation (like registering

View file

@ -17,8 +17,10 @@ from functools import wraps
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Collection,
Generic, Generic,
Iterable, Iterable,
List,
Optional, Optional,
Type, Type,
TypeVar, TypeVar,
@ -83,15 +85,30 @@ class _Node:
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"] __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
def __init__( def __init__(
self, prev_node, next_node, key, value, callbacks: Optional[set] = None self,
prev_node,
next_node,
key,
value,
callbacks: Collection[Callable[[], None]] = (),
): ):
self.prev_node = prev_node self.prev_node = prev_node
self.next_node = next_node self.next_node = next_node
self.key = key self.key = key
self.value = value self.value = value
self.callbacks = callbacks or set()
self.memory = 0 self.memory = 0
# Set of callbacks to run when the node gets deleted. We store as a list
# rather than a set to keep memory usage down (and since we expect few
# entries per node the performance of checking for duplication in a list
# vs using a set is negligible).
#
# Note that we store this as an optional list to keep the memory
# footprint down. Empty lists are 56 bytes (and empty sets are 216 bytes).
self.callbacks = None # type: Optional[List[Callable[[], None]]]
self.add_callbacks(callbacks)
if TRACK_MEMORY_USAGE: if TRACK_MEMORY_USAGE:
self.memory = ( self.memory = (
_get_size_of(key) _get_size_of(key)
@ -101,6 +118,32 @@ class _Node:
) )
self.memory += _get_size_of(self.memory, recurse=False) self.memory += _get_size_of(self.memory, recurse=False)
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
"""Add to stored list of callbacks, removing duplicates."""
if not callbacks:
return
if not self.callbacks:
self.callbacks = []
for callback in callbacks:
if callback not in self.callbacks:
self.callbacks.append(callback)
def run_and_clear_callbacks(self) -> None:
"""Run all callbacks and clear the stored set of callbacks. Used when
the node is being deleted.
"""
if not self.callbacks:
return
for callback in self.callbacks:
callback()
self.callbacks = None
class LruCache(Generic[KT, VT]): class LruCache(Generic[KT, VT]):
""" """
@ -213,10 +256,10 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len) self.len = synchronized(cache_len)
def add_node(key, value, callbacks: Optional[set] = None): def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value, callbacks or set()) node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node prev_node.next_node = node
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
@ -250,9 +293,7 @@ class LruCache(Generic[KT, VT]):
deleted_len = size_callback(node.value) deleted_len = size_callback(node.value)
cached_cache_len[0] -= deleted_len cached_cache_len[0] -= deleted_len
for cb in node.callbacks: node.run_and_clear_callbacks()
cb()
node.callbacks.clear()
if TRACK_MEMORY_USAGE and metrics: if TRACK_MEMORY_USAGE and metrics:
metrics.dec_memory_usage(node.memory) metrics.dec_memory_usage(node.memory)
@ -263,7 +304,7 @@ class LruCache(Generic[KT, VT]):
def cache_get( def cache_get(
key: KT, key: KT,
default: Literal[None] = None, default: Literal[None] = None,
callbacks: Iterable[Callable[[], None]] = ..., callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ..., update_metrics: bool = ...,
) -> Optional[VT]: ) -> Optional[VT]:
... ...
@ -272,7 +313,7 @@ class LruCache(Generic[KT, VT]):
def cache_get( def cache_get(
key: KT, key: KT,
default: T, default: T,
callbacks: Iterable[Callable[[], None]] = ..., callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ..., update_metrics: bool = ...,
) -> Union[T, VT]: ) -> Union[T, VT]:
... ...
@ -281,13 +322,13 @@ class LruCache(Generic[KT, VT]):
def cache_get( def cache_get(
key: KT, key: KT,
default: Optional[T] = None, default: Optional[T] = None,
callbacks: Iterable[Callable[[], None]] = (), callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True, update_metrics: bool = True,
): ):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
node.callbacks.update(callbacks) node.add_callbacks(callbacks)
if update_metrics and metrics: if update_metrics and metrics:
metrics.inc_hits() metrics.inc_hits()
return node.value return node.value
@ -303,10 +344,8 @@ class LruCache(Generic[KT, VT]):
# We sometimes store large objects, e.g. dicts, which cause # We sometimes store large objects, e.g. dicts, which cause
# the inequality check to take a long time. So let's only do # the inequality check to take a long time. So let's only do
# the check if we have some callbacks to call. # the check if we have some callbacks to call.
if node.callbacks and value != node.value: if value != node.value:
for cb in node.callbacks: node.run_and_clear_callbacks()
cb()
node.callbacks.clear()
# We don't bother to protect this by value != node.value as # We don't bother to protect this by value != node.value as
# generally size_callback will be cheap compared with equality # generally size_callback will be cheap compared with equality
@ -316,7 +355,7 @@ class LruCache(Generic[KT, VT]):
cached_cache_len[0] -= size_callback(node.value) cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value) cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks) node.add_callbacks(callbacks)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
@ -369,8 +408,7 @@ class LruCache(Generic[KT, VT]):
list_root.next_node = list_root list_root.next_node = list_root
list_root.prev_node = list_root list_root.prev_node = list_root
for node in cache.values(): for node in cache.values():
for cb in node.callbacks: node.run_and_clear_callbacks()
cb()
cache.clear() cache.clear()
if size_callback: if size_callback:
cached_cache_len[0] = 0 cached_cache_len[0] = 0

83
tests/http/test_site.py Normal file
View file

@ -0,0 +1,83 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet.address import IPv6Address
from twisted.test.proto_helpers import StringTransport
from synapse.app.homeserver import SynapseHomeServer
from tests.unittest import HomeserverTestCase
class SynapseRequestTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
def test_large_request(self):
"""overlarge HTTP requests should be rejected"""
self.hs.start_listening()
# find the HTTP server which is configured to listen on port 0
(port, factory, _backlog, interface) = self.reactor.tcpServers[0]
self.assertEqual(interface, "::")
self.assertEqual(port, 0)
# as a control case, first send a regular request.
# complete the connection and wire it up to a fake transport
client_address = IPv6Address("TCP", "::1", "2345")
protocol = factory.buildProtocol(client_address)
transport = StringTransport()
protocol.makeConnection(transport)
protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"0\r\n"
b"\r\n"
)
while not transport.disconnecting:
self.reactor.advance(1)
# we should get a 404
self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
# now send an oversized request
protocol = factory.buildProtocol(client_address)
transport = StringTransport()
protocol.makeConnection(transport)
protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
# we deliberately send all the data in one big chunk, to ensure that
# twisted isn't buffering the data in the chunked transfer decoder.
# we start with the chunk size, in hex. (We won't actually send this much)
protocol.dataReceived(b"10000000\r\n")
sent = 0
while not transport.disconnected:
self.assertLess(sent, 0x10000000, "connection did not drop")
protocol.dataReceived(b"\0" * 1024)
sent += 1024
# default max upload size is 50M, so it should drop on the next buffer after
# that.
self.assertEqual(sent, 50 * 1024 * 1024 + 1024)

View file

@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type from typing import Any, Callable, Dict, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request, Site
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
ServerReplicationStreamProtocol, ServerReplicationStreamProtocol,
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeTransport from tests.server import FakeTransport
@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol # Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site) channel = self.site.buildProtocol(None)
# hook into the channel's request factory so that we can keep a record
# of the requests
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory
def request_factory(*args, **kwargs):
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
channel.requestFactory = request_factory
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection() server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection() client_to_server_transport.loseConnection()
return channel.request # there should have been exactly one request
self.assertEqual(len(requests), 1)
return requests[0]
def assert_request_is_get_repl_stream_updates( def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str self, request: SynapseRequest, stream_name: str
@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config=worker_hs.config.server.listeners[0], config=worker_hs.config.server.listeners[0],
resource=resource, resource=resource,
server_version_string="1", server_version_string="1",
max_request_body_size=4096,
reactor=self.reactor,
) )
if worker_hs.config.redis.redis_enabled: if worker_hs.config.redis.redis_enabled:
@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol # Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs]) channel = self._hs_to_site[hs].buildProtocol(None)
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
self.received_rdata_rows.append((stream_name, token, r)) self.received_rdata_rows.append((stream_name, token, r))
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
This is a hack to get around the fact that HTTPChannel transparently wraps a
pull producer (which is what Synapse uses to reply to requests) with
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
uses the standard reactor rather than letting us use our test reactor, which
makes it very hard to test.
"""
def __init__(
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
self.requestFactory = request_factory
self.site = site
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.
if not streaming:
self._pull_to_push_producer = _PullToPushProducer(
self.reactor, producer, self
)
producer = self._pull_to_push_producer
super().registerProducer(producer, True)
def unregisterProducer(self):
if self._pull_to_push_producer:
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
def checkPersistence(self, request, version):
"""Check whether the connection can be re-used"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
def requestDone(self, request):
# Store the request for inspection.
self.request = request
super().requestDone(request)
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
def __init__(
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
):
self._clock = Clock(reactor)
self._producer = producer
self._consumer = consumer
# While running we use a looping call with a zero delay to call
# resumeProducing on given producer.
self._looping_call = None # type: Optional[LoopingCall]
# We start writing next reactor tick.
self._start_loop()
def _start_loop(self):
"""Start the looping call to"""
if not self._looping_call:
# Start a looping call which runs every tick.
self._looping_call = self._clock.looping_call(self._run_once, 0)
def stop(self):
"""Stops calling resumeProducing."""
if self._looping_call:
self._looping_call.stop()
self._looping_call = None
def pauseProducing(self):
"""Implements IPushProducer"""
self.stop()
def resumeProducing(self):
"""Implements IPushProducer"""
self._start_loop()
def stopProducing(self):
"""Implements IPushProducer"""
self.stop()
self._producer.stopProducing()
def _run_once(self):
"""Calls resumeProducing on producer once."""
try:
self._producer.resumeProducing()
except Exception:
logger.exception("Failed to call resumeProducing")
try:
self._consumer.unregisterProducer()
except Exception:
pass
self.stopProducing()
class FakeRedisPubSubServer: class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub.""" """A fake Redis server for pub/sub."""

View file

@ -603,12 +603,6 @@ class FakeTransport:
if self.disconnected: if self.disconnected:
return return
if not hasattr(self.other, "transport"):
# the other has no transport yet; reschedule
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
return
if maxbytes is not None: if maxbytes is not None:
to_write = self.buffer[:maxbytes] to_write = self.buffer[:maxbytes]
else: else:

View file

@ -202,6 +202,8 @@ class OptionsResourceTests(unittest.TestCase):
parse_listener_def({"type": "http", "port": 0}), parse_listener_def({"type": "http", "port": 0}),
self.resource, self.resource,
"1.0", "1.0",
max_request_body_size=1234,
reactor=self.reactor,
) )
# render the request and return the channel # render the request and return the channel

View file

@ -247,6 +247,8 @@ class HomeserverTestCase(TestCase):
config=self.hs.config.server.listeners[0], config=self.hs.config.server.listeners[0],
resource=self.resource, resource=self.resource,
server_version_string="1", server_version_string="1",
max_request_body_size=1234,
reactor=self.reactor,
) )
from tests.rest.client.v1.utils import RestHelper from tests.rest.client.v1.utils import RestHelper