Merge branch 'erikj/reduce_size_of_cache' into erikj/merge_cache_prs

This commit is contained in:
Erik Johnston 2021-04-26 16:30:42 +01:00
commit a99c692906
25 changed files with 364 additions and 248 deletions

1
changelog.d/9726.bugfix Normal file
View file

@ -0,0 +1 @@
Fixes the OIDC SSO flow when using a `public_baseurl` value including a non-root URL path.

1
changelog.d/9817.misc Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug which caused `max_upload_size` to not be correctly enforced.

1
changelog.d/9874.misc Normal file
View file

@ -0,0 +1 @@
Pass a reactor into `SynapseSite` to make testing easier.

1
changelog.d/9876.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.

1
changelog.d/9878.misc Normal file
View file

@ -0,0 +1 @@
Remove redundant `_PushHTTPChannel` test class.

View file

@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
from twisted.web.server import Request
import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -68,7 +70,7 @@ class Auth:
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@ -88,13 +90,13 @@ class Auth:
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
):
) -> None:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_by_id = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check(
@ -151,17 +153,11 @@ class Auth:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
async def check_host_in_room(self, room_id, host):
async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids
return await self.store.is_host_joined(room_id, host)
def can_federate(self, event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
def get_public_keys(self, invite_event):
def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
return event_auth.get_public_keys(invite_event)
async def get_user_by_req(
@ -170,7 +166,7 @@ class Auth:
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
) -> synapse.types.Requester:
) -> Requester:
"""Get a registered user's ID.
Args:
@ -196,7 +192,7 @@ class Auth:
access_token = self.get_access_token_from_request(request)
user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
if user_id and app_service:
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
@ -206,9 +202,7 @@ class Auth:
device_id="dummy-device", # stubbed
)
requester = synapse.types.create_requester(
user_id, app_service=app_service
)
requester = create_requester(user_id, app_service=app_service)
request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id)
@ -251,7 +245,7 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
requester = synapse.types.create_requester(
requester = create_requester(
user_info.user_id,
token_id,
is_guest,
@ -271,7 +265,9 @@ class Auth:
except KeyError:
raise MissingClientTokenError()
async def _get_appservice_user_id(self, request):
async def _get_appservice_user_id(
self, request: Request
) -> Tuple[Optional[str], Optional[ApplicationService]]:
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
@ -283,6 +279,9 @@ class Auth:
if ip_address not in app_service.ip_range_whitelist:
return None, None
# This will always be set by the time Twisted calls us.
assert request.args is not None
if b"user_id" not in request.args:
return app_service.sender, app_service
@ -387,7 +386,9 @@ class Auth:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise InvalidClientTokenError("Invalid macaroon passed.")
def _parse_and_validate_macaroon(self, token, rights="access"):
def _parse_and_validate_macaroon(
self, token: str, rights: str = "access"
) -> Tuple[str, bool]:
"""Takes a macaroon and tries to parse and validate it. This is cached
if and only if rights == access and there isn't an expiry.
@ -432,15 +433,16 @@ class Auth:
return user_id, guest
def validate_macaroon(self, macaroon, type_string, user_id):
def validate_macaroon(
self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
) -> None:
"""
validate that a Macaroon is understood by and was signed by this server.
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access",
"delete_pusher")
user_id (str): The user_id required
macaroon: The macaroon to validate
type_string: The kind of token required (e.g. "access", "delete_pusher")
user_id: The user_id required
"""
v = pymacaroons.Verifier()
@ -465,9 +467,7 @@ class Auth:
if not service:
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.requester = synapse.types.create_requester(
service.sender, app_service=service
)
request.requester = create_requester(service.sender, app_service=service)
return service
async def is_server_admin(self, user: UserID) -> bool:
@ -519,7 +519,7 @@ class Auth:
return auth_ids
async def check_can_change_room_list(self, room_id: str, user: UserID):
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
@ -554,11 +554,11 @@ class Auth:
return user_level >= send_level
@staticmethod
def has_access_token(request: Request):
def has_access_token(request: Request) -> bool:
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
False if no access_token was given, True otherwise.
"""
# This will always be set by the time Twisted calls us.
assert request.args is not None
@ -568,13 +568,13 @@ class Auth:
return bool(query_params) or bool(auth_headers)
@staticmethod
def get_access_token_from_request(request: Request):
def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request.
Args:
request: The http request.
Returns:
unicode: The access_token
The access_token
Raises:
MissingClientTokenError: If there isn't a single access_token in the
request
@ -649,5 +649,5 @@ class Auth:
% (user_id, room_id),
)
def check_auth_blocking(self, *args, **kwargs):
return self._auth_blocking.check_auth_blocking(*args, **kwargs)
async def check_auth_blocking(self, *args, **kwargs) -> None:
await self._auth_blocking.check_auth_blocking(*args, **kwargs)

View file

@ -13,18 +13,21 @@
# limitations under the License.
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
from synapse.types import Requester
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class AuthBlocking:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid
@ -43,7 +46,7 @@ class AuthBlocking:
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
):
) -> None:
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag

View file

@ -17,6 +17,9 @@
"""Contains constants from the specification."""
# the max size of a (canonical-json-encoded) event
MAX_PDU_SIZE = 65536
# the "depth" field on events is limited to 2**63 - 1
MAX_DEPTH = 2 ** 63 - 1

View file

@ -30,9 +30,10 @@ from twisted.internet import defer, error, reactor
from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.api.constants import MAX_PDU_SIZE
from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.server import ListenerConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -288,7 +289,7 @@ def refresh_certificate(hs):
logger.info("Context factories updated.")
async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
async def start(hs: "synapse.server.HomeServer"):
"""
Start a Synapse server or worker.
@ -300,7 +301,6 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
Args:
hs: homeserver instance
listeners: Listener configuration ('listeners' in homeserver.yaml)
"""
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
@ -336,7 +336,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
# It is now safe to start your Synapse.
hs.start_listening(listeners)
hs.start_listening()
hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
@ -530,3 +530,25 @@ def sdnotify(state):
# this is a bit surprising, since we don't expect to have a NOTIFY_SOCKET
# unless systemd is expecting us to notify it.
logger.warning("Unable to send notification to systemd: %s", e)
def max_request_body_size(config: HomeServerConfig) -> int:
"""Get a suitable maximum size for incoming HTTP requests"""
# Other than media uploads, the biggest request we expect to see is a fully-loaded
# /federation/v1/send request.
#
# The main thing in such a request is up to 50 PDUs, and up to 100 EDUs. PDUs are
# limited to 65536 bytes (possibly slightly more if the sender didn't use canonical
# json encoding); there is no specced limit to EDUs (see
# https://github.com/matrix-org/matrix-doc/issues/3121).
#
# in short, we somewhat arbitrarily limit requests to 200 * 64K (about 12.5M)
#
max_request_size = 200 * MAX_PDU_SIZE
# if we have a media repo enabled, we may need to allow larger uploads than that
if config.media.can_load_media_repo:
max_request_size = max(max_request_size, config.media.max_upload_size)
return max_request_size

View file

@ -70,12 +70,6 @@ class AdminCmdSlavedStore(
class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore
def _listen_http(self, listener_config):
pass
def start_listening(self, listeners):
pass
async def export_data_command(hs, args):
"""Export data for a user.
@ -232,7 +226,7 @@ def start(config_options):
async def run():
with LoggingContext("command"):
_base.start(ss, [])
_base.start(ss)
await args.func(ss, args)
_base.start_worker_reactor(

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
import sys
from typing import Dict, Iterable, Optional
from typing import Dict, Optional
from twisted.internet import address
from twisted.web.resource import IResource
@ -32,7 +32,7 @@ from synapse.api.urls import (
SERVER_KEY_V2_PREFIX,
)
from synapse.app import _base
from synapse.app._base import register_start
from synapse.app._base import max_request_body_size, register_start
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
@ -367,14 +367,16 @@ class GenericWorkerServer(HomeServer):
listener_config,
root_resource,
self.version_string,
max_request_body_size=max_request_body_size(self.config),
reactor=self.get_reactor(),
),
reactor=self.get_reactor(),
)
logger.info("Synapse worker now listening on port %d", port)
def start_listening(self, listeners: Iterable[ListenerConfig]):
for listener in listeners:
def start_listening(self):
for listener in self.config.worker_listeners:
if listener.type == "http":
self._listen_http(listener)
elif listener.type == "manhole":
@ -468,7 +470,7 @@ def start(config_options):
# streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer()
register_start(_base.start, hs, config.worker_listeners)
register_start(_base.start, hs)
_base.start_worker_reactor("synapse-generic-worker", config)

