mirror of
https://github.com/element-hq/synapse
synced 2024-10-01 21:32:40 +00:00
Merge branch 'develop' into madlittlemods/msc3575-sliding-sync-0.0.1
This commit is contained in:
commit
703cdc9c3b
16 changed files with 476 additions and 20 deletions
1
changelog.d/17254.bugfix
Normal file
1
changelog.d/17254.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix searching for users with their exact localpart whose ID includes a hyphen.
|
1
changelog.d/17256.feature
Normal file
1
changelog.d/17256.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve ratelimiting in Synapse (#17256).
|
|
@ -1946,6 +1946,24 @@ Example configuration:
|
||||||
max_image_pixels: 35M
|
max_image_pixels: 35M
|
||||||
```
|
```
|
||||||
---
|
---
|
||||||
|
### `remote_media_download_burst_count`
|
||||||
|
|
||||||
|
Remote media downloads are ratelimited using a [leaky bucket algorithm](https://en.wikipedia.org/wiki/Leaky_bucket), where a given "bucket" is keyed to the IP address of the requester when requesting remote media downloads. This configuration option sets the size of the bucket against which the size in bytes of downloads are penalized - if the bucket is full, ie a given number of bytes have already been downloaded, further downloads will be denied until the bucket drains. Defaults to 500MiB. See also `remote_media_download_per_second` which determines the rate at which the "bucket" is emptied and thus has available space to authorize new requests.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
```yaml
|
||||||
|
remote_media_download_burst_count: 200M
|
||||||
|
```
|
||||||
|
---
|
||||||
|
### `remote_media_download_per_second`
|
||||||
|
|
||||||
|
Works in conjunction with `remote_media_download_burst_count` to ratelimit remote media downloads - this configuration option determines the rate at which the "bucket" (see above) leaks in bytes per second. As requests are made to download remote media, the size of those requests in bytes is added to the bucket, and once the bucket has reached it's capacity, no more requests will be allowed until a number of bytes has "drained" from the bucket. This setting determines the rate at which bytes drain from the bucket, with the practical effect that the larger the number, the faster the bucket leaks, allowing for more bytes downloaded over a shorter period of time. Defaults to 87KiB per second. See also `remote_media_download_burst_count`.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
```yaml
|
||||||
|
remote_media_download_per_second: 40K
|
||||||
|
```
|
||||||
|
---
|
||||||
### `prevent_media_downloads_from`
|
### `prevent_media_downloads_from`
|
||||||
|
|
||||||
A list of domains to never download media from. Media from these
|
A list of domains to never download media from. Media from these
|
||||||
|
|
|
@ -218,3 +218,13 @@ class RatelimitConfig(Config):
|
||||||
"rc_media_create",
|
"rc_media_create",
|
||||||
defaults={"per_second": 10, "burst_count": 50},
|
defaults={"per_second": 10, "burst_count": 50},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.remote_media_downloads = RatelimitSettings(
|
||||||
|
key="rc_remote_media_downloads",
|
||||||
|
per_second=self.parse_size(
|
||||||
|
config.get("remote_media_download_per_second", "87K")
|
||||||
|
),
|
||||||
|
burst_count=self.parse_size(
|
||||||
|
config.get("remote_media_download_burst_count", "500M")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
|
@ -56,6 +56,7 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
UnsupportedRoomVersionError,
|
UnsupportedRoomVersionError,
|
||||||
)
|
)
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.api.room_versions import (
|
from synapse.api.room_versions import (
|
||||||
KNOWN_ROOM_VERSIONS,
|
KNOWN_ROOM_VERSIONS,
|
||||||
EventFormatVersions,
|
EventFormatVersions,
|
||||||
|
@ -1877,6 +1878,8 @@ class FederationClient(FederationBase):
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
try:
|
try:
|
||||||
return await self.transport_layer.download_media_v3(
|
return await self.transport_layer.download_media_v3(
|
||||||
|
@ -1885,6 +1888,8 @@ class FederationClient(FederationBase):
|
||||||
output_stream=output_stream,
|
output_stream=output_stream,
|
||||||
max_size=max_size,
|
max_size=max_size,
|
||||||
max_timeout_ms=max_timeout_ms,
|
max_timeout_ms=max_timeout_ms,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
# If an error is received that is due to an unrecognised endpoint,
|
# If an error is received that is due to an unrecognised endpoint,
|
||||||
|
@ -1905,6 +1910,8 @@ class FederationClient(FederationBase):
|
||||||
output_stream=output_stream,
|
output_stream=output_stream,
|
||||||
max_size=max_size,
|
max_size=max_size,
|
||||||
max_timeout_ms=max_timeout_ms,
|
max_timeout_ms=max_timeout_ms,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,7 @@ import ijson
|
||||||
|
|
||||||
from synapse.api.constants import Direction, Membership
|
from synapse.api.constants import Direction, Membership
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.api.urls import (
|
from synapse.api.urls import (
|
||||||
FEDERATION_UNSTABLE_PREFIX,
|
FEDERATION_UNSTABLE_PREFIX,
|
||||||
|
@ -819,6 +820,8 @@ class TransportLayerClient:
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
|
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
|
||||||
|
|
||||||
|
@ -834,6 +837,8 @@ class TransportLayerClient:
|
||||||
"allow_remote": "false",
|
"allow_remote": "false",
|
||||||
"timeout_ms": str(max_timeout_ms),
|
"timeout_ms": str(max_timeout_ms),
|
||||||
},
|
},
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def download_media_v3(
|
async def download_media_v3(
|
||||||
|
@ -843,6 +848,8 @@ class TransportLayerClient:
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
|
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
|
||||||
|
|
||||||
|
@ -862,6 +869,8 @@ class TransportLayerClient:
|
||||||
"allow_redirect": "true",
|
"allow_redirect": "true",
|
||||||
},
|
},
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime
|
||||||
from twisted.internet.task import Cooperator
|
from twisted.internet.task import Cooperator
|
||||||
from twisted.web.client import ResponseFailed
|
from twisted.web.client import ResponseFailed
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import IAgent, IBodyProducer, IResponse
|
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
import synapse.util.retryutils
|
import synapse.util.retryutils
|
||||||
|
@ -68,6 +68,7 @@ from synapse.api.errors import (
|
||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
||||||
from synapse.http import QuieterFileBodyProducer
|
from synapse.http import QuieterFileBodyProducer
|
||||||
from synapse.http.client import (
|
from synapse.http.client import (
|
||||||
|
@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient:
|
||||||
destination: str,
|
destination: str,
|
||||||
path: str,
|
path: str,
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
|
max_size: int,
|
||||||
args: Optional[QueryParams] = None,
|
args: Optional[QueryParams] = None,
|
||||||
retry_on_dns_fail: bool = True,
|
retry_on_dns_fail: bool = True,
|
||||||
max_size: Optional[int] = None,
|
|
||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
follow_redirects: bool = False,
|
follow_redirects: bool = False,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
|
@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient:
|
||||||
destination: The remote server to send the HTTP request to.
|
destination: The remote server to send the HTTP request to.
|
||||||
path: The HTTP path to GET.
|
path: The HTTP path to GET.
|
||||||
output_stream: File to write the response body to.
|
output_stream: File to write the response body to.
|
||||||
|
download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
|
||||||
|
requester IP
|
||||||
|
ip_address: IP address of the requester
|
||||||
|
max_size: maximum allowable size in bytes of the file
|
||||||
args: Optional dictionary used to create the query string.
|
args: Optional dictionary used to create the query string.
|
||||||
ignore_backoff: true to ignore the historical backoff data
|
ignore_backoff: true to ignore the historical backoff data
|
||||||
and try the request anyway.
|
and try the request anyway.
|
||||||
|
@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient:
|
||||||
federation whitelist
|
federation whitelist
|
||||||
RequestSendFailed: If there were problems connecting to the
|
RequestSendFailed: If there were problems connecting to the
|
||||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||||
|
SynapseError: If the requested file exceeds ratelimits
|
||||||
"""
|
"""
|
||||||
request = MatrixFederationRequest(
|
request = MatrixFederationRequest(
|
||||||
method="GET", destination=destination, path=path, query=args
|
method="GET", destination=destination, path=path, query=args
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check for a minimum balance of 1MiB in ratelimiter before initiating request
|
||||||
|
send_req, _ = await download_ratelimiter.can_do_action(
|
||||||
|
requester=None, key=ip_address, n_actions=1048576, update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not send_req:
|
||||||
|
msg = "Requested file size exceeds ratelimits"
|
||||||
|
logger.warning(
|
||||||
|
"{%s} [%s] %s",
|
||||||
|
request.txn_id,
|
||||||
|
request.destination,
|
||||||
|
msg,
|
||||||
|
)
|
||||||
|
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
|
||||||
|
|
||||||
response = await self._send_request(
|
response = await self._send_request(
|
||||||
request,
|
request,
|
||||||
retry_on_dns_fail=retry_on_dns_fail,
|
retry_on_dns_fail=retry_on_dns_fail,
|
||||||
|
@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient:
|
||||||
|
|
||||||
headers = dict(response.headers.getAllRawHeaders())
|
headers = dict(response.headers.getAllRawHeaders())
|
||||||
|
|
||||||
|
expected_size = response.length
|
||||||
|
# if we don't get an expected length then use the max length
|
||||||
|
if expected_size == UNKNOWN_LENGTH:
|
||||||
|
expected_size = max_size
|
||||||
|
logger.debug(
|
||||||
|
f"File size unknown, assuming file is max allowable size: {max_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
read_body, _ = await download_ratelimiter.can_do_action(
|
||||||
|
requester=None,
|
||||||
|
key=ip_address,
|
||||||
|
n_actions=expected_size,
|
||||||
|
)
|
||||||
|
if not read_body:
|
||||||
|
msg = "Requested file size exceeds ratelimits"
|
||||||
|
logger.warning(
|
||||||
|
"{%s} [%s] %s",
|
||||||
|
request.txn_id,
|
||||||
|
request.destination,
|
||||||
|
msg,
|
||||||
|
)
|
||||||
|
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
d = read_body_with_max_size(response, output_stream, max_size)
|
# add a byte of headroom to max size as function errs at >=
|
||||||
|
d = read_body_with_max_size(response, output_stream, expected_size + 1)
|
||||||
d.addTimeout(self.default_timeout_seconds, self.reactor)
|
d.addTimeout(self.default_timeout_seconds, self.reactor)
|
||||||
length = await make_deferred_yieldable(d)
|
length = await make_deferred_yieldable(d)
|
||||||
except BodyExceededMaxSize:
|
except BodyExceededMaxSize:
|
||||||
msg = "Requested file is too large > %r bytes" % (max_size,)
|
msg = "Requested file is too large > %r bytes" % (expected_size,)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"{%s} [%s] %s",
|
"{%s} [%s] %s",
|
||||||
request.txn_id,
|
request.txn_id,
|
||||||
|
|
|
@ -42,6 +42,7 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
cs_error,
|
cs_error,
|
||||||
)
|
)
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.config.repository import ThumbnailRequirement
|
from synapse.config.repository import ThumbnailRequirement
|
||||||
from synapse.http.server import respond_with_json
|
from synapse.http.server import respond_with_json
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
|
@ -111,6 +112,12 @@ class MediaRepository:
|
||||||
)
|
)
|
||||||
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
||||||
|
|
||||||
|
self.download_ratelimiter = Ratelimiter(
|
||||||
|
store=hs.get_storage_controllers().main,
|
||||||
|
clock=hs.get_clock(),
|
||||||
|
cfg=hs.config.ratelimiting.remote_media_downloads,
|
||||||
|
)
|
||||||
|
|
||||||
# List of StorageProviders where we should search for media and
|
# List of StorageProviders where we should search for media and
|
||||||
# potentially upload to.
|
# potentially upload to.
|
||||||
storage_providers = []
|
storage_providers = []
|
||||||
|
@ -464,6 +471,7 @@ class MediaRepository:
|
||||||
media_id: str,
|
media_id: str,
|
||||||
name: Optional[str],
|
name: Optional[str],
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
ip_address: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Respond to requests for remote media.
|
"""Respond to requests for remote media.
|
||||||
|
|
||||||
|
@ -475,6 +483,7 @@ class MediaRepository:
|
||||||
the filename in the Content-Disposition header of the response.
|
the filename in the Content-Disposition header of the response.
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
ip_address: the IP address of the requester
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolves once a response has successfully been written to request
|
Resolves once a response has successfully been written to request
|
||||||
|
@ -500,7 +509,11 @@ class MediaRepository:
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
async with self.remote_media_linearizer.queue(key):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
responder, media_info = await self._get_remote_media_impl(
|
responder, media_info = await self._get_remote_media_impl(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name,
|
||||||
|
media_id,
|
||||||
|
max_timeout_ms,
|
||||||
|
self.download_ratelimiter,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We deliberately stream the file outside the lock
|
# We deliberately stream the file outside the lock
|
||||||
|
@ -517,7 +530,7 @@ class MediaRepository:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
|
||||||
async def get_remote_media_info(
|
async def get_remote_media_info(
|
||||||
self, server_name: str, media_id: str, max_timeout_ms: int
|
self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
|
||||||
) -> RemoteMedia:
|
) -> RemoteMedia:
|
||||||
"""Gets the media info associated with the remote file, downloading
|
"""Gets the media info associated with the remote file, downloading
|
||||||
if necessary.
|
if necessary.
|
||||||
|
@ -527,6 +540,7 @@ class MediaRepository:
|
||||||
media_id: The media ID of the content (as defined by the remote server).
|
media_id: The media ID of the content (as defined by the remote server).
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
ip_address: IP address of the requester
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The media info of the file
|
The media info of the file
|
||||||
|
@ -542,7 +556,11 @@ class MediaRepository:
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
async with self.remote_media_linearizer.queue(key):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
responder, media_info = await self._get_remote_media_impl(
|
responder, media_info = await self._get_remote_media_impl(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name,
|
||||||
|
media_id,
|
||||||
|
max_timeout_ms,
|
||||||
|
self.download_ratelimiter,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure we actually use the responder so that it releases resources
|
# Ensure we actually use the responder so that it releases resources
|
||||||
|
@ -553,7 +571,12 @@ class MediaRepository:
|
||||||
return media_info
|
return media_info
|
||||||
|
|
||||||
async def _get_remote_media_impl(
|
async def _get_remote_media_impl(
|
||||||
self, server_name: str, media_id: str, max_timeout_ms: int
|
self,
|
||||||
|
server_name: str,
|
||||||
|
media_id: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[Optional[Responder], RemoteMedia]:
|
) -> Tuple[Optional[Responder], RemoteMedia]:
|
||||||
"""Looks for media in local cache, if not there then attempt to
|
"""Looks for media in local cache, if not there then attempt to
|
||||||
download from remote server.
|
download from remote server.
|
||||||
|
@ -564,6 +587,9 @@ class MediaRepository:
|
||||||
remote server).
|
remote server).
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
|
||||||
|
requester IP.
|
||||||
|
ip_address: the IP address of the requester
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of responder and the media info of the file.
|
A tuple of responder and the media info of the file.
|
||||||
|
@ -596,7 +622,7 @@ class MediaRepository:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
media_info = await self._download_remote_file(
|
media_info = await self._download_remote_file(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
|
||||||
)
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
raise
|
raise
|
||||||
|
@ -630,6 +656,8 @@ class MediaRepository:
|
||||||
server_name: str,
|
server_name: str,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> RemoteMedia:
|
) -> RemoteMedia:
|
||||||
"""Attempt to download the remote file from the given server name,
|
"""Attempt to download the remote file from the given server name,
|
||||||
using the given file_id as the local id.
|
using the given file_id as the local id.
|
||||||
|
@ -641,6 +669,9 @@ class MediaRepository:
|
||||||
locally generated.
|
locally generated.
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
|
||||||
|
requester IP
|
||||||
|
ip_address: the IP address of the requester
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The media info of the file.
|
The media info of the file.
|
||||||
|
@ -658,6 +689,8 @@ class MediaRepository:
|
||||||
output_stream=f,
|
output_stream=f,
|
||||||
max_size=self.max_upload_size,
|
max_size=self.max_upload_size,
|
||||||
max_timeout_ms=max_timeout_ms,
|
max_timeout_ms=max_timeout_ms,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
except RequestSendFailed as e:
|
except RequestSendFailed as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -359,9 +359,10 @@ class ThumbnailProvider:
|
||||||
desired_method: str,
|
desired_method: str,
|
||||||
desired_type: str,
|
desired_type: str,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
ip_address: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
media_info = await self.media_repo.get_remote_media_info(
|
media_info = await self.media_repo.get_remote_media_info(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name, media_id, max_timeout_ms, ip_address
|
||||||
)
|
)
|
||||||
if not media_info:
|
if not media_info:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
@ -422,12 +423,13 @@ class ThumbnailProvider:
|
||||||
method: str,
|
method: str,
|
||||||
m_type: str,
|
m_type: str,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
ip_address: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Don't download the whole remote file
|
# TODO: Don't download the whole remote file
|
||||||
# We should proxy the thumbnail from the remote server instead of
|
# We should proxy the thumbnail from the remote server instead of
|
||||||
# downloading the remote file and generating our own thumbnails.
|
# downloading the remote file and generating our own thumbnails.
|
||||||
media_info = await self.media_repo.get_remote_media_info(
|
media_info = await self.media_repo.get_remote_media_info(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name, media_id, max_timeout_ms, ip_address
|
||||||
)
|
)
|
||||||
if not media_info:
|
if not media_info:
|
||||||
return
|
return
|
||||||
|
|
|
@ -174,6 +174,7 @@ class UnstableThumbnailResource(RestServlet):
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
ip_address = request.getClientAddress().host
|
||||||
remote_resp_function = (
|
remote_resp_function = (
|
||||||
self.thumbnailer.select_or_generate_remote_thumbnail
|
self.thumbnailer.select_or_generate_remote_thumbnail
|
||||||
if self.dynamic_thumbnails
|
if self.dynamic_thumbnails
|
||||||
|
@ -188,6 +189,7 @@ class UnstableThumbnailResource(RestServlet):
|
||||||
method,
|
method,
|
||||||
m_type,
|
m_type,
|
||||||
max_timeout_ms,
|
max_timeout_ms,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
|
|
|
@ -97,6 +97,12 @@ class DownloadResource(RestServlet):
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
ip_address = request.getClientAddress().host
|
||||||
await self.media_repo.get_remote_media(
|
await self.media_repo.get_remote_media(
|
||||||
request, server_name, media_id, file_name, max_timeout_ms
|
request,
|
||||||
|
server_name,
|
||||||
|
media_id,
|
||||||
|
file_name,
|
||||||
|
max_timeout_ms,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
|
|
|
@ -104,6 +104,7 @@ class ThumbnailResource(RestServlet):
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
ip_address = request.getClientAddress().host
|
||||||
remote_resp_function = (
|
remote_resp_function = (
|
||||||
self.thumbnail_provider.select_or_generate_remote_thumbnail
|
self.thumbnail_provider.select_or_generate_remote_thumbnail
|
||||||
if self.dynamic_thumbnails
|
if self.dynamic_thumbnails
|
||||||
|
@ -118,5 +119,6 @@ class ThumbnailResource(RestServlet):
|
||||||
method,
|
method,
|
||||||
m_type,
|
m_type,
|
||||||
max_timeout_ms,
|
max_timeout_ms,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
|
@ -1281,7 +1281,7 @@ def _parse_words_with_regex(search_term: str) -> List[str]:
|
||||||
Break down search term into words, when we don't have ICU available.
|
Break down search term into words, when we don't have ICU available.
|
||||||
See: `_parse_words`
|
See: `_parse_words`
|
||||||
"""
|
"""
|
||||||
return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
return re.findall(r"([\w-]+)", search_term, re.UNICODE)
|
||||||
|
|
||||||
|
|
||||||
def _parse_words_with_icu(search_term: str) -> List[str]:
|
def _parse_words_with_icu(search_term: str) -> List[str]:
|
||||||
|
@ -1303,15 +1303,69 @@ def _parse_words_with_icu(search_term: str) -> List[str]:
|
||||||
if j < 0:
|
if j < 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
result = search_term[i:j]
|
# We want to make sure that we split on `@` and `:` specifically, as
|
||||||
|
# they occur in user IDs.
|
||||||
|
for result in re.split(r"[@:]+", search_term[i:j]):
|
||||||
|
results.append(result.strip())
|
||||||
|
|
||||||
|
i = j
|
||||||
|
|
||||||
|
# libicu will break up words that have punctuation in them, but to handle
|
||||||
|
# cases where user IDs have '-', '.' and '_' in them we want to *not* break
|
||||||
|
# those into words and instead allow the DB to tokenise them how it wants.
|
||||||
|
#
|
||||||
|
# In particular, user-71 in postgres gets tokenised to "user, -71", and this
|
||||||
|
# will not match a query for "user, 71".
|
||||||
|
new_results: List[str] = []
|
||||||
|
i = 0
|
||||||
|
while i < len(results):
|
||||||
|
curr = results[i]
|
||||||
|
|
||||||
|
prev = None
|
||||||
|
next = None
|
||||||
|
if i > 0:
|
||||||
|
prev = results[i - 1]
|
||||||
|
if i + 1 < len(results):
|
||||||
|
next = results[i + 1]
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
# libicu considers spaces and punctuation between words as words, but we don't
|
# libicu considers spaces and punctuation between words as words, but we don't
|
||||||
# want to include those in results as they would result in syntax errors in SQL
|
# want to include those in results as they would result in syntax errors in SQL
|
||||||
# queries (e.g. "foo bar" would result in the search query including "foo & &
|
# queries (e.g. "foo bar" would result in the search query including "foo & &
|
||||||
# bar").
|
# bar").
|
||||||
if len(re.findall(r"([\w\-]+)", result, re.UNICODE)):
|
if not curr:
|
||||||
results.append(result)
|
continue
|
||||||
|
|
||||||
i = j
|
if curr in ["-", ".", "_"]:
|
||||||
|
prefix = ""
|
||||||
|
suffix = ""
|
||||||
|
|
||||||
return results
|
# Check if the next item is a word, and if so use it as the suffix.
|
||||||
|
# We check for if its a word as we don't want to concatenate
|
||||||
|
# multiple punctuation marks.
|
||||||
|
if next is not None and re.match(r"\w", next):
|
||||||
|
suffix = next
|
||||||
|
i += 1 # We're using next, so we skip it in the outer loop.
|
||||||
|
else:
|
||||||
|
# We want to avoid creating terms like "user-", as we should
|
||||||
|
# strip trailing punctuation.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if prev and re.match(r"\w", prev) and new_results:
|
||||||
|
prefix = new_results[-1]
|
||||||
|
new_results.pop()
|
||||||
|
|
||||||
|
# We might not have a prefix here, but that's fine as we want to
|
||||||
|
# ensure that we don't strip preceding punctuation e.g. '-71'
|
||||||
|
# shouldn't be converted to '71'.
|
||||||
|
|
||||||
|
new_results.append(f"{prefix}{curr}{suffix}")
|
||||||
|
continue
|
||||||
|
elif not re.match(r"\w", curr):
|
||||||
|
# Ignore other punctuation
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_results.append(curr)
|
||||||
|
|
||||||
|
return new_results
|
||||||
|
|
|
@ -1061,6 +1061,45 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
{alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)},
|
{alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_search_punctuation(self) -> None:
|
||||||
|
"""Test that you can search for a user that includes punctuation"""
|
||||||
|
|
||||||
|
searching_user = self.register_user("searcher", "password")
|
||||||
|
searching_user_tok = self.login("searcher", "password")
|
||||||
|
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
searching_user,
|
||||||
|
room_version=RoomVersions.V1.identifier,
|
||||||
|
tok=searching_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We want to test searching for users of the form e.g. "user-1", with
|
||||||
|
# various punctuation. We also test both where the prefix is numeric and
|
||||||
|
# alphanumeric, as e.g. postgres tokenises "user-1" as "user" and "-1".
|
||||||
|
i = 1
|
||||||
|
for char in ["-", ".", "_"]:
|
||||||
|
for use_numeric in [False, True]:
|
||||||
|
if use_numeric:
|
||||||
|
prefix1 = f"{i}"
|
||||||
|
prefix2 = f"{i+1}"
|
||||||
|
else:
|
||||||
|
prefix1 = f"a{i}"
|
||||||
|
prefix2 = f"a{i+1}"
|
||||||
|
|
||||||
|
local_user_1 = self.register_user(f"user{char}{prefix1}", "password")
|
||||||
|
local_user_2 = self.register_user(f"user{char}{prefix2}", "password")
|
||||||
|
|
||||||
|
self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
|
||||||
|
self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
|
||||||
|
|
||||||
|
results = self.get_success(
|
||||||
|
self.handler.search_users(searching_user, local_user_1, 20)
|
||||||
|
)["results"]
|
||||||
|
received_user_id_ordering = [result["user_id"] for result in results]
|
||||||
|
self.assertSequenceEqual(received_user_id_ordering[:1], [local_user_1])
|
||||||
|
|
||||||
|
i += 2
|
||||||
|
|
||||||
|
|
||||||
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
|
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
|
|
|
@ -25,7 +25,7 @@ import tempfile
|
||||||
from binascii import unhexlify
|
from binascii import unhexlify
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
|
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -37,9 +37,12 @@ from twisted.internet import defer
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
from twisted.web.http_headers import Headers
|
||||||
|
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.api.errors import Codes, HttpResponseException
|
from synapse.api.errors import Codes, HttpResponseException
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.http.types import QueryParams
|
from synapse.http.types import QueryParams
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
@ -59,6 +62,7 @@ from synapse.util import Clock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeChannel
|
from tests.server import FakeChannel
|
||||||
from tests.test_utils import SMALL_PNG
|
from tests.test_utils import SMALL_PNG
|
||||||
|
from tests.unittest import override_config
|
||||||
from tests.utils import default_config
|
from tests.utils import default_config
|
||||||
|
|
||||||
|
|
||||||
|
@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
destination: str,
|
destination: str,
|
||||||
path: str,
|
path: str,
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: Any,
|
||||||
|
max_size: int,
|
||||||
args: Optional[QueryParams] = None,
|
args: Optional[QueryParams] = None,
|
||||||
retry_on_dns_fail: bool = True,
|
retry_on_dns_fail: bool = True,
|
||||||
max_size: Optional[int] = None,
|
|
||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
follow_redirects: bool = False,
|
follow_redirects: bool = False,
|
||||||
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
|
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
|
||||||
|
@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||||
tok=self.tok,
|
tok=self.tok,
|
||||||
expect_code=400,
|
expect_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
|
||||||
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
config = self.default_config()
|
||||||
|
|
||||||
|
self.storage_path = self.mktemp()
|
||||||
|
self.media_store_path = self.mktemp()
|
||||||
|
os.mkdir(self.storage_path)
|
||||||
|
os.mkdir(self.media_store_path)
|
||||||
|
config["media_store_path"] = self.media_store_path
|
||||||
|
|
||||||
|
provider_config = {
|
||||||
|
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
|
||||||
|
"store_local": True,
|
||||||
|
"store_synchronous": False,
|
||||||
|
"store_remote": True,
|
||||||
|
"config": {"directory": self.storage_path},
|
||||||
|
}
|
||||||
|
|
||||||
|
config["media_storage_providers"] = [provider_config]
|
||||||
|
|
||||||
|
return self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.repo = hs.get_media_repository()
|
||||||
|
self.client = hs.get_federation_http_client()
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
|
# We need to manually set the resource tree to include media, the
|
||||||
|
# default only does `/_matrix/client` APIs.
|
||||||
|
return {"/_matrix/media": self.hs.get_media_repository_resource()}
|
||||||
|
|
||||||
|
# mock actually reading file body
|
||||||
|
def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
|
||||||
|
d: Deferred = defer.Deferred()
|
||||||
|
d.callback(31457280)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
|
||||||
|
d: Deferred = defer.Deferred()
|
||||||
|
d.callback(52428800)
|
||||||
|
return d
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||||
|
read_body_with_max_size_30MiB,
|
||||||
|
)
|
||||||
|
def test_download_ratelimit_default(self) -> None:
|
||||||
|
"""
|
||||||
|
Test remote media download ratelimiting against default configuration - 500MB bucket
|
||||||
|
and 87kb/second drain rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mock out actually sending the request, returns a 30MiB response
|
||||||
|
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||||
|
resp = MagicMock(spec=IResponse)
|
||||||
|
resp.code = 200
|
||||||
|
resp.length = 31457280
|
||||||
|
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||||
|
resp.phrase = b"OK"
|
||||||
|
return resp
|
||||||
|
|
||||||
|
self.client._send_request = _send_request # type: ignore
|
||||||
|
|
||||||
|
# first request should go through
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 200
|
||||||
|
|
||||||
|
# next 15 should go through
|
||||||
|
for i in range(15):
|
||||||
|
channel2 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel2.code == 200
|
||||||
|
|
||||||
|
# 17th will hit ratelimit
|
||||||
|
channel3 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel3.code == 429
|
||||||
|
|
||||||
|
# however, a request from a different IP will go through
|
||||||
|
channel4 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||||
|
shorthand=False,
|
||||||
|
client_ip="187.233.230.159",
|
||||||
|
)
|
||||||
|
assert channel4.code == 200
|
||||||
|
|
||||||
|
# at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
|
||||||
|
# 30MiB download is authorized - The last download was blocked at 503,316,480.
|
||||||
|
# The next download will be authorized when bucket hits 492,830,720
|
||||||
|
# (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
|
||||||
|
# needs to drain before another download will be authorized, that will take ~=
|
||||||
|
# 2 minutes (10,485,760/89,088/60)
|
||||||
|
self.reactor.pump([2.0 * 60.0])
|
||||||
|
|
||||||
|
# enough has drained and next request goes through
|
||||||
|
channel5 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel5.code == 200
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"remote_media_download_per_second": "50M",
|
||||||
|
"remote_media_download_burst_count": "50M",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@patch(
|
||||||
|
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||||
|
read_body_with_max_size_50MiB,
|
||||||
|
)
|
||||||
|
def test_download_rate_limit_config(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that download rate limit config options are correctly picked up and applied
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||||
|
resp = MagicMock(spec=IResponse)
|
||||||
|
resp.code = 200
|
||||||
|
resp.length = 52428800
|
||||||
|
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||||
|
resp.phrase = b"OK"
|
||||||
|
return resp
|
||||||
|
|
||||||
|
self.client._send_request = _send_request # type: ignore
|
||||||
|
|
||||||
|
# first request should go through
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 200
|
||||||
|
|
||||||
|
# immediate second request should fail
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 429
|
||||||
|
|
||||||
|
# advance half a second
|
||||||
|
self.reactor.pump([0.5])
|
||||||
|
|
||||||
|
# request still fails
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 429
|
||||||
|
|
||||||
|
# advance another half second
|
||||||
|
self.reactor.pump([0.5])
|
||||||
|
|
||||||
|
# enough has drained from bucket and request is successful
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 200
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||||
|
read_body_with_max_size_30MiB,
|
||||||
|
)
|
||||||
|
def test_download_ratelimit_max_size_sub(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that if no content-length is provided, the default max size is applied instead
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mock out actually sending the request
|
||||||
|
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||||
|
resp = MagicMock(spec=IResponse)
|
||||||
|
resp.code = 200
|
||||||
|
resp.length = UNKNOWN_LENGTH
|
||||||
|
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||||
|
resp.phrase = b"OK"
|
||||||
|
return resp
|
||||||
|
|
||||||
|
self.client._send_request = _send_request # type: ignore
|
||||||
|
|
||||||
|
# ten requests should go through using the max size (500MB/50MB)
|
||||||
|
for i in range(10):
|
||||||
|
channel2 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel2.code == 200
|
||||||
|
|
||||||
|
# eleventh will hit ratelimit
|
||||||
|
channel3 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel3.code == 429
|
||||||
|
|
|
@ -711,6 +711,10 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assertEqual(_parse_words_with_icu("user-1"), ["user-1"])
|
||||||
|
self.assertEqual(_parse_words_with_icu("user-ab"), ["user-ab"])
|
||||||
|
self.assertEqual(_parse_words_with_icu("user.--1"), ["user", "-1"])
|
||||||
|
|
||||||
def test_regex_word_boundary_punctuation(self) -> None:
|
def test_regex_word_boundary_punctuation(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the behaviour of punctuation with the non-ICU tokeniser
|
Tests the behaviour of punctuation with the non-ICU tokeniser
|
||||||
|
|
Loading…
Reference in a new issue