From aabf577166546d98353ab9bdb6f0648193a94b85 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Jun 2024 10:40:34 +0100 Subject: [PATCH 1/2] Handle hyphens in user dir search porperly (#17254) c.f. #16675 --- changelog.d/17254.bugfix | 1 + .../storage/databases/main/user_directory.py | 66 +++++++++++++++++-- tests/handlers/test_user_directory.py | 39 +++++++++++ tests/storage/test_user_directory.py | 4 ++ 4 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 changelog.d/17254.bugfix diff --git a/changelog.d/17254.bugfix b/changelog.d/17254.bugfix new file mode 100644 index 0000000000..b0d61309e2 --- /dev/null +++ b/changelog.d/17254.bugfix @@ -0,0 +1 @@ +Fix searching for users with their exact localpart whose ID includes a hyphen. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 0513e7dc06..6e18f714d7 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -1281,7 +1281,7 @@ def _parse_words_with_regex(search_term: str) -> List[str]: Break down search term into words, when we don't have ICU available. See: `_parse_words` """ - return re.findall(r"([\w\-]+)", search_term, re.UNICODE) + return re.findall(r"([\w-]+)", search_term, re.UNICODE) def _parse_words_with_icu(search_term: str) -> List[str]: @@ -1303,15 +1303,69 @@ def _parse_words_with_icu(search_term: str) -> List[str]: if j < 0: break - result = search_term[i:j] + # We want to make sure that we split on `@` and `:` specifically, as + # they occur in user IDs. + for result in re.split(r"[@:]+", search_term[i:j]): + results.append(result.strip()) + + i = j + + # libicu will break up words that have punctuation in them, but to handle + # cases where user IDs have '-', '.' and '_' in them we want to *not* break + # those into words and instead allow the DB to tokenise them how it wants. + # + # In particular, user-71 in postgres gets tokenised to "user, -71", and this + # will not match a query for "user, 71". + new_results: List[str] = [] + i = 0 + while i < len(results): + curr = results[i] + + prev = None + next = None + if i > 0: + prev = results[i - 1] + if i + 1 < len(results): + next = results[i + 1] + + i += 1 # libicu considers spaces and punctuation between words as words, but we don't # want to include those in results as they would result in syntax errors in SQL # queries (e.g. "foo bar" would result in the search query including "foo & & # bar"). - if len(re.findall(r"([\w\-]+)", result, re.UNICODE)): - results.append(result) + if not curr: + continue - i = j + if curr in ["-", ".", "_"]: + prefix = "" + suffix = "" - return results + # Check if the next item is a word, and if so use it as the suffix. + # We check for if its a word as we don't want to concatenate + # multiple punctuation marks. + if next is not None and re.match(r"\w", next): + suffix = next + i += 1 # We're using next, so we skip it in the outer loop. + else: + # We want to avoid creating terms like "user-", as we should + # strip trailing punctuation. + continue + + if prev and re.match(r"\w", prev) and new_results: + prefix = new_results[-1] + new_results.pop() + + # We might not have a prefix here, but that's fine as we want to + # ensure that we don't strip preceding punctuation e.g. '-71' + # shouldn't be converted to '71'. + + new_results.append(f"{prefix}{curr}{suffix}") + continue + elif not re.match(r"\w", curr): + # Ignore other punctuation + continue + + new_results.append(curr) + + return new_results diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 77c6cac449..878d9683b6 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -1061,6 +1061,45 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): {alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)}, ) + def test_search_punctuation(self) -> None: + """Test that you can search for a user that includes punctuation""" + + searching_user = self.register_user("searcher", "password") + searching_user_tok = self.login("searcher", "password") + + room_id = self.helper.create_room_as( + searching_user, + room_version=RoomVersions.V1.identifier, + tok=searching_user_tok, + ) + + # We want to test searching for users of the form e.g. "user-1", with + # various punctuation. We also test both where the prefix is numeric and + # alphanumeric, as e.g. postgres tokenises "user-1" as "user" and "-1". + i = 1 + for char in ["-", ".", "_"]: + for use_numeric in [False, True]: + if use_numeric: + prefix1 = f"{i}" + prefix2 = f"{i+1}" + else: + prefix1 = f"a{i}" + prefix2 = f"a{i+1}" + + local_user_1 = self.register_user(f"user{char}{prefix1}", "password") + local_user_2 = self.register_user(f"user{char}{prefix2}", "password") + + self._add_user_to_room(room_id, RoomVersions.V1, local_user_1) + self._add_user_to_room(room_id, RoomVersions.V1, local_user_2) + + results = self.get_success( + self.handler.search_users(searching_user, local_user_1, 20) + )["results"] + received_user_id_ordering = [result["user_id"] for result in results] + self.assertSequenceEqual(received_user_id_ordering[:1], [local_user_1]) + + i += 2 + class TestUserDirSearchDisabled(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 156a610faa..c26932069f 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -711,6 +711,10 @@ class UserDirectoryICUTestCase(HomeserverTestCase): ), ) + self.assertEqual(_parse_words_with_icu("user-1"), ["user-1"]) + self.assertEqual(_parse_words_with_icu("user-ab"), ["user-ab"]) + self.assertEqual(_parse_words_with_icu("user.--1"), ["user", "-1"]) + def test_regex_word_boundary_punctuation(self) -> None: """ Tests the behaviour of punctuation with the non-ICU tokeniser From fcbc79bb87d08147e86dafa0fee5a9aec4d3fc23 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 5 Jun 2024 05:43:36 -0700 Subject: [PATCH 2/2] Ratelimiting of remote media downloads (#17256) --- changelog.d/17256.feature | 1 + .../configuration/config_documentation.md | 18 ++ synapse/config/ratelimiting.py | 10 + synapse/federation/federation_client.py | 7 + synapse/federation/transport/client.py | 9 + synapse/http/matrixfederationclient.py | 55 ++++- synapse/media/media_repository.py | 43 +++- synapse/media/thumbnailer.py | 6 +- synapse/rest/client/media.py | 2 + synapse/rest/media/download_resource.py | 8 +- synapse/rest/media/thumbnail_resource.py | 2 + tests/media/test_media_storage.py | 225 +++++++++++++++++- 12 files changed, 372 insertions(+), 14 deletions(-) create mode 100644 changelog.d/17256.feature diff --git a/changelog.d/17256.feature b/changelog.d/17256.feature new file mode 100644 index 0000000000..6ec4cb7a31 --- /dev/null +++ b/changelog.d/17256.feature @@ -0,0 +1 @@ + Improve ratelimiting in Synapse (#17256). \ No newline at end of file diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 2c917d1f8e..d23f8c4c4f 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1946,6 +1946,24 @@ Example configuration: max_image_pixels: 35M ``` --- +### `remote_media_download_burst_count` + +Remote media downloads are ratelimited using a [leaky bucket algorithm](https://en.wikipedia.org/wiki/Leaky_bucket), where a given "bucket" is keyed to the IP address of the requester when requesting remote media downloads. This configuration option sets the size of the bucket against which the size in bytes of downloads are penalized - if the bucket is full, ie a given number of bytes have already been downloaded, further downloads will be denied until the bucket drains. Defaults to 500MiB. See also `remote_media_download_per_second` which determines the rate at which the "bucket" is emptied and thus has available space to authorize new requests. + +Example configuration: +```yaml +remote_media_download_burst_count: 200M +``` +--- +### `remote_media_download_per_second` + +Works in conjunction with `remote_media_download_burst_count` to ratelimit remote media downloads - this configuration option determines the rate at which the "bucket" (see above) leaks in bytes per second. As requests are made to download remote media, the size of those requests in bytes is added to the bucket, and once the bucket has reached it's capacity, no more requests will be allowed until a number of bytes has "drained" from the bucket. This setting determines the rate at which bytes drain from the bucket, with the practical effect that the larger the number, the faster the bucket leaks, allowing for more bytes downloaded over a shorter period of time. Defaults to 87KiB per second. See also `remote_media_download_burst_count`. + +Example configuration: +```yaml +remote_media_download_per_second: 40K +``` +--- ### `prevent_media_downloads_from` A list of domains to never download media from. Media from these diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index d2cb4576df..3fa33f5373 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -218,3 +218,13 @@ class RatelimitConfig(Config): "rc_media_create", defaults={"per_second": 10, "burst_count": 50}, ) + + self.remote_media_downloads = RatelimitSettings( + key="rc_remote_media_downloads", + per_second=self.parse_size( + config.get("remote_media_download_per_second", "87K") + ), + burst_count=self.parse_size( + config.get("remote_media_download_burst_count", "500M") + ), + ) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index e613eb87a6..f0f5a37a57 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -56,6 +56,7 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, @@ -1877,6 +1878,8 @@ class FederationClient(FederationBase): output_stream: BinaryIO, max_size: int, max_timeout_ms: int, + download_ratelimiter: Ratelimiter, + ip_address: str, ) -> Tuple[int, Dict[bytes, List[bytes]]]: try: return await self.transport_layer.download_media_v3( @@ -1885,6 +1888,8 @@ class FederationClient(FederationBase): output_stream=output_stream, max_size=max_size, max_timeout_ms=max_timeout_ms, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, @@ -1905,6 +1910,8 @@ class FederationClient(FederationBase): output_stream=output_stream, max_size=max_size, max_timeout_ms=max_timeout_ms, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index de408f7f8d..af1336fe5f 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -43,6 +43,7 @@ import ijson from synapse.api.constants import Direction, Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import RoomVersion from synapse.api.urls import ( FEDERATION_UNSTABLE_PREFIX, @@ -819,6 +820,8 @@ class TransportLayerClient: output_stream: BinaryIO, max_size: int, max_timeout_ms: int, + download_ratelimiter: Ratelimiter, + ip_address: str, ) -> Tuple[int, Dict[bytes, List[bytes]]]: path = f"/_matrix/media/r0/download/{destination}/{media_id}" @@ -834,6 +837,8 @@ class TransportLayerClient: "allow_remote": "false", "timeout_ms": str(max_timeout_ms), }, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, ) async def download_media_v3( @@ -843,6 +848,8 @@ class TransportLayerClient: output_stream: BinaryIO, max_size: int, max_timeout_ms: int, + download_ratelimiter: Ratelimiter, + ip_address: str, ) -> Tuple[int, Dict[bytes, List[bytes]]]: path = f"/_matrix/media/v3/download/{destination}/{media_id}" @@ -862,6 +869,8 @@ class TransportLayerClient: "allow_redirect": "true", }, follow_redirects=True, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c73a589e6c..104b803b0f 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime from twisted.internet.task import Cooperator from twisted.web.client import ResponseFailed from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IResponse +from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse import synapse.metrics import synapse.util.retryutils @@ -68,6 +68,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( @@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient: destination: str, path: str, output_stream: BinaryIO, + download_ratelimiter: Ratelimiter, + ip_address: str, + max_size: int, args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, - max_size: Optional[int] = None, ignore_backoff: bool = False, follow_redirects: bool = False, ) -> Tuple[int, Dict[bytes, List[bytes]]]: @@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient: destination: The remote server to send the HTTP request to. path: The HTTP path to GET. output_stream: File to write the response body to. + download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to + requester IP + ip_address: IP address of the requester + max_size: maximum allowable size in bytes of the file args: Optional dictionary used to create the query string. ignore_backoff: true to ignore the historical backoff data and try the request anyway. @@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient: federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. + SynapseError: If the requested file exceeds ratelimits """ request = MatrixFederationRequest( method="GET", destination=destination, path=path, query=args ) + # check for a minimum balance of 1MiB in ratelimiter before initiating request + send_req, _ = await download_ratelimiter.can_do_action( + requester=None, key=ip_address, n_actions=1048576, update=False + ) + + if not send_req: + msg = "Requested file size exceeds ratelimits" + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED) + response = await self._send_request( request, retry_on_dns_fail=retry_on_dns_fail, @@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient: headers = dict(response.headers.getAllRawHeaders()) + expected_size = response.length + # if we don't get an expected length then use the max length + if expected_size == UNKNOWN_LENGTH: + expected_size = max_size + logger.debug( + f"File size unknown, assuming file is max allowable size: {max_size}" + ) + + read_body, _ = await download_ratelimiter.can_do_action( + requester=None, + key=ip_address, + n_actions=expected_size, + ) + if not read_body: + msg = "Requested file size exceeds ratelimits" + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED) + try: - d = read_body_with_max_size(response, output_stream, max_size) + # add a byte of headroom to max size as function errs at >= + d = read_body_with_max_size(response, output_stream, expected_size + 1) d.addTimeout(self.default_timeout_seconds, self.reactor) length = await make_deferred_yieldable(d) except BodyExceededMaxSize: - msg = "Requested file is too large > %r bytes" % (max_size,) + msg = "Requested file is too large > %r bytes" % (expected_size,) logger.warning( "{%s} [%s] %s", request.txn_id, diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 9c29e09653..6ed56099ca 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -42,6 +42,7 @@ from synapse.api.errors import ( SynapseError, cs_error, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.repository import ThumbnailRequirement from synapse.http.server import respond_with_json from synapse.http.site import SynapseRequest @@ -111,6 +112,12 @@ class MediaRepository: ) self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from + self.download_ratelimiter = Ratelimiter( + store=hs.get_storage_controllers().main, + clock=hs.get_clock(), + cfg=hs.config.ratelimiting.remote_media_downloads, + ) + # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] @@ -464,6 +471,7 @@ class MediaRepository: media_id: str, name: Optional[str], max_timeout_ms: int, + ip_address: str, ) -> None: """Respond to requests for remote media. @@ -475,6 +483,7 @@ class MediaRepository: the filename in the Content-Disposition header of the response. max_timeout_ms: the maximum number of milliseconds to wait for the media to be uploaded. + ip_address: the IP address of the requester Returns: Resolves once a response has successfully been written to request @@ -500,7 +509,11 @@ class MediaRepository: key = (server_name, media_id) async with self.remote_media_linearizer.queue(key): responder, media_info = await self._get_remote_media_impl( - server_name, media_id, max_timeout_ms + server_name, + media_id, + max_timeout_ms, + self.download_ratelimiter, + ip_address, ) # We deliberately stream the file outside the lock @@ -517,7 +530,7 @@ class MediaRepository: respond_404(request) async def get_remote_media_info( - self, server_name: str, media_id: str, max_timeout_ms: int + self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str ) -> RemoteMedia: """Gets the media info associated with the remote file, downloading if necessary. @@ -527,6 +540,7 @@ class MediaRepository: media_id: The media ID of the content (as defined by the remote server). max_timeout_ms: the maximum number of milliseconds to wait for the media to be uploaded. + ip_address: IP address of the requester Returns: The media info of the file @@ -542,7 +556,11 @@ class MediaRepository: key = (server_name, media_id) async with self.remote_media_linearizer.queue(key): responder, media_info = await self._get_remote_media_impl( - server_name, media_id, max_timeout_ms + server_name, + media_id, + max_timeout_ms, + self.download_ratelimiter, + ip_address, ) # Ensure we actually use the responder so that it releases resources @@ -553,7 +571,12 @@ class MediaRepository: return media_info async def _get_remote_media_impl( - self, server_name: str, media_id: str, max_timeout_ms: int + self, + server_name: str, + media_id: str, + max_timeout_ms: int, + download_ratelimiter: Ratelimiter, + ip_address: str, ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -564,6 +587,9 @@ class MediaRepository: remote server). max_timeout_ms: the maximum number of milliseconds to wait for the media to be uploaded. + download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to + requester IP. + ip_address: the IP address of the requester Returns: A tuple of responder and the media info of the file. @@ -596,7 +622,7 @@ class MediaRepository: try: media_info = await self._download_remote_file( - server_name, media_id, max_timeout_ms + server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address ) except SynapseError: raise @@ -630,6 +656,8 @@ class MediaRepository: server_name: str, media_id: str, max_timeout_ms: int, + download_ratelimiter: Ratelimiter, + ip_address: str, ) -> RemoteMedia: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -641,6 +669,9 @@ class MediaRepository: locally generated. max_timeout_ms: the maximum number of milliseconds to wait for the media to be uploaded. + download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to + requester IP + ip_address: the IP address of the requester Returns: The media info of the file. @@ -658,6 +689,8 @@ class MediaRepository: output_stream=f, max_size=self.max_upload_size, max_timeout_ms=max_timeout_ms, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, ) except RequestSendFailed as e: logger.warning( diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index cc3acf51e1..f8a9560784 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -359,9 +359,10 @@ class ThumbnailProvider: desired_method: str, desired_type: str, max_timeout_ms: int, + ip_address: str, ) -> None: media_info = await self.media_repo.get_remote_media_info( - server_name, media_id, max_timeout_ms + server_name, media_id, max_timeout_ms, ip_address ) if not media_info: respond_404(request) @@ -422,12 +423,13 @@ class ThumbnailProvider: method: str, m_type: str, max_timeout_ms: int, + ip_address: str, ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. media_info = await self.media_repo.get_remote_media_info( - server_name, media_id, max_timeout_ms + server_name, media_id, max_timeout_ms, ip_address ) if not media_info: return diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 172d240783..0c089163c1 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -174,6 +174,7 @@ class UnstableThumbnailResource(RestServlet): respond_404(request) return + ip_address = request.getClientAddress().host remote_resp_function = ( self.thumbnailer.select_or_generate_remote_thumbnail if self.dynamic_thumbnails @@ -188,6 +189,7 @@ class UnstableThumbnailResource(RestServlet): method, m_type, max_timeout_ms, + ip_address, ) self.media_repo.mark_recently_accessed(server_name, media_id) diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py index 8ba723c8d4..1628d58926 100644 --- a/synapse/rest/media/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -97,6 +97,12 @@ class DownloadResource(RestServlet): respond_404(request) return + ip_address = request.getClientAddress().host await self.media_repo.get_remote_media( - request, server_name, media_id, file_name, max_timeout_ms + request, + server_name, + media_id, + file_name, + max_timeout_ms, + ip_address, ) diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index fe8fbb06e4..ce511c6dce 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -104,6 +104,7 @@ class ThumbnailResource(RestServlet): respond_404(request) return + ip_address = request.getClientAddress().host remote_resp_function = ( self.thumbnail_provider.select_or_generate_remote_thumbnail if self.dynamic_thumbnails @@ -118,5 +119,6 @@ class ThumbnailResource(RestServlet): method, m_type, max_timeout_ms, + ip_address, ) self.media_repo.mark_recently_accessed(server_name, media_id) diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 1bd51ceba2..46d20ce775 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -25,7 +25,7 @@ import tempfile from binascii import unhexlify from io import BytesIO from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch from urllib import parse import attr @@ -37,9 +37,12 @@ from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor +from twisted.web.http_headers import Headers +from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.resource import Resource from synapse.api.errors import Codes, HttpResponseException +from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable @@ -59,6 +62,7 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeChannel from tests.test_utils import SMALL_PNG +from tests.unittest import override_config from tests.utils import default_config @@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase): destination: str, path: str, output_stream: BinaryIO, + download_ratelimiter: Ratelimiter, + ip_address: Any, + max_size: int, args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, - max_size: Optional[int] = None, ignore_backoff: bool = False, follow_redirects: bool = False, ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": @@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): tok=self.tok, expect_code=400, ) + + +class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.client = hs.get_federation_http_client() + self.store = hs.get_datastores().main + + def create_resource_dict(self) -> Dict[str, Resource]: + # We need to manually set the resource tree to include media, the + # default only does `/_matrix/client` APIs. + return {"/_matrix/media": self.hs.get_media_repository_resource()} + + # mock actually reading file body + def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred: + d: Deferred = defer.Deferred() + d.callback(31457280) + return d + + def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred: + d: Deferred = defer.Deferred() + d.callback(52428800) + return d + + @patch( + "synapse.http.matrixfederationclient.read_body_with_max_size", + read_body_with_max_size_30MiB, + ) + def test_download_ratelimit_default(self) -> None: + """ + Test remote media download ratelimiting against default configuration - 500MB bucket + and 87kb/second drain rate + """ + + # mock out actually sending the request, returns a 30MiB response + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = 31457280 + resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # first request should go through + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", + shorthand=False, + ) + assert channel.code == 200 + + # next 15 should go through + for i in range(15): + channel2 = self.make_request( + "GET", + f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}", + shorthand=False, + ) + assert channel2.code == 200 + + # 17th will hit ratelimit + channel3 = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx", + shorthand=False, + ) + assert channel3.code == 429 + + # however, a request from a different IP will go through + channel4 = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", + shorthand=False, + client_ip="187.233.230.159", + ) + assert channel4.code == 200 + + # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another + # 30MiB download is authorized - The last download was blocked at 503,316,480. + # The next download will be authorized when bucket hits 492,830,720 + # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760 + # needs to drain before another download will be authorized, that will take ~= + # 2 minutes (10,485,760/89,088/60) + self.reactor.pump([2.0 * 60.0]) + + # enough has drained and next request goes through + channel5 = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb", + shorthand=False, + ) + assert channel5.code == 200 + + @override_config( + { + "remote_media_download_per_second": "50M", + "remote_media_download_burst_count": "50M", + } + ) + @patch( + "synapse.http.matrixfederationclient.read_body_with_max_size", + read_body_with_max_size_50MiB, + ) + def test_download_rate_limit_config(self) -> None: + """ + Test that download rate limit config options are correctly picked up and applied + """ + + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = 52428800 + resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # first request should go through + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", + shorthand=False, + ) + assert channel.code == 200 + + # immediate second request should fail + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1", + shorthand=False, + ) + assert channel.code == 429 + + # advance half a second + self.reactor.pump([0.5]) + + # request still fails + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2", + shorthand=False, + ) + assert channel.code == 429 + + # advance another half second + self.reactor.pump([0.5]) + + # enough has drained from bucket and request is successful + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3", + shorthand=False, + ) + assert channel.code == 200 + + @patch( + "synapse.http.matrixfederationclient.read_body_with_max_size", + read_body_with_max_size_30MiB, + ) + def test_download_ratelimit_max_size_sub(self) -> None: + """ + Test that if no content-length is provided, the default max size is applied instead + """ + + # mock out actually sending the request + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = UNKNOWN_LENGTH + resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # ten requests should go through using the max size (500MB/50MB) + for i in range(10): + channel2 = self.make_request( + "GET", + f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}", + shorthand=False, + ) + assert channel2.code == 200 + + # eleventh will hit ratelimit + channel3 = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx", + shorthand=False, + ) + assert channel3.code == 429