View file

@ -17,7 +17,7 @@
import logging
import os
import sys
from typing import Iterable, Iterator
from typing import Iterator
from twisted.internet import reactor
from twisted.web.resource import EncodingResourceWrapper, IResource
@ -36,7 +36,13 @@ from synapse.api.urls import (
WEB_CLIENT_PREFIX,
)
from synapse.app import _base
from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
from synapse.app._base import (
listen_ssl,
listen_tcp,
max_request_body_size,
quit_with_error,
register_start,
)
from synapse.config._base import ConfigError
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
@ -126,19 +132,21 @@ class SynapseHomeServer(HomeServer):
else:
root_resource = OptionsResource()
root_resource = create_resource_tree(resources, root_resource)
site = SynapseSite(
"synapse.access.%s.%s" % ("https" if tls else "http", site_tag),
site_tag,
listener_config,
create_resource_tree(resources, root_resource),
self.version_string,
max_request_body_size=max_request_body_size(self.config),
reactor=self.get_reactor(),
)
if tls:
ports = listen_ssl(
bind_addresses,
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
self.version_string,
),
site,
self.tls_server_context_factory,
reactor=self.get_reactor(),
)
@ -148,13 +156,7 @@ class SynapseHomeServer(HomeServer):
ports = listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
self.version_string,
),
site,
reactor=self.get_reactor(),
)
logger.info("Synapse now listening on TCP port %d", port)
@ -273,14 +275,14 @@ class SynapseHomeServer(HomeServer):
return resources
def start_listening(self, listeners: Iterable[ListenerConfig]):
def start_listening(self):
if self.config.redis_enabled:
# If redis is enabled we connect via the replication command handler
# in the same way as the workers (since we're effectively a client
# rather than a server).
self.get_tcp_replication().start_replication(self)
for listener in listeners:
for listener in self.config.server.listeners:
if listener.type == "http":
self._listening_services.extend(
self._listener_http(self.config, listener)
@ -413,7 +415,7 @@ def setup(config_options):
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
await _base.start(hs, config.listeners)
await _base.start(hs)
hs.get_datastore().db_pool.updates.start_doing_background_updates()

View file

@ -31,7 +31,6 @@ from twisted.logger import (
)
import synapse
from synapse.app import _base as appbase
from synapse.logging._structured import setup_structured_logging
from synapse.logging.context import LoggingContextFilter
from synapse.logging.filter import MetadataFilter
@ -318,6 +317,8 @@ def setup_logging(
# Perform one-time logging configuration.
_setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
# Add a SIGHUP handler to reload the logging configuration, if one is available.
from synapse.app import _base as appbase
appbase.register_sighup(_reload_logging_config, log_config_path)
# Log immediately so we can grep backwards.

View file

@ -235,7 +235,11 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
# Whether to enable user presence.
presence_config = config.get("presence") or {}
@ -407,10 +411,6 @@ class ServerConfig(Config):
config_path=("federation_ip_range_blacklist",),
)
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
# (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before
# sending out any replication updates.

View file

@ -14,14 +14,14 @@
# limitations under the License.
import logging
from typing import List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.constants import MAX_PDU_SIZE, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, EventSizeError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
@ -205,7 +205,7 @@ def _check_size_limits(event: EventBase) -> None:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE:
too_big("event")
@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
return False
def get_public_keys(invite_event):
def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
public_keys = []
if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]}

View file

@ -15,7 +15,7 @@
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode
from urllib.parse import urlencode, urlparse
import attr
import pymacaroons
@ -68,8 +68,8 @@ logger = logging.getLogger(__name__)
#
# Here we have the names of the cookies, and the options we use to set them.
_SESSION_COOKIES = [
(b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
(b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
(b"oidc_session", b"HttpOnly; Secure; SameSite=None"),
(b"oidc_session_no_samesite", b"HttpOnly"),
]
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
@ -279,6 +279,13 @@ class OidcProvider:
self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
public_baseurl_path = urlparse(hs.config.server.public_baseurl).path
self._callback_path_prefix = (
public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc"
)
self._oidc_attribute_requirements = provider.attribute_requirements
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
@ -779,8 +786,13 @@ class OidcProvider:
for cookie_name, options in _SESSION_COOKIES:
request.cookies.append(
b"%s=%s; Max-Age=3600; %s"
% (cookie_name, cookie.encode("utf-8"), options)
b"%s=%s; Max-Age=3600; Path=%s; %s"
% (
cookie_name,
cookie.encode("utf-8"),
self._callback_path_prefix,
options,
)
)
metadata = await self.load_metadata()

View file

@ -14,13 +14,14 @@
import contextlib
import logging
import time
from typing import Optional, Tuple, Type, Union
from typing import Optional, Tuple, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@ -49,6 +50,7 @@ class SynapseRequest(Request):
* Redaction of access_token query-params in __repr__
* Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint.
* A limit to the size of request which will be accepted
It also provides a method `processing`, which returns a context manager. If this
method is called, the request won't be logged until the context manager is closed;
@ -59,8 +61,9 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
def __init__(self, channel, *args, **kw):
def __init__(self, channel, *args, max_request_body_size=1024, **kw):
Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0
@ -97,6 +100,18 @@ class SynapseRequest(Request):
self.site.site_tag,
)
def handleContentChunk(self, data):
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
logger.warning(
"Aborting connection from %s because the request exceeds maximum size",
self.client,
)
self.transport.abortConnection()
return
super().handleContentChunk(data)
@property
def requester(self) -> Optional[Union[Requester, str]]:
return self._requester
@ -485,29 +500,55 @@ class _XForwardedForAddress:
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
Synapse-specific twisted http Site
This does two main things.
First, it replaces the requestFactory in use so that we build SynapseRequests
instead of regular t.w.server.Requests. All of the constructor params are really
just parameters for SynapseRequest.
Second, it inhibits the log() method called by Request.finish, since SynapseRequest
does its own logging.
"""
def __init__(
self,
logger_name,
site_tag,
logger_name: str,
site_tag: str,
config: ListenerConfig,
resource,
resource: IResource,
server_version_string,
*args,
**kwargs,
max_request_body_size: int,
reactor: IReactorTime,
):
Site.__init__(self, resource, *args, **kwargs)
"""
Args:
logger_name: The name of the logger to use for access logs.
site_tag: A tag to use for this site - mostly in access logs.
config: Configuration for the HTTP listener corresponding to this site
resource: The base of the resource tree to be used for serving requests on
this site
server_version_string: A string to present for the Server header
max_request_body_size: Maximum request body length to allow before
dropping the connection
reactor: reactor to be used to manage connection timeouts
"""
Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag
assert config.http_options is not None
proxied = config.http_options.x_forwarded
self.requestFactory = (
XForwardedForRequest if proxied else SynapseRequest
) # type: Type[Request]
request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued) -> Request:
return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued
)
self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")

View file

@ -51,8 +51,6 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)

View file

@ -287,6 +287,14 @@ class HomeServer(metaclass=abc.ABCMeta):
if self.config.run_background_tasks:
self.setup_background_tasks()
def start_listening(self) -> None:
"""Start the HTTP, manhole, metrics, etc listeners
Does nothing in this base class; overridden in derived classes to start the
appropriate listeners.
"""
pass
def setup_background_tasks(self) -> None:
"""
Some handlers have side effects on instantiation (like registering

View file

@ -17,8 +17,10 @@ from functools import wraps
from typing import (
Any,
Callable,
Collection,
Generic,
Iterable,
List,
Optional,
Type,
TypeVar,
@ -83,15 +85,30 @@ class _Node:
__slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
def __init__(
self, prev_node, next_node, key, value, callbacks: Optional[set] = None
self,
prev_node,
next_node,
key,
value,
callbacks: Collection[Callable[[], None]] = (),
):
self.prev_node = prev_node
self.next_node = next_node
self.key = key
self.value = value
self.callbacks = callbacks or set()
self.memory = 0
# Set of callbacks to run when the node gets deleted. We store as a list
# rather than a set to keep memory usage down (and since we expect few
# entries per node the performance of checking for duplication in a list
# vs using a set is negligible).
#
# Note that we store this as an optional list to keep the memory
# footprint down. Empty lists are 56 bytes (and empty sets are 216 bytes).
self.callbacks = None # type: Optional[List[Callable[[], None]]]
self.add_callbacks(callbacks)
if TRACK_MEMORY_USAGE:
self.memory = (
_get_size_of(key)
@ -101,6 +118,32 @@ class _Node:
)
self.memory += _get_size_of(self.memory, recurse=False)
def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
"""Add to stored list of callbacks, removing duplicates."""
if not callbacks:
return
if not self.callbacks:
self.callbacks = []
for callback in callbacks:
if callback not in self.callbacks:
self.callbacks.append(callback)
def run_and_clear_callbacks(self) -> None:
"""Run all callbacks and clear the stored set of callbacks. Used when
the node is being deleted.
"""
if not self.callbacks:
return
for callback in self.callbacks:
callback()
self.callbacks = None
class LruCache(Generic[KT, VT]):
"""
@ -213,10 +256,10 @@ class LruCache(Generic[KT, VT]):
self.len = synchronized(cache_len)
def add_node(key, value, callbacks: Optional[set] = None):
def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
prev_node = list_root
next_node = prev_node.next_node
node = _Node(prev_node, next_node, key, value, callbacks or set())
node = _Node(prev_node, next_node, key, value, callbacks)
prev_node.next_node = node
next_node.prev_node = node
cache[key] = node
@ -250,9 +293,7 @@ class LruCache(Generic[KT, VT]):
deleted_len = size_callback(node.value)
cached_cache_len[0] -= deleted_len
for cb in node.callbacks:
cb()
node.callbacks.clear()
node.run_and_clear_callbacks()
if TRACK_MEMORY_USAGE and metrics:
metrics.dec_memory_usage(node.memory)
@ -263,7 +304,7 @@ class LruCache(Generic[KT, VT]):
def cache_get(
key: KT,
default: Literal[None] = None,
callbacks: Iterable[Callable[[], None]] = ...,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Optional[VT]:
...
@ -272,7 +313,7 @@ class LruCache(Generic[KT, VT]):
def cache_get(
key: KT,
default: T,
callbacks: Iterable[Callable[[], None]] = ...,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Union[T, VT]:
...
@ -281,13 +322,13 @@ class LruCache(Generic[KT, VT]):
def cache_get(
key: KT,
default: Optional[T] = None,
callbacks: Iterable[Callable[[], None]] = (),
callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True,
):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
node.callbacks.update(callbacks)
node.add_callbacks(callbacks)
if update_metrics and metrics:
metrics.inc_hits()
return node.value
@ -303,10 +344,8 @@ class LruCache(Generic[KT, VT]):
# We sometimes store large objects, e.g. dicts, which cause
# the inequality check to take a long time. So let's only do
# the check if we have some callbacks to call.
if node.callbacks and value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
if value != node.value:
node.run_and_clear_callbacks()
# We don't bother to protect this by value != node.value as
# generally size_callback will be cheap compared with equality
@ -316,7 +355,7 @@ class LruCache(Generic[KT, VT]):
cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
node.add_callbacks(callbacks)
move_node_to_front(node)
node.value = value
@ -369,8 +408,7 @@ class LruCache(Generic[KT, VT]):
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
for cb in node.callbacks:
cb()
node.run_and_clear_callbacks()
cache.clear()
if size_callback:
cached_cache_len[0] = 0

83
tests/http/test_site.py Normal file
View file

@ -0,0 +1,83 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet.address import IPv6Address
from twisted.test.proto_helpers import StringTransport
from synapse.app.homeserver import SynapseHomeServer
from tests.unittest import HomeserverTestCase
class SynapseRequestTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
def test_large_request(self):
"""overlarge HTTP requests should be rejected"""
self.hs.start_listening()
# find the HTTP server which is configured to listen on port 0
(port, factory, _backlog, interface) = self.reactor.tcpServers[0]
self.assertEqual(interface, "::")
self.assertEqual(port, 0)
# as a control case, first send a regular request.
# complete the connection and wire it up to a fake transport
client_address = IPv6Address("TCP", "::1", "2345")
protocol = factory.buildProtocol(client_address)
transport = StringTransport()
protocol.makeConnection(transport)
protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"0\r\n"
b"\r\n"
)
while not transport.disconnecting:
self.reactor.advance(1)
# we should get a 404
self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
# now send an oversized request
protocol = factory.buildProtocol(client_address)
transport = StringTransport()
protocol.makeConnection(transport)
protocol.dataReceived(
b"POST / HTTP/1.1\r\n"
b"Connection: close\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
# we deliberately send all the data in one big chunk, to ensure that
# twisted isn't buffering the data in the chunked transfer decoder.
# we start with the chunk size, in hex. (We won't actually send this much)
protocol.dataReceived(b"10000000\r\n")
sent = 0
while not transport.disconnected:
self.assertLess(sent, 0x10000000, "connection did not drop")
protocol.dataReceived(b"\0" * 1024)
sent += 1024
# default max upload size is 50M, so it should drop on the next buffer after
# that.
self.assertEqual(sent, 50 * 1024 * 1024 + 1024)

View file

@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
from twisted.web.server import Request, Site
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
channel = self.site.buildProtocol(None)
# hook into the channel's request factory so that we can keep a record
# of the requests
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory
def request_factory(*args, **kwargs):
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
channel.requestFactory = request_factory
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
return channel.request
# there should have been exactly one request
self.assertEqual(len(requests), 1)
return requests[0]
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config=worker_hs.config.server.listeners[0],
resource=resource,
server_version_string="1",
max_request_body_size=4096,
reactor=self.reactor,
)
if worker_hs.config.redis.redis_enabled:
@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
channel = self._hs_to_site[hs].buildProtocol(None)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
self.received_rdata_rows.append((stream_name, token, r))
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
This is a hack to get around the fact that HTTPChannel transparently wraps a
pull producer (which is what Synapse uses to reply to requests) with
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
uses the standard reactor rather than letting us use our test reactor, which
makes it very hard to test.
"""
def __init__(
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
self.requestFactory = request_factory
self.site = site
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.
if not streaming:
self._pull_to_push_producer = _PullToPushProducer(
self.reactor, producer, self
)
producer = self._pull_to_push_producer
super().registerProducer(producer, True)
def unregisterProducer(self):
if self._pull_to_push_producer:
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
def checkPersistence(self, request, version):
"""Check whether the connection can be re-used"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
def requestDone(self, request):
# Store the request for inspection.
self.request = request
super().requestDone(request)
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
def __init__(
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
):
self._clock = Clock(reactor)
self._producer = producer
self._consumer = consumer
# While running we use a looping call with a zero delay to call
# resumeProducing on given producer.
self._looping_call = None # type: Optional[LoopingCall]
# We start writing next reactor tick.
self._start_loop()
def _start_loop(self):
"""Start the looping call to"""
if not self._looping_call:
# Start a looping call which runs every tick.
self._looping_call = self._clock.looping_call(self._run_once, 0)
def stop(self):
"""Stops calling resumeProducing."""
if self._looping_call:
self._looping_call.stop()
self._looping_call = None
def pauseProducing(self):
"""Implements IPushProducer"""
self.stop()
def resumeProducing(self):
"""Implements IPushProducer"""
self._start_loop()
def stopProducing(self):
"""Implements IPushProducer"""
self.stop()
self._producer.stopProducing()
def _run_once(self):
"""Calls resumeProducing on producer once."""
try:
self._producer.resumeProducing()
except Exception:
logger.exception("Failed to call resumeProducing")
try:
self._consumer.unregisterProducer()
except Exception:
pass
self.stopProducing()
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""

View file

@ -603,12 +603,6 @@ class FakeTransport:
if self.disconnected:
return
if not hasattr(self.other, "transport"):
# the other has no transport yet; reschedule
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
return
if maxbytes is not None:
to_write = self.buffer[:maxbytes]
else:

View file

@ -202,6 +202,8 @@ class OptionsResourceTests(unittest.TestCase):
parse_listener_def({"type": "http", "port": 0}),
self.resource,
"1.0",
max_request_body_size=1234,
reactor=self.reactor,
)
# render the request and return the channel

View file

@ -247,6 +247,8 @@ class HomeserverTestCase(TestCase):
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
max_request_body_size=1234,
reactor=self.reactor,
)
from tests.rest.client.v1.utils import RestHelper