mirror of
https://github.com/element-hq/synapse
synced 2024-10-05 18:52:45 +00:00
Merge branch 'develop' into madlittlemods/sliding-sync-room-data
This commit is contained in:
commit
126ce1e7ac
48 changed files with 1998 additions and 115 deletions
1
changelog.d/17365.feature
Normal file
1
changelog.d/17365.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md) by adding _matrix/client/v1/media/download endpoint.
|
1
changelog.d/17386.bugfix
Normal file
1
changelog.d/17386.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0.
|
1
changelog.d/17389.misc
Normal file
1
changelog.d/17389.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix building debian package for debian sid.
|
1
changelog.d/17390.misc
Normal file
1
changelog.d/17390.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix building debian packages on non-clean checkouts.
|
1
changelog.d/17391.bugfix
Normal file
1
changelog.d/17391.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0.
|
|
@ -73,6 +73,8 @@ RUN apt-get update -qq -o Acquire::Languages=none \
|
|||
curl \
|
||||
debhelper \
|
||||
devscripts \
|
||||
# Required for building cffi from source.
|
||||
libffi-dev \
|
||||
libsystemd-dev \
|
||||
lsb-release \
|
||||
pkg-config \
|
||||
|
|
|
@ -11,6 +11,9 @@ DIST=$(cut -d ':' -f2 <<< "${distro:?}")
|
|||
cp -aT /synapse/source /synapse/build
|
||||
cd /synapse/build
|
||||
|
||||
# Delete any existing `.so` files to ensure a clean build.
|
||||
rm -f /synapse/build/synapse/*.so
|
||||
|
||||
# if this is a prerelease, set the Section accordingly.
|
||||
#
|
||||
# When the package is later added to the package repo, reprepro will use the
|
||||
|
|
|
@ -117,7 +117,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
|
|||
},
|
||||
"media_repository": {
|
||||
"app": "synapse.app.generic_worker",
|
||||
"listener_resources": ["media"],
|
||||
"listener_resources": ["media", "client"],
|
||||
"endpoint_patterns": [
|
||||
"^/_matrix/media/",
|
||||
"^/_synapse/admin/v1/purge_media_cache$",
|
||||
|
@ -125,6 +125,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
|
|||
"^/_synapse/admin/v1/user/.*/media.*$",
|
||||
"^/_synapse/admin/v1/media/.*$",
|
||||
"^/_synapse/admin/v1/quarantine_media/.*$",
|
||||
"^/_matrix/client/v1/media/.*$",
|
||||
],
|
||||
# The first configured media worker will run the media background jobs
|
||||
"shared_extra_conf": {
|
||||
|
|
|
@ -117,6 +117,19 @@ each upgrade are complete before moving on to the next upgrade, to avoid
|
|||
stacking them up. You can monitor the currently running background updates with
|
||||
[the Admin API](usage/administration/admin_api/background_updates.html#status).
|
||||
|
||||
# Upgrading to v1.111.0
|
||||
|
||||
## New worker endpoints for authenticated client media
|
||||
|
||||
[Media repository workers](./workers.md#synapseappmedia_repository) handling
|
||||
Media APIs can now handle the following endpoint pattern:
|
||||
|
||||
```
|
||||
^/_matrix/client/v1/media/.*$
|
||||
```
|
||||
|
||||
Please update your reverse proxy configuration.
|
||||
|
||||
# Upgrading to v1.106.0
|
||||
|
||||
## Minimum supported Rust version
|
||||
|
|
|
@ -739,6 +739,7 @@ An example for a federation sender instance:
|
|||
Handles the media repository. It can handle all endpoints starting with:
|
||||
|
||||
/_matrix/media/
|
||||
/_matrix/client/v1/media/
|
||||
|
||||
... and the following regular expressions matching media-specific administration APIs:
|
||||
|
||||
|
|
3
mypy.ini
3
mypy.ini
|
@ -96,3 +96,6 @@ ignore_missing_imports = True
|
|||
# https://github.com/twisted/treq/pull/366
|
||||
[mypy-treq.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-multipart.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
18
poetry.lock
generated
18
poetry.lock
generated
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
|
@ -2039,6 +2039,20 @@ files = [
|
|||
[package.dependencies]
|
||||
six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.9"
|
||||
description = "A streaming multipart parser for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
|
||||
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2022.7.1"
|
||||
|
@ -3187,4 +3201,4 @@ user-search = ["pyicu"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8.0"
|
||||
content-hash = "107c8fb5c67360340854fbdba3c085fc5f9c7be24bcb592596a914eea621faea"
|
||||
content-hash = "e8d5806e10eb69bc06900fde18ea3df38f38490ab6baa73fe4a563dfb6abacba"
|
||||
|
|
|
@ -224,6 +224,8 @@ pydantic = ">=1.7.4, <3"
|
|||
# needed.
|
||||
setuptools_rust = ">=1.3"
|
||||
|
||||
# This is used for parsing multipart responses
|
||||
python-multipart = ">=0.0.9"
|
||||
|
||||
# Optional Dependencies
|
||||
# ---------------------
|
||||
|
|
|
@ -130,7 +130,8 @@ class Ratelimiter:
|
|||
Overrides the value set during instantiation if set.
|
||||
burst_count: How many actions that can be performed before being limited.
|
||||
Overrides the value set during instantiation if set.
|
||||
update: Whether to count this check as performing the action
|
||||
update: Whether to count this check as performing the action. If the action
|
||||
cannot be performed, the user's action count is not incremented at all.
|
||||
n_actions: The number of times the user wants to do this action. If the user
|
||||
cannot do all of the actions, the user's action count is not incremented
|
||||
at all.
|
||||
|
|
|
@ -1871,6 +1871,52 @@ class FederationClient(FederationBase):
|
|||
|
||||
return filtered_statuses, filtered_failures
|
||||
|
||||
async def federation_download_media(
|
||||
self,
|
||||
destination: str,
|
||||
media_id: str,
|
||||
output_stream: BinaryIO,
|
||||
max_size: int,
|
||||
max_timeout_ms: int,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: str,
|
||||
) -> Union[
|
||||
Tuple[int, Dict[bytes, List[bytes]], bytes],
|
||||
Tuple[int, Dict[bytes, List[bytes]]],
|
||||
]:
|
||||
try:
|
||||
return await self.transport_layer.federation_download_media(
|
||||
destination,
|
||||
media_id,
|
||||
output_stream=output_stream,
|
||||
max_size=max_size,
|
||||
max_timeout_ms=max_timeout_ms,
|
||||
download_ratelimiter=download_ratelimiter,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
# If an error is received that is due to an unrecognised endpoint,
|
||||
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
|
||||
# and raise.
|
||||
if not is_unknown_endpoint(e):
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
|
||||
destination,
|
||||
media_id,
|
||||
)
|
||||
|
||||
return await self.transport_layer.download_media_v3(
|
||||
destination,
|
||||
media_id,
|
||||
output_stream=output_stream,
|
||||
max_size=max_size,
|
||||
max_timeout_ms=max_timeout_ms,
|
||||
download_ratelimiter=download_ratelimiter,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
async def download_media(
|
||||
self,
|
||||
destination: str,
|
||||
|
|
|
@ -824,7 +824,6 @@ class TransportLayerClient:
|
|||
ip_address: str,
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
|
||||
|
||||
return await self.client.get_file(
|
||||
destination,
|
||||
path,
|
||||
|
@ -852,7 +851,6 @@ class TransportLayerClient:
|
|||
ip_address: str,
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
|
||||
|
||||
return await self.client.get_file(
|
||||
destination,
|
||||
path,
|
||||
|
@ -873,6 +871,29 @@ class TransportLayerClient:
|
|||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
async def federation_download_media(
|
||||
self,
|
||||
destination: str,
|
||||
media_id: str,
|
||||
output_stream: BinaryIO,
|
||||
max_size: int,
|
||||
max_timeout_ms: int,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: str,
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
|
||||
path = f"/_matrix/federation/v1/media/download/{media_id}"
|
||||
return await self.client.federation_get_file(
|
||||
destination,
|
||||
path,
|
||||
output_stream=output_stream,
|
||||
max_size=max_size,
|
||||
args={
|
||||
"timeout_ms": str(max_timeout_ms),
|
||||
},
|
||||
download_ratelimiter=download_ratelimiter,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
|
||||
def _create_path(federation_prefix: str, path: str, *args: str) -> str:
|
||||
"""
|
||||
|
|
|
@ -32,8 +32,8 @@ from synapse.federation.transport.server._base import (
|
|||
from synapse.federation.transport.server.federation import (
|
||||
FEDERATION_SERVLET_CLASSES,
|
||||
FederationAccountStatusServlet,
|
||||
FederationMediaDownloadServlet,
|
||||
FederationUnstableClientKeysClaimServlet,
|
||||
FederationUnstableMediaDownloadServlet,
|
||||
)
|
||||
from synapse.http.server import HttpServer, JsonResource
|
||||
from synapse.http.servlet import (
|
||||
|
@ -316,11 +316,8 @@ def register_servlets(
|
|||
):
|
||||
continue
|
||||
|
||||
if servletclass == FederationUnstableMediaDownloadServlet:
|
||||
if (
|
||||
not hs.config.server.enable_media_repo
|
||||
or not hs.config.experimental.msc3916_authenticated_media_enabled
|
||||
):
|
||||
if servletclass == FederationMediaDownloadServlet:
|
||||
if not hs.config.server.enable_media_repo:
|
||||
continue
|
||||
|
||||
servletclass(
|
||||
|
|
|
@ -362,7 +362,7 @@ class BaseFederationServlet:
|
|||
return None
|
||||
if (
|
||||
func.__self__.__class__.__name__ # type: ignore
|
||||
== "FederationUnstableMediaDownloadServlet"
|
||||
== "FederationMediaDownloadServlet"
|
||||
):
|
||||
response = await func(
|
||||
origin, content, request, *args, **kwargs
|
||||
|
@ -374,7 +374,7 @@ class BaseFederationServlet:
|
|||
else:
|
||||
if (
|
||||
func.__self__.__class__.__name__ # type: ignore
|
||||
== "FederationUnstableMediaDownloadServlet"
|
||||
== "FederationMediaDownloadServlet"
|
||||
):
|
||||
response = await func(
|
||||
origin, content, request, *args, **kwargs
|
||||
|
|
|
@ -790,7 +790,7 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
|
|||
return 200, {"account_statuses": statuses, "failures": failures}
|
||||
|
||||
|
||||
class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
|
||||
class FederationMediaDownloadServlet(BaseFederationServerServlet):
|
||||
"""
|
||||
Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
|
||||
a multipart/mixed response consisting of a JSON object and the requested media
|
||||
|
@ -798,7 +798,6 @@ class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
|
|||
"""
|
||||
|
||||
PATH = "/media/download/(?P<media_id>[^/]*)"
|
||||
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
|
||||
RATELIMIT = True
|
||||
|
||||
def __init__(
|
||||
|
@ -858,5 +857,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
|
|||
FederationV1SendKnockServlet,
|
||||
FederationMakeKnockServlet,
|
||||
FederationAccountStatusServlet,
|
||||
FederationUnstableMediaDownloadServlet,
|
||||
FederationMediaDownloadServlet,
|
||||
)
|
||||
|
|
|
@ -35,6 +35,8 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
import multipart
|
||||
import treq
|
||||
from canonicaljson import encode_canonical_json
|
||||
from netaddr import AddrFormatError, IPAddress, IPSet
|
||||
|
@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
|||
self._maybe_fail()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True)
|
||||
class MultipartResponse:
|
||||
"""
|
||||
A small class to hold parsed values of a multipart response.
|
||||
"""
|
||||
|
||||
json: bytes = b"{}"
|
||||
length: Optional[int] = None
|
||||
content_type: Optional[bytes] = None
|
||||
disposition: Optional[bytes] = None
|
||||
url: Optional[bytes] = None
|
||||
|
||||
|
||||
class _MultipartParserProtocol(protocol.Protocol):
|
||||
"""
|
||||
Protocol to read and parse a MSC3916 multipart/mixed response
|
||||
"""
|
||||
|
||||
transport: Optional[ITCPTransport] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: ByteWriteable,
|
||||
deferred: defer.Deferred,
|
||||
boundary: str,
|
||||
max_length: Optional[int],
|
||||
) -> None:
|
||||
self.stream = stream
|
||||
self.deferred = deferred
|
||||
self.boundary = boundary
|
||||
self.max_length = max_length
|
||||
self.parser = None
|
||||
self.multipart_response = MultipartResponse()
|
||||
self.has_redirect = False
|
||||
self.in_json = False
|
||||
self.json_done = False
|
||||
self.file_length = 0
|
||||
self.total_length = 0
|
||||
self.in_disposition = False
|
||||
self.in_content_type = False
|
||||
|
||||
def dataReceived(self, incoming_data: bytes) -> None:
|
||||
if self.deferred.called:
|
||||
return
|
||||
|
||||
# we don't have a parser yet, instantiate it
|
||||
if not self.parser:
|
||||
|
||||
def on_header_field(data: bytes, start: int, end: int) -> None:
|
||||
if data[start:end] == b"Location":
|
||||
self.has_redirect = True
|
||||
if data[start:end] == b"Content-Disposition":
|
||||
self.in_disposition = True
|
||||
if data[start:end] == b"Content-Type":
|
||||
self.in_content_type = True
|
||||
|
||||
def on_header_value(data: bytes, start: int, end: int) -> None:
|
||||
# the first header should be content-type for application/json
|
||||
if not self.in_json and not self.json_done:
|
||||
assert data[start:end] == b"application/json"
|
||||
self.in_json = True
|
||||
elif self.has_redirect:
|
||||
self.multipart_response.url = data[start:end]
|
||||
elif self.in_content_type:
|
||||
self.multipart_response.content_type = data[start:end]
|
||||
self.in_content_type = False
|
||||
elif self.in_disposition:
|
||||
self.multipart_response.disposition = data[start:end]
|
||||
self.in_disposition = False
|
||||
|
||||
def on_part_data(data: bytes, start: int, end: int) -> None:
|
||||
# we've seen json header but haven't written the json data
|
||||
if self.in_json and not self.json_done:
|
||||
self.multipart_response.json = data[start:end]
|
||||
self.json_done = True
|
||||
# we have a redirect header rather than a file, and have already captured it
|
||||
elif self.has_redirect:
|
||||
return
|
||||
# otherwise we are in the file part
|
||||
else:
|
||||
logger.info("Writing multipart file data to stream")
|
||||
try:
|
||||
self.stream.write(data[start:end])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Exception encountered writing file data to stream: {e}"
|
||||
)
|
||||
self.deferred.errback()
|
||||
self.file_length += end - start
|
||||
|
||||
callbacks = {
|
||||
"on_header_field": on_header_field,
|
||||
"on_header_value": on_header_value,
|
||||
"on_part_data": on_part_data,
|
||||
}
|
||||
self.parser = multipart.MultipartParser(self.boundary, callbacks)
|
||||
|
||||
self.total_length += len(incoming_data)
|
||||
if self.max_length is not None and self.total_length >= self.max_length:
|
||||
self.deferred.errback(BodyExceededMaxSize())
|
||||
# Close the connection (forcefully) since all the data will get
|
||||
# discarded anyway.
|
||||
assert self.transport is not None
|
||||
self.transport.abortConnection()
|
||||
|
||||
try:
|
||||
self.parser.write(incoming_data) # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception writing to multipart parser: {e}")
|
||||
self.deferred.errback()
|
||||
return
|
||||
|
||||
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||
# If the maximum size was already exceeded, there's nothing to do.
|
||||
if self.deferred.called:
|
||||
return
|
||||
|
||||
if reason.check(ResponseDone):
|
||||
self.multipart_response.length = self.file_length
|
||||
self.deferred.callback(self.multipart_response)
|
||||
else:
|
||||
self.deferred.errback(reason)
|
||||
|
||||
|
||||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
||||
|
||||
|
@ -1091,6 +1217,32 @@ def read_body_with_max_size(
|
|||
return d
|
||||
|
||||
|
||||
def read_multipart_response(
|
||||
response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
|
||||
) -> "defer.Deferred[MultipartResponse]":
|
||||
"""
|
||||
Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
|
||||
the stream passed in and returning a deferred resolving to a MultipartResponse
|
||||
|
||||
Args:
|
||||
response: The HTTP response to read from.
|
||||
stream: The file-object to write to.
|
||||
boundary: the multipart/mixed boundary string
|
||||
max_length: maximum allowable length of the response
|
||||
"""
|
||||
d: defer.Deferred[MultipartResponse] = defer.Deferred()
|
||||
|
||||
# If the Content-Length header gives a size larger than the maximum allowed
|
||||
# size, do not bother downloading the body.
|
||||
if max_length is not None and response.length != UNKNOWN_LENGTH:
|
||||
if response.length > max_length:
|
||||
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
|
||||
return d
|
||||
|
||||
response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
|
||||
return d
|
||||
|
||||
|
||||
def encode_query_args(args: Optional[QueryParams]) -> bytes:
|
||||
"""
|
||||
Encodes a map of query arguments to bytes which can be appended to a URL.
|
||||
|
|
|
@ -75,9 +75,11 @@ from synapse.http.client import (
|
|||
BlocklistingAgentWrapper,
|
||||
BodyExceededMaxSize,
|
||||
ByteWriteable,
|
||||
SimpleHttpClient,
|
||||
_make_scheduler,
|
||||
encode_query_args,
|
||||
read_body_with_max_size,
|
||||
read_multipart_response,
|
||||
)
|
||||
from synapse.http.connectproxyclient import BearerProxyCredentials
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
|
@ -466,6 +468,13 @@ class MatrixFederationHttpClient:
|
|||
|
||||
self._sleeper = AwakenableSleeper(self.reactor)
|
||||
|
||||
self._simple_http_client = SimpleHttpClient(
|
||||
hs,
|
||||
ip_blocklist=hs.config.server.federation_ip_range_blocklist,
|
||||
ip_allowlist=hs.config.server.federation_ip_range_allowlist,
|
||||
use_proxy=True,
|
||||
)
|
||||
|
||||
def wake_destination(self, destination: str) -> None:
|
||||
"""Called when the remote server may have come back online."""
|
||||
|
||||
|
@ -1553,6 +1562,189 @@ class MatrixFederationHttpClient:
|
|||
)
|
||||
return length, headers
|
||||
|
||||
async def federation_get_file(
|
||||
self,
|
||||
destination: str,
|
||||
path: str,
|
||||
output_stream: BinaryIO,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: str,
|
||||
max_size: int,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
ignore_backoff: bool = False,
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
|
||||
"""GETs a file from a given homeserver over the federation /download endpoint
|
||||
Args:
|
||||
destination: The remote server to send the HTTP request to.
|
||||
path: The HTTP path to GET.
|
||||
output_stream: File to write the response body to.
|
||||
download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
|
||||
requester IP
|
||||
ip_address: IP address of the requester
|
||||
max_size: maximum allowable size in bytes of the file
|
||||
args: Optional dictionary used to create the query string.
|
||||
ignore_backoff: true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
|
||||
Returns:
|
||||
Resolves to an (int, dict, bytes) tuple of
|
||||
the file length, a dict of the response headers, and the file json
|
||||
|
||||
Raises:
|
||||
HttpResponseException: If we get an HTTP response code >= 300
|
||||
(except 429).
|
||||
NotRetryingDestination: If we are not yet ready to retry this
|
||||
server.
|
||||
FederationDeniedError: If this destination is not on our
|
||||
federation whitelist
|
||||
RequestSendFailed: If there were problems connecting to the
|
||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||
SynapseError: If the requested file exceeds ratelimits or the response from the
|
||||
remote server is not a multipart response
|
||||
AssertionError: if the resolved multipart response's length is None
|
||||
"""
|
||||
request = MatrixFederationRequest(
|
||||
method="GET", destination=destination, path=path, query=args
|
||||
)
|
||||
|
||||
# check for a minimum balance of 1MiB in ratelimiter before initiating request
|
||||
send_req, _ = await download_ratelimiter.can_do_action(
|
||||
requester=None, key=ip_address, n_actions=1048576, update=False
|
||||
)
|
||||
|
||||
if not send_req:
|
||||
msg = "Requested file size exceeds ratelimits"
|
||||
logger.warning(
|
||||
"{%s} [%s] %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
msg,
|
||||
)
|
||||
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
|
||||
|
||||
response = await self._send_request(
|
||||
request,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
headers = dict(response.headers.getAllRawHeaders())
|
||||
|
||||
expected_size = response.length
|
||||
# if we don't get an expected length then use the max length
|
||||
if expected_size == UNKNOWN_LENGTH:
|
||||
expected_size = max_size
|
||||
logger.debug(
|
||||
f"File size unknown, assuming file is max allowable size: {max_size}"
|
||||
)
|
||||
|
||||
read_body, _ = await download_ratelimiter.can_do_action(
|
||||
requester=None,
|
||||
key=ip_address,
|
||||
n_actions=expected_size,
|
||||
)
|
||||
if not read_body:
|
||||
msg = "Requested file size exceeds ratelimits"
|
||||
logger.warning(
|
||||
"{%s} [%s] %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
msg,
|
||||
)
|
||||
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
|
||||
|
||||
# this should be a multipart/mixed response with the boundary string in the header
|
||||
try:
|
||||
raw_content_type = headers.get(b"Content-Type")
|
||||
assert raw_content_type is not None
|
||||
content_type = raw_content_type[0].decode("UTF-8")
|
||||
content_type_parts = content_type.split("boundary=")
|
||||
boundary = content_type_parts[1]
|
||||
except Exception:
|
||||
msg = "Remote response is malformed: expected Content-Type of multipart/mixed with a boundary present."
|
||||
logger.warning(
|
||||
"{%s} [%s] %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
msg,
|
||||
)
|
||||
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg)
|
||||
|
||||
try:
|
||||
# add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
|
||||
deferred = read_multipart_response(
|
||||
response, output_stream, boundary, expected_size + 1
|
||||
)
|
||||
deferred.addTimeout(self.default_timeout_seconds, self.reactor)
|
||||
except BodyExceededMaxSize:
|
||||
msg = "Requested file is too large > %r bytes" % (expected_size,)
|
||||
logger.warning(
|
||||
"{%s} [%s] %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
msg,
|
||||
)
|
||||
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
|
||||
except defer.TimeoutError as e:
|
||||
logger.warning(
|
||||
"{%s} [%s] Timed out reading response - %s %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
request.method,
|
||||
request.uri.decode("ascii"),
|
||||
)
|
||||
raise RequestSendFailed(e, can_retry=True) from e
|
||||
except ResponseFailed as e:
|
||||
logger.warning(
|
||||
"{%s} [%s] Failed to read response - %s %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
request.method,
|
||||
request.uri.decode("ascii"),
|
||||
)
|
||||
raise RequestSendFailed(e, can_retry=True) from e
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"{%s} [%s] Error reading response: %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
multipart_response = await make_deferred_yieldable(deferred)
|
||||
if not multipart_response.url:
|
||||
assert multipart_response.length is not None
|
||||
length = multipart_response.length
|
||||
headers[b"Content-Type"] = [multipart_response.content_type]
|
||||
headers[b"Content-Disposition"] = [multipart_response.disposition]
|
||||
|
||||
# the response contained a redirect url to download the file from
|
||||
else:
|
||||
str_url = multipart_response.url.decode("utf-8")
|
||||
logger.info(
|
||||
"{%s} [%s] File download redirected, now downloading from: %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
str_url,
|
||||
)
|
||||
length, headers, _, _ = await self._simple_http_client.get_file(
|
||||
str_url, output_stream, expected_size
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"{%s} [%s] Completed: %d %s [%d bytes] %s %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
length,
|
||||
request.method,
|
||||
request.uri.decode("ascii"),
|
||||
)
|
||||
return length, headers, multipart_response.json
|
||||
|
||||
|
||||
def _flatten_response_never_received(e: BaseException) -> str:
|
||||
if hasattr(e, "reasons"):
|
||||
|
|
|
@ -221,6 +221,7 @@ def add_file_headers(
|
|||
# select private. don't bother setting Expires as all our
|
||||
# clients are smart enough to be happy with Cache-Control
|
||||
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
|
||||
|
||||
if file_size is not None:
|
||||
request.setHeader(b"Content-Length", b"%d" % (file_size,))
|
||||
|
||||
|
@ -302,12 +303,37 @@ async def respond_with_multipart_responder(
|
|||
)
|
||||
return
|
||||
|
||||
if media_info.media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES:
|
||||
disposition = "inline"
|
||||
else:
|
||||
disposition = "attachment"
|
||||
|
||||
def _quote(x: str) -> str:
|
||||
return urllib.parse.quote(x.encode("utf-8"))
|
||||
|
||||
if media_info.upload_name:
|
||||
if _can_encode_filename_as_token(media_info.upload_name):
|
||||
disposition = "%s; filename=%s" % (
|
||||
disposition,
|
||||
media_info.upload_name,
|
||||
)
|
||||
else:
|
||||
disposition = "%s; filename*=utf-8''%s" % (
|
||||
disposition,
|
||||
_quote(media_info.upload_name),
|
||||
)
|
||||
|
||||
from synapse.media.media_storage import MultipartFileConsumer
|
||||
|
||||
# note that currently the json_object is just {}, this will change when linked media
|
||||
# is implemented
|
||||
multipart_consumer = MultipartFileConsumer(
|
||||
clock, request, media_info.media_type, {}, media_info.media_length
|
||||
clock,
|
||||
request,
|
||||
media_info.media_type,
|
||||
{},
|
||||
disposition,
|
||||
media_info.media_length,
|
||||
)
|
||||
|
||||
logger.debug("Responding to media request with responder %s", responder)
|
||||
|
|
|
@ -480,6 +480,7 @@ class MediaRepository:
|
|||
name: Optional[str],
|
||||
max_timeout_ms: int,
|
||||
ip_address: str,
|
||||
use_federation_endpoint: bool,
|
||||
) -> None:
|
||||
"""Respond to requests for remote media.
|
||||
|
||||
|
@ -492,6 +493,8 @@ class MediaRepository:
|
|||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||
media to be uploaded.
|
||||
ip_address: the IP address of the requester
|
||||
use_federation_endpoint: whether to request the remote media over the new
|
||||
federation `/download` endpoint
|
||||
|
||||
Returns:
|
||||
Resolves once a response has successfully been written to request
|
||||
|
@ -522,6 +525,7 @@ class MediaRepository:
|
|||
max_timeout_ms,
|
||||
self.download_ratelimiter,
|
||||
ip_address,
|
||||
use_federation_endpoint,
|
||||
)
|
||||
|
||||
# We deliberately stream the file outside the lock
|
||||
|
@ -569,6 +573,7 @@ class MediaRepository:
|
|||
max_timeout_ms,
|
||||
self.download_ratelimiter,
|
||||
ip_address,
|
||||
False,
|
||||
)
|
||||
|
||||
# Ensure we actually use the responder so that it releases resources
|
||||
|
@ -585,6 +590,7 @@ class MediaRepository:
|
|||
max_timeout_ms: int,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: str,
|
||||
use_federation_endpoint: bool,
|
||||
) -> Tuple[Optional[Responder], RemoteMedia]:
|
||||
"""Looks for media in local cache, if not there then attempt to
|
||||
download from remote server.
|
||||
|
@ -598,6 +604,8 @@ class MediaRepository:
|
|||
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
|
||||
requester IP.
|
||||
ip_address: the IP address of the requester
|
||||
use_federation_endpoint: whether to request the remote media over the new federation
|
||||
/download endpoint
|
||||
|
||||
Returns:
|
||||
A tuple of responder and the media info of the file.
|
||||
|
@ -629,9 +637,23 @@ class MediaRepository:
|
|||
# Failed to find the file anywhere, lets download it.
|
||||
|
||||
try:
|
||||
media_info = await self._download_remote_file(
|
||||
server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
|
||||
)
|
||||
if not use_federation_endpoint:
|
||||
media_info = await self._download_remote_file(
|
||||
server_name,
|
||||
media_id,
|
||||
max_timeout_ms,
|
||||
download_ratelimiter,
|
||||
ip_address,
|
||||
)
|
||||
else:
|
||||
media_info = await self._federation_download_remote_file(
|
||||
server_name,
|
||||
media_id,
|
||||
max_timeout_ms,
|
||||
download_ratelimiter,
|
||||
ip_address,
|
||||
)
|
||||
|
||||
except SynapseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
@ -775,6 +797,129 @@ class MediaRepository:
|
|||
quarantined_by=None,
|
||||
)
|
||||
|
||||
async def _federation_download_remote_file(
|
||||
self,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
max_timeout_ms: int,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: str,
|
||||
) -> RemoteMedia:
|
||||
"""Attempt to download the remote file from the given server name.
|
||||
Uses the given file_id as the local id and downloads the file over the federation
|
||||
v1 download endpoint
|
||||
|
||||
Args:
|
||||
server_name: Originating server
|
||||
media_id: The media ID of the content (as defined by the
|
||||
remote server). This is different than the file_id, which is
|
||||
locally generated.
|
||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||
media to be uploaded.
|
||||
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
|
||||
requester IP
|
||||
ip_address: the IP address of the requester
|
||||
|
||||
Returns:
|
||||
The media info of the file.
|
||||
"""
|
||||
|
||||
file_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||
|
||||
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||
try:
|
||||
res = await self.client.federation_download_media(
|
||||
server_name,
|
||||
media_id,
|
||||
output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
max_timeout_ms=max_timeout_ms,
|
||||
download_ratelimiter=download_ratelimiter,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
# if we had to fall back to the _matrix/media endpoint it will only return
|
||||
# the headers and length, check the length of the tuple before unpacking
|
||||
if len(res) == 3:
|
||||
length, headers, json = res
|
||||
else:
|
||||
length, headers = res
|
||||
except RequestSendFailed as e:
|
||||
logger.warning(
|
||||
"Request failed fetching remote media %s/%s: %r",
|
||||
server_name,
|
||||
media_id,
|
||||
e,
|
||||
)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
except HttpResponseException as e:
|
||||
logger.warning(
|
||||
"HTTP error fetching remote media %s/%s: %s",
|
||||
server_name,
|
||||
media_id,
|
||||
e.response,
|
||||
)
|
||||
if e.code == twisted.web.http.NOT_FOUND:
|
||||
raise e.to_synapse_error()
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
except SynapseError:
|
||||
logger.warning(
|
||||
"Failed to fetch remote media %s/%s", server_name, media_id
|
||||
)
|
||||
raise
|
||||
except NotRetryingDestination:
|
||||
logger.warning("Not retrying destination %r", server_name)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to fetch remote media %s/%s", server_name, media_id
|
||||
)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
if b"Content-Type" in headers:
|
||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||
else:
|
||||
media_type = "application/octet-stream"
|
||||
upload_name = get_filename_from_headers(headers)
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
# Multiple remote media download requests can race (when using
|
||||
# multiple media repos), so this may throw a violation constraint
|
||||
# exception. If it does we'll delete the newly downloaded file from
|
||||
# disk (as we're in the ctx manager).
|
||||
#
|
||||
# However: we've already called `finish()` so we may have also
|
||||
# written to the storage providers. This is preferable to the
|
||||
# alternative where we call `finish()` *after* this, where we could
|
||||
# end up having an entry in the DB but fail to write the files to
|
||||
# the storage providers.
|
||||
await self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=time_now_ms,
|
||||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
)
|
||||
|
||||
logger.debug("Stored remote media in file %r", fname)
|
||||
|
||||
return RemoteMedia(
|
||||
media_origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
media_length=length,
|
||||
upload_name=upload_name,
|
||||
created_ts=time_now_ms,
|
||||
filesystem_id=file_id,
|
||||
last_access_ts=time_now_ms,
|
||||
quarantined_by=None,
|
||||
)
|
||||
|
||||
def _get_thumbnail_requirements(
|
||||
self, media_type: str
|
||||
) -> Tuple[ThumbnailRequirement, ...]:
|
||||
|
|
|
@ -401,13 +401,14 @@ class MultipartFileConsumer:
|
|||
wrapped_consumer: interfaces.IConsumer,
|
||||
file_content_type: str,
|
||||
json_object: JsonDict,
|
||||
content_length: Optional[int] = None,
|
||||
disposition: str,
|
||||
content_length: Optional[int],
|
||||
) -> None:
|
||||
self.clock = clock
|
||||
self.wrapped_consumer = wrapped_consumer
|
||||
self.json_field = json_object
|
||||
self.json_field_written = False
|
||||
self.content_type_written = False
|
||||
self.file_headers_written = False
|
||||
self.file_content_type = file_content_type
|
||||
self.boundary = uuid4().hex.encode("ascii")
|
||||
|
||||
|
@ -420,6 +421,7 @@ class MultipartFileConsumer:
|
|||
self.paused = False
|
||||
|
||||
self.length = content_length
|
||||
self.disposition = disposition
|
||||
|
||||
### IConsumer APIs ###
|
||||
|
||||
|
@ -488,11 +490,13 @@ class MultipartFileConsumer:
|
|||
self.json_field_written = True
|
||||
|
||||
# if we haven't written the content type yet, do so
|
||||
if not self.content_type_written:
|
||||
if not self.file_headers_written:
|
||||
type = self.file_content_type.encode("utf-8")
|
||||
content_type = Header(b"Content-Type", type)
|
||||
self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF)
|
||||
self.content_type_written = True
|
||||
self.wrapped_consumer.write(bytes(content_type) + CRLF)
|
||||
disp_header = Header(b"Content-Disposition", self.disposition)
|
||||
self.wrapped_consumer.write(bytes(disp_header) + CRLF + CRLF)
|
||||
self.file_headers_written = True
|
||||
|
||||
self.wrapped_consumer.write(data)
|
||||
|
||||
|
@ -506,7 +510,6 @@ class MultipartFileConsumer:
|
|||
producing data for good.
|
||||
"""
|
||||
assert self.producer is not None
|
||||
|
||||
self.paused = True
|
||||
self.producer.stopProducing()
|
||||
|
||||
|
@ -518,7 +521,6 @@ class MultipartFileConsumer:
|
|||
the time being, and to stop until C{resumeProducing()} is called.
|
||||
"""
|
||||
assert self.producer is not None
|
||||
|
||||
self.paused = True
|
||||
|
||||
if self.streaming:
|
||||
|
@ -549,7 +551,7 @@ class MultipartFileConsumer:
|
|||
"""
|
||||
if not self.length:
|
||||
return None
|
||||
# calculate length of json field and content-type header
|
||||
# calculate length of json field and content-type, disposition headers
|
||||
json_field = json.dumps(self.json_field)
|
||||
json_bytes = json_field.encode("utf-8")
|
||||
json_length = len(json_bytes)
|
||||
|
@ -558,9 +560,13 @@ class MultipartFileConsumer:
|
|||
content_type = Header(b"Content-Type", type)
|
||||
type_length = len(bytes(content_type))
|
||||
|
||||
# 154 is the length of the elements that aren't variable, ie
|
||||
disp = self.disposition.encode("utf-8")
|
||||
disp_header = Header(b"Content-Disposition", disp)
|
||||
disp_length = len(bytes(disp_header))
|
||||
|
||||
# 156 is the length of the elements that aren't variable, ie
|
||||
# CRLFs and boundary strings, etc
|
||||
self.length += json_length + type_length + 154
|
||||
self.length += json_length + type_length + disp_length + 156
|
||||
|
||||
return self.length
|
||||
|
||||
|
@ -569,7 +575,6 @@ class MultipartFileConsumer:
|
|||
async def _resumeProducingRepeatedly(self) -> None:
|
||||
assert self.producer is not None
|
||||
assert not self.streaming
|
||||
|
||||
producer = cast("interfaces.IPullProducer", self.producer)
|
||||
|
||||
self.paused = False
|
||||
|
|
|
@ -764,6 +764,13 @@ class Notifier:
|
|||
|
||||
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
|
||||
"""Wait for this worker to catch up with the given stream token."""
|
||||
current_token = self.event_sources.get_current_token()
|
||||
if stream_token.is_before_or_eq(current_token):
|
||||
return True
|
||||
|
||||
# Work around a bug where older Synapse versions gave out tokens "from
|
||||
# the future", i.e. that are ahead of the tokens persisted in the DB.
|
||||
stream_token = await self.event_sources.bound_future_token(stream_token)
|
||||
|
||||
start = self.clock.time_msec()
|
||||
while True:
|
||||
|
|
|
@ -145,6 +145,10 @@ class ClientRestResource(JsonResource):
|
|||
password_policy.register_servlets(hs, client_resource)
|
||||
knock.register_servlets(hs, client_resource)
|
||||
appservice_ping.register_servlets(hs, client_resource)
|
||||
if hs.config.server.enable_media_repo:
|
||||
from synapse.rest.client import media
|
||||
|
||||
media.register_servlets(hs, client_resource)
|
||||
|
||||
# moving to /_synapse/admin
|
||||
if is_main_process:
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from synapse.http.server import (
|
||||
HttpServer,
|
||||
|
@ -194,14 +195,76 @@ class UnstableThumbnailResource(RestServlet):
|
|||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
|
||||
class DownloadResource(RestServlet):
|
||||
PATTERNS = [
|
||||
re.compile(
|
||||
"/_matrix/client/v1/media/download/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)(/(?P<file_name>[^/]*))?$"
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||
super().__init__()
|
||||
self.media_repo = media_repo
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
file_name: Optional[str] = None,
|
||||
) -> None:
|
||||
# Validate the server name, raising if invalid
|
||||
parse_and_validate_server_name(server_name)
|
||||
|
||||
await self.auth.get_user_by_req(request)
|
||||
|
||||
set_cors_headers(request)
|
||||
set_corp_headers(request)
|
||||
request.setHeader(
|
||||
b"Content-Security-Policy",
|
||||
b"sandbox;"
|
||||
b" default-src 'none';"
|
||||
b" script-src 'none';"
|
||||
b" plugin-types application/pdf;"
|
||||
b" style-src 'unsafe-inline';"
|
||||
b" media-src 'self';"
|
||||
b" object-src 'self';",
|
||||
)
|
||||
# Limited non-standard form of CSP for IE11
|
||||
request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
|
||||
request.setHeader(b"Referrer-Policy", b"no-referrer")
|
||||
max_timeout_ms = parse_integer(
|
||||
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
|
||||
)
|
||||
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
|
||||
|
||||
if self._is_mine_server_name(server_name):
|
||||
await self.media_repo.get_local_media(
|
||||
request, media_id, file_name, max_timeout_ms
|
||||
)
|
||||
else:
|
||||
ip_address = request.getClientAddress().host
|
||||
await self.media_repo.get_remote_media(
|
||||
request,
|
||||
server_name,
|
||||
media_id,
|
||||
file_name,
|
||||
max_timeout_ms,
|
||||
ip_address,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc3916_authenticated_media_enabled:
|
||||
media_repo = hs.get_media_repository()
|
||||
if hs.config.media.url_preview_enabled:
|
||||
UnstablePreviewURLServlet(
|
||||
hs, media_repo, media_repo.media_storage
|
||||
).register(http_server)
|
||||
UnstableMediaConfigResource(hs).register(http_server)
|
||||
UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
|
||||
media_repo = hs.get_media_repository()
|
||||
if hs.config.media.url_preview_enabled:
|
||||
UnstablePreviewURLServlet(hs, media_repo, media_repo.media_storage).register(
|
||||
http_server
|
||||
)
|
||||
UnstableMediaConfigResource(hs).register(http_server)
|
||||
UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
|
||||
http_server
|
||||
)
|
||||
DownloadResource(hs, media_repo).register(http_server)
|
||||
|
|
|
@ -105,4 +105,5 @@ class DownloadResource(RestServlet):
|
|||
file_name,
|
||||
max_timeout_ms,
|
||||
ip_address,
|
||||
False,
|
||||
)
|
||||
|
|
|
@ -43,10 +43,7 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
self._instance_name in hs.config.worker.writers.account_data
|
||||
)
|
||||
|
||||
self._account_data_id_gen: AbstractStreamIdGenerator
|
||||
self._account_data_id_gen: MultiWriterIdGenerator
|
||||
|
||||
self._account_data_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
|
@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||
"""
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def get_account_data_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._account_data_id_gen
|
||||
|
||||
@cached()
|
||||
async def get_global_account_data_for_user(
|
||||
self, user_id: str
|
||||
|
|
|
@ -50,10 +50,7 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
self._instance_name in hs.config.worker.writers.to_device
|
||||
)
|
||||
|
||||
self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
|
||||
self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
|
@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
def get_to_device_stream_token(self) -> int:
|
||||
return self._to_device_msg_id_gen.get_current_token()
|
||||
|
||||
def get_to_device_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._to_device_msg_id_gen
|
||||
|
||||
async def get_messages_for_user_devices(
|
||||
self,
|
||||
user_ids: Collection[str],
|
||||
|
|
|
@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
def get_device_stream_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._device_list_id_gen
|
||||
|
||||
async def count_devices_by_users(
|
||||
self, user_ids: Optional[Collection[str]] = None
|
||||
) -> int:
|
||||
|
|
|
@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._stream_id_gen: AbstractStreamIdGenerator
|
||||
self._backfill_id_gen: AbstractStreamIdGenerator
|
||||
self._stream_id_gen: MultiWriterIdGenerator
|
||||
self._backfill_id_gen: MultiWriterIdGenerator
|
||||
|
||||
self._stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
|
|
|
@ -42,10 +42,7 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines._base import IsolationLevel
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._presence_id_gen: AbstractStreamIdGenerator
|
||||
self._presence_id_gen: MultiWriterIdGenerator
|
||||
|
||||
self._can_persist_presence = (
|
||||
self._instance_name in hs.config.worker.writers.presence
|
||||
|
@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||
def get_current_presence_token(self) -> int:
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator:
|
||||
return self._presence_id_gen
|
||||
|
||||
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
|
||||
"""Fetch non-offline presence from the database so that we can register
|
||||
the appropriate time outs.
|
||||
|
|
|
@ -178,6 +178,9 @@ class PushRulesWorkerStore(
|
|||
"""
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator:
|
||||
return self._push_rules_stream_id_gen
|
||||
|
||||
def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||
) -> None:
|
||||
|
|
|
@ -45,10 +45,7 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.engines._base import IsolationLevel
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
JsonMapping,
|
||||
|
@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._receipts_id_gen: AbstractStreamIdGenerator
|
||||
self._receipts_id_gen: MultiWriterIdGenerator
|
||||
|
||||
self._can_write_to_receipts = (
|
||||
self._instance_name in hs.config.worker.writers.receipts
|
||||
|
@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
|
||||
return self._receipts_id_gen.get_current_token_for_writer(instance_name)
|
||||
|
||||
def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator:
|
||||
return self._receipts_id_gen
|
||||
|
||||
def get_last_unthreaded_receipt_for_user_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
|
|
|
@ -59,11 +59,7 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
IdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
self.config: HomeServerConfig = hs.config
|
||||
|
||||
self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
|
||||
self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator
|
||||
|
||||
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
|
@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||
instance_name
|
||||
)
|
||||
|
||||
def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._un_partial_stated_rooms_stream_id_gen
|
||||
|
||||
async def get_un_partial_stated_rooms_between(
|
||||
self, last_id: int, current_id: int, room_ids: Collection[str]
|
||||
) -> Set[str]:
|
||||
|
|
|
@ -641,6 +641,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
|
||||
|
||||
def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._stream_id_gen
|
||||
|
||||
async def get_room_events_stream_for_rooms(
|
||||
self,
|
||||
room_ids: Collection[str],
|
||||
|
|
|
@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
pos = self.get_current_token_for_writer(self._instance_name)
|
||||
txn.execute(sql, (self._stream_name, self._instance_name, pos))
|
||||
|
||||
async def get_max_allocated_token(self) -> int:
|
||||
return await self._db.runInteraction(
|
||||
"get_max_allocated_token", self._sequence_gen.get_max_allocated
|
||||
)
|
||||
|
||||
|
||||
@attr.s(frozen=True, auto_attribs=True)
|
||||
class _AsyncCtxManagerWrapper(Generic[T]):
|
||||
|
|
|
@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
|
|||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_max_allocated(self, txn: Cursor) -> int:
|
||||
"""Get the maximum ID that we have allocated"""
|
||||
|
||||
|
||||
class PostgresSequenceGenerator(SequenceGenerator):
|
||||
"""An implementation of SequenceGenerator which uses a postgres sequence"""
|
||||
|
@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
|||
% {"seq": self._sequence_name, "stream_name": stream_name}
|
||||
)
|
||||
|
||||
def get_max_allocated(self, txn: Cursor) -> int:
|
||||
# We just read from the sequence what the last value we fetched was.
|
||||
txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}")
|
||||
row = txn.fetchone()
|
||||
assert row is not None
|
||||
|
||||
last_value, is_called = row
|
||||
if not is_called:
|
||||
last_value -= 1
|
||||
return last_value
|
||||
|
||||
|
||||
GetFirstCallbackType = Callable[[Cursor], int]
|
||||
|
||||
|
@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator):
|
|||
# There is nothing to do for in memory sequences
|
||||
pass
|
||||
|
||||
def get_max_allocated(self, txn: Cursor) -> int:
|
||||
with self._lock:
|
||||
if self._current_max_id is None:
|
||||
assert self._callback is not None
|
||||
self._current_max_id = self._callback(txn)
|
||||
self._callback = None
|
||||
|
||||
return self._current_max_id
|
||||
|
||||
|
||||
def build_sequence_generator(
|
||||
db_conn: "LoggingDatabaseConnection",
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Sequence, Tuple
|
||||
|
||||
import attr
|
||||
|
@ -30,12 +31,20 @@ from synapse.handlers.room import RoomEventSource
|
|||
from synapse.handlers.typing import TypingNotificationEventSource
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.streams import EventSource
|
||||
from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
|
||||
from synapse.types import (
|
||||
AbstractMultiWriterStreamToken,
|
||||
MultiWriterStreamToken,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class _EventSourcesInner:
|
||||
room: RoomEventSource
|
||||
|
@ -91,6 +100,77 @@ class EventSources:
|
|||
)
|
||||
return token
|
||||
|
||||
async def bound_future_token(self, token: StreamToken) -> StreamToken:
|
||||
"""Bound a token that is ahead of the current token to the maximum
|
||||
persisted values.
|
||||
|
||||
This ensures that if we wait for the given token we know the stream will
|
||||
eventually advance to that point.
|
||||
|
||||
This works around a bug where older Synapse versions will give out
|
||||
tokens for streams, and then after a restart will give back tokens where
|
||||
the stream has "gone backwards".
|
||||
"""
|
||||
|
||||
current_token = self.get_current_token()
|
||||
|
||||
stream_key_to_id_gen = {
|
||||
StreamKeyType.ROOM: self.store.get_events_stream_id_generator(),
|
||||
StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(),
|
||||
StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(),
|
||||
StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(),
|
||||
StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(),
|
||||
StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
|
||||
StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
|
||||
}
|
||||
|
||||
for _, key in StreamKeyType.__members__.items():
|
||||
if key == StreamKeyType.TYPING:
|
||||
# Typing stream is allowed to "reset", and so comparisons don't
|
||||
# really make sense as is.
|
||||
# TODO: Figure out a better way of tracking resets.
|
||||
continue
|
||||
|
||||
token_value = token.get_field(key)
|
||||
current_value = current_token.get_field(key)
|
||||
|
||||
if isinstance(token_value, AbstractMultiWriterStreamToken):
|
||||
assert type(current_value) is type(token_value)
|
||||
|
||||
if not token_value.is_before_or_eq(current_value): # type: ignore[arg-type]
|
||||
max_token = await stream_key_to_id_gen[
|
||||
key
|
||||
].get_max_allocated_token()
|
||||
|
||||
if max_token < token_value.get_max_stream_pos():
|
||||
logger.error(
|
||||
"Bounding token from the future '%s': token: %s, bound: %s",
|
||||
key,
|
||||
token_value,
|
||||
max_token,
|
||||
)
|
||||
token = token.copy_and_replace(
|
||||
key, token_value.bound_stream_token(max_token)
|
||||
)
|
||||
else:
|
||||
assert isinstance(current_value, int)
|
||||
if current_value < token_value:
|
||||
max_token = await stream_key_to_id_gen[
|
||||
key
|
||||
].get_max_allocated_token()
|
||||
|
||||
if max_token < token_value:
|
||||
logger.error(
|
||||
"Bounding token from the future '%s': token: %s, bound: %s",
|
||||
key,
|
||||
token_value,
|
||||
max_token,
|
||||
)
|
||||
token = token.copy_and_replace(key, max_token)
|
||||
|
||||
return token
|
||||
|
||||
@trace
|
||||
async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
|
||||
"""Get the start token for a given room to be used to paginate
|
||||
|
|
|
@ -536,6 +536,16 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||
|
||||
return True
|
||||
|
||||
def bound_stream_token(self, max_stream: int) -> "Self":
|
||||
"""Bound the stream positions to a maximum value"""
|
||||
|
||||
return type(self)(
|
||||
stream=min(self.stream, max_stream),
|
||||
instance_map=immutabledict(
|
||||
{k: min(s, max_stream) for k, s in self.instance_map.items()}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, order=False)
|
||||
class RoomStreamToken(AbstractMultiWriterStreamToken):
|
||||
|
@ -722,6 +732,14 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
def bound_stream_token(self, max_stream: int) -> "RoomStreamToken":
|
||||
"""See super class"""
|
||||
|
||||
# This only makes sense for stream tokens.
|
||||
assert self.topological is None
|
||||
|
||||
return super().bound_stream_token(max_stream)
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, order=False)
|
||||
class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
||||
|
|
|
@ -36,10 +36,9 @@ from synapse.util import Clock
|
|||
|
||||
from tests import unittest
|
||||
from tests.test_utils import SMALL_PNG
|
||||
from tests.unittest import override_config
|
||||
|
||||
|
||||
class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||
class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
@ -65,9 +64,6 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
|
|||
)
|
||||
self.media_repo = hs.get_media_repository()
|
||||
|
||||
@override_config(
|
||||
{"experimental_features": {"msc3916_authenticated_media_enabled": True}}
|
||||
)
|
||||
def test_file_download(self) -> None:
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
|
@ -82,7 +78,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
|
|||
# test with a text file
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET",
|
||||
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
|
||||
f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(200, channel.code)
|
||||
|
@ -106,7 +102,8 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
|
|||
|
||||
# check that the text file and expected value exist
|
||||
found_file = any(
|
||||
"\r\nContent-Type: text/plain\r\n\r\nfile_to_stream" in field
|
||||
"\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream"
|
||||
in field
|
||||
for field in stripped
|
||||
)
|
||||
self.assertTrue(found_file)
|
||||
|
@ -124,7 +121,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
|
|||
# test with an image file
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET",
|
||||
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
|
||||
f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(200, channel.code)
|
||||
|
@ -149,25 +146,3 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
|
|||
# check that the png file exists and matches what was uploaded
|
||||
found_file = any(SMALL_PNG in field for field in stripped_bytes)
|
||||
self.assertTrue(found_file)
|
||||
|
||||
@override_config(
|
||||
{"experimental_features": {"msc3916_authenticated_media_enabled": False}}
|
||||
)
|
||||
def test_disable_config(self) -> None:
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
"text/plain",
|
||||
"test_upload",
|
||||
content,
|
||||
46,
|
||||
UserID.from_string("@user_id:whatever.org"),
|
||||
)
|
||||
)
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET",
|
||||
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(404, channel.code)
|
||||
self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED")
|
||||
|
|
|
@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
|||
|
||||
from parameterized import parameterized
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules
|
||||
|
@ -35,7 +36,14 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe
|
|||
from synapse.rest import admin
|
||||
from synapse.rest.client import knock, login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID, create_requester
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
MultiWriterStreamToken,
|
||||
RoomStreamToken,
|
||||
StreamKeyType,
|
||||
UserID,
|
||||
create_requester,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
|
||||
import tests.unittest
|
||||
|
@ -959,6 +967,94 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||
|
||||
self.fail("No push rules found")
|
||||
|
||||
def test_wait_for_future_sync_token(self) -> None:
|
||||
"""Test that if we receive a token that is ahead of our current token,
|
||||
we'll wait until the stream position advances.
|
||||
|
||||
This can happen if replication streams start lagging, and the client's
|
||||
previous sync request was serviced by a worker ahead of ours.
|
||||
"""
|
||||
user = self.register_user("alice", "password")
|
||||
|
||||
# We simulate a lagging stream by getting a stream ID from the ID gen
|
||||
# and then waiting to mark it as "persisted".
|
||||
presence_id_gen = self.store.get_presence_stream_id_gen()
|
||||
ctx_mgr = presence_id_gen.get_next()
|
||||
stream_id = self.get_success(ctx_mgr.__aenter__())
|
||||
|
||||
# Create the new token based on the stream ID above.
|
||||
current_token = self.hs.get_event_sources().get_current_token()
|
||||
since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id)
|
||||
|
||||
sync_d = defer.ensureDeferred(
|
||||
self.sync_handler.wait_for_sync_for_user(
|
||||
create_requester(user),
|
||||
generate_sync_config(user),
|
||||
sync_version=SyncVersion.SYNC_V2,
|
||||
request_key=generate_request_key(),
|
||||
since_token=since_token,
|
||||
timeout=0,
|
||||
)
|
||||
)
|
||||
|
||||
# This should block waiting for the presence stream to update
|
||||
self.pump()
|
||||
self.assertFalse(sync_d.called)
|
||||
|
||||
# Marking the stream ID as persisted should unblock the request.
|
||||
self.get_success(ctx_mgr.__aexit__(None, None, None))
|
||||
|
||||
self.get_success(sync_d, by=1.0)
|
||||
|
||||
@parameterized.expand(
|
||||
[(key,) for key in StreamKeyType.__members__.values()],
|
||||
name_func=lambda func, _, param: f"{func.__name__}_{param.args[0].name}",
|
||||
)
|
||||
def test_wait_for_invalid_future_sync_token(
|
||||
self, stream_key: StreamKeyType
|
||||
) -> None:
|
||||
"""Like the previous test, except we give a token that has a stream
|
||||
position ahead of what is in the DB, i.e. its invalid and we shouldn't
|
||||
wait for the stream to advance (as it may never do so).
|
||||
|
||||
This can happen due to older versions of Synapse giving out stream
|
||||
positions without persisting them in the DB, and so on restart the
|
||||
stream would get reset back to an older position.
|
||||
"""
|
||||
user = self.register_user("alice", "password")
|
||||
|
||||
# Create a token and advance one of the streams.
|
||||
current_token = self.hs.get_event_sources().get_current_token()
|
||||
token_value = current_token.get_field(stream_key)
|
||||
|
||||
# How we advance the streams depends on the type.
|
||||
if isinstance(token_value, int):
|
||||
since_token = current_token.copy_and_advance(stream_key, token_value + 1)
|
||||
elif isinstance(token_value, MultiWriterStreamToken):
|
||||
since_token = current_token.copy_and_advance(
|
||||
stream_key, MultiWriterStreamToken(stream=token_value.stream + 1)
|
||||
)
|
||||
elif isinstance(token_value, RoomStreamToken):
|
||||
since_token = current_token.copy_and_advance(
|
||||
stream_key, RoomStreamToken(stream=token_value.stream + 1)
|
||||
)
|
||||
else:
|
||||
raise Exception("Unreachable")
|
||||
|
||||
sync_d = defer.ensureDeferred(
|
||||
self.sync_handler.wait_for_sync_for_user(
|
||||
create_requester(user),
|
||||
generate_sync_config(user),
|
||||
sync_version=SyncVersion.SYNC_V2,
|
||||
request_key=generate_request_key(),
|
||||
since_token=since_token,
|
||||
timeout=0,
|
||||
)
|
||||
)
|
||||
|
||||
# We should return without waiting for the presence stream to advance.
|
||||
self.get_success(sync_d)
|
||||
|
||||
|
||||
def generate_sync_config(
|
||||
user_id: str,
|
||||
|
|
|
@ -37,18 +37,155 @@ from synapse.http.client import (
|
|||
BlocklistingAgentWrapper,
|
||||
BlocklistingReactorWrapper,
|
||||
BodyExceededMaxSize,
|
||||
MultipartResponse,
|
||||
_DiscardBodyWithMaxSizeProtocol,
|
||||
_MultipartParserProtocol,
|
||||
read_body_with_max_size,
|
||||
read_multipart_response,
|
||||
)
|
||||
|
||||
from tests.server import FakeTransport, get_clock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class ReadMultipartResponseTests(TestCase):
|
||||
data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
|
||||
data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
|
||||
|
||||
redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
|
||||
|
||||
def _build_multipart_response(
|
||||
self, response_length: Union[int, str], max_length: int
|
||||
) -> Tuple[
|
||||
BytesIO,
|
||||
"Deferred[MultipartResponse]",
|
||||
_MultipartParserProtocol,
|
||||
]:
|
||||
"""Start reading the body, returns the response, result and proto"""
|
||||
response = Mock(length=response_length)
|
||||
result = BytesIO()
|
||||
boundary = "6067d4698f8d40a0a794ea7d7379d53a"
|
||||
deferred = read_multipart_response(response, result, boundary, max_length)
|
||||
|
||||
# Fish the protocol out of the response.
|
||||
protocol = response.deliverBody.call_args[0][0]
|
||||
protocol.transport = Mock()
|
||||
|
||||
return result, deferred, protocol
|
||||
|
||||
def _assert_error(
|
||||
self,
|
||||
deferred: "Deferred[MultipartResponse]",
|
||||
protocol: _MultipartParserProtocol,
|
||||
) -> None:
|
||||
"""Ensure that the expected error is received."""
|
||||
assert isinstance(deferred.result, Failure)
|
||||
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
|
||||
assert protocol.transport is not None
|
||||
# type-ignore: presumably abortConnection has been replaced with a Mock.
|
||||
protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
|
||||
|
||||
def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None:
|
||||
"""Ensure that the error in the Deferred is handled gracefully."""
|
||||
called = [False]
|
||||
|
||||
def errback(f: Failure) -> None:
|
||||
called[0] = True
|
||||
|
||||
deferred.addErrback(errback)
|
||||
self.assertTrue(called[0])
|
||||
|
||||
def test_parse_file(self) -> None:
|
||||
"""
|
||||
Check that a multipart response containing a file is properly parsed
|
||||
into the json/file parts, and the json and file are properly captured
|
||||
"""
|
||||
result, deferred, protocol = self._build_multipart_response(249, 250)
|
||||
|
||||
# Start sending data.
|
||||
protocol.dataReceived(self.data1)
|
||||
protocol.dataReceived(self.data2)
|
||||
# Close the connection.
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
|
||||
|
||||
self.assertEqual(multipart_response.json, b"{}")
|
||||
self.assertEqual(result.getvalue(), b"file_to_stream")
|
||||
self.assertEqual(multipart_response.length, len(b"file_to_stream"))
|
||||
self.assertEqual(multipart_response.content_type, b"text/plain")
|
||||
self.assertEqual(
|
||||
multipart_response.disposition, b"inline; filename=test_upload"
|
||||
)
|
||||
|
||||
def test_parse_redirect(self) -> None:
|
||||
"""
|
||||
check that a multipart response containing a redirect is properly parsed and redirect url is
|
||||
returned
|
||||
"""
|
||||
result, deferred, protocol = self._build_multipart_response(249, 250)
|
||||
|
||||
# Start sending data.
|
||||
protocol.dataReceived(self.redirect_data)
|
||||
# Close the connection.
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
|
||||
|
||||
self.assertEqual(multipart_response.json, b"{}")
|
||||
self.assertEqual(result.getvalue(), b"")
|
||||
self.assertEqual(
|
||||
multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt"
|
||||
)
|
||||
|
||||
def test_too_large(self) -> None:
|
||||
"""A response which is too large raises an exception."""
|
||||
result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
|
||||
|
||||
# Start sending data.
|
||||
protocol.dataReceived(self.data1)
|
||||
|
||||
self.assertEqual(result.getvalue(), b"file_")
|
||||
self._assert_error(deferred, protocol)
|
||||
self._cleanup_error(deferred)
|
||||
|
||||
def test_additional_data(self) -> None:
|
||||
"""A connection can receive data after being closed."""
|
||||
result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
|
||||
|
||||
# Start sending data.
|
||||
protocol.dataReceived(self.data1)
|
||||
self._assert_error(deferred, protocol)
|
||||
|
||||
# More data might have come in.
|
||||
protocol.dataReceived(self.data2)
|
||||
|
||||
self.assertEqual(result.getvalue(), b"file_")
|
||||
self._assert_error(deferred, protocol)
|
||||
self._cleanup_error(deferred)
|
||||
|
||||
def test_content_length(self) -> None:
|
||||
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
|
||||
result, deferred, protocol = self._build_multipart_response(250, 1)
|
||||
|
||||
# Deferred shouldn't be called yet.
|
||||
self.assertFalse(deferred.called)
|
||||
|
||||
# Start sending data.
|
||||
protocol.dataReceived(self.data1)
|
||||
self._assert_error(deferred, protocol)
|
||||
self._cleanup_error(deferred)
|
||||
|
||||
# The data is never consumed.
|
||||
self.assertEqual(result.getvalue(), b"")
|
||||
|
||||
|
||||
class ReadBodyWithMaxSizeTests(TestCase):
|
||||
def _build_response(
|
||||
self, length: Union[int, str] = UNKNOWN_LENGTH
|
||||
) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
|
||||
def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
|
||||
BytesIO,
|
||||
"Deferred[int]",
|
||||
_DiscardBodyWithMaxSizeProtocol,
|
||||
]:
|
||||
"""Start reading the body, returns the response, result and proto"""
|
||||
response = Mock(length=length)
|
||||
result = BytesIO()
|
||||
|
|
|
@ -129,7 +129,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
|
|||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||
class _TestImage:
|
||||
class TestImage:
|
||||
"""An image for testing thumbnailing with the expected results
|
||||
|
||||
Attributes:
|
||||
|
@ -158,7 +158,7 @@ class _TestImage:
|
|||
is_inline: bool = True
|
||||
|
||||
|
||||
small_png = _TestImage(
|
||||
small_png = TestImage(
|
||||
SMALL_PNG,
|
||||
b"image/png",
|
||||
b".png",
|
||||
|
@ -175,7 +175,7 @@ small_png = _TestImage(
|
|||
),
|
||||
)
|
||||
|
||||
small_png_with_transparency = _TestImage(
|
||||
small_png_with_transparency = TestImage(
|
||||
unhexlify(
|
||||
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
|
||||
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
|
||||
|
@ -188,7 +188,7 @@ small_png_with_transparency = _TestImage(
|
|||
# different versions of Pillow.
|
||||
)
|
||||
|
||||
small_lossless_webp = _TestImage(
|
||||
small_lossless_webp = TestImage(
|
||||
unhexlify(
|
||||
b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
|
||||
),
|
||||
|
@ -196,7 +196,7 @@ small_lossless_webp = _TestImage(
|
|||
b".webp",
|
||||
)
|
||||
|
||||
empty_file = _TestImage(
|
||||
empty_file = TestImage(
|
||||
b"",
|
||||
b"image/gif",
|
||||
b".gif",
|
||||
|
@ -204,7 +204,7 @@ empty_file = _TestImage(
|
|||
unable_to_thumbnail=True,
|
||||
)
|
||||
|
||||
SVG = _TestImage(
|
||||
SVG = TestImage(
|
||||
b"""<?xml version="1.0"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
|
||||
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
|
@ -236,7 +236,7 @@ urls = [
|
|||
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
|
||||
class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
servlets = [media.register_servlets]
|
||||
test_image: ClassVar[_TestImage]
|
||||
test_image: ClassVar[TestImage]
|
||||
hijack_auth = True
|
||||
user_id = "@test:user"
|
||||
url: ClassVar[str]
|
||||
|
|
|
@ -28,7 +28,7 @@ from twisted.web.http import HTTPChannel
|
|||
from twisted.web.server import Request
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login
|
||||
from synapse.rest.client import login, media
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
|
@ -255,6 +255,238 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
return sum(len(files) for _, _, files in os.walk(path))
|
||||
|
||||
|
||||
class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""Checks running multiple media repos work correctly using autheticated media paths"""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
media.register_servlets,
|
||||
]
|
||||
|
||||
file_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user_id = self.register_user("user", "pass")
|
||||
self.access_token = self.login("user", "pass")
|
||||
|
||||
self.reactor.lookups["example.com"] = "1.2.3.4"
|
||||
|
||||
def default_config(self) -> dict:
|
||||
conf = super().default_config()
|
||||
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
|
||||
return conf
|
||||
|
||||
def make_worker_hs(
|
||||
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
|
||||
) -> HomeServer:
|
||||
worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
|
||||
# Force the media paths onto the replication resource.
|
||||
worker_hs.get_media_repository_resource().register_servlets(
|
||||
self._hs_to_site[worker_hs].resource, worker_hs
|
||||
)
|
||||
return worker_hs
|
||||
|
||||
def _get_media_req(
|
||||
self, hs: HomeServer, target: str, media_id: str
|
||||
) -> Tuple[FakeChannel, Request]:
|
||||
"""Request some remote media from the given HS by calling the download
|
||||
API.
|
||||
|
||||
This then triggers an outbound request from the HS to the target.
|
||||
|
||||
Returns:
|
||||
The channel for the *client* request and the *outbound* request for
|
||||
the media which the caller should respond to.
|
||||
"""
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
self._hs_to_site[hs],
|
||||
"GET",
|
||||
f"/_matrix/client/v1/media/download/{target}/{media_id}",
|
||||
shorthand=False,
|
||||
access_token=self.access_token,
|
||||
await_result=False,
|
||||
)
|
||||
self.pump()
|
||||
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertGreaterEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
|
||||
|
||||
# build the test server
|
||||
server_factory = Factory.forProtocol(HTTPChannel)
|
||||
# Request.finish expects the factory to have a 'log' method.
|
||||
server_factory.log = _log_request
|
||||
|
||||
server_tls_protocol = wrap_server_factory_for_tls(
|
||||
server_factory, self.reactor, sanlist=[b"DNS:example.com"]
|
||||
).buildProtocol(None)
|
||||
|
||||
# now, tell the client protocol factory to build the client protocol (it will be a
|
||||
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
|
||||
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
|
||||
# a FakeTransport.
|
||||
#
|
||||
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||
# stubbing that out here.
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
client_protocol.makeConnection(
|
||||
FakeTransport(server_tls_protocol, self.reactor, client_protocol)
|
||||
)
|
||||
|
||||
# tell the server tls protocol to send its stuff back to the client, too
|
||||
server_tls_protocol.makeConnection(
|
||||
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
||||
)
|
||||
|
||||
# fish the test server back out of the server-side TLS protocol.
|
||||
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
|
||||
|
||||
# give the reactor a pump to get the TLS juices flowing.
|
||||
self.reactor.pump((0.1,))
|
||||
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
request = http_server.requests[0]
|
||||
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(
|
||||
request.path,
|
||||
f"/_matrix/federation/v1/media/download/{media_id}".encode(),
|
||||
)
|
||||
self.assertEqual(
|
||||
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
|
||||
)
|
||||
|
||||
return channel, request
|
||||
|
||||
def test_basic(self) -> None:
|
||||
"""Test basic fetching of remote media from a single worker."""
|
||||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.responseHeaders.setRawHeaders(
|
||||
b"Content-Type",
|
||||
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
|
||||
)
|
||||
request.write(self.file_data)
|
||||
request.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.result["body"], b"file_to_stream")
|
||||
|
||||
def test_download_simple_file_race(self) -> None:
|
||||
"""Test that fetching remote media from two different processes at the
|
||||
same time works.
|
||||
"""
|
||||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
hs2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
start_count = self._count_remote_media()
|
||||
|
||||
# Make two requests without responding to the outbound media requests.
|
||||
channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
|
||||
channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
|
||||
|
||||
# Respond to the first outbound media request and check that the client
|
||||
# request is successful
|
||||
request1.setResponseCode(200)
|
||||
request1.responseHeaders.setRawHeaders(
|
||||
b"Content-Type",
|
||||
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
|
||||
)
|
||||
request1.write(self.file_data)
|
||||
request1.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel1.code, 200, channel1.result["body"])
|
||||
self.assertEqual(channel1.result["body"], b"file_to_stream")
|
||||
|
||||
# Now respond to the second with the same content.
|
||||
request2.setResponseCode(200)
|
||||
request2.responseHeaders.setRawHeaders(
|
||||
b"Content-Type",
|
||||
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
|
||||
)
|
||||
request2.write(self.file_data)
|
||||
request2.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel2.code, 200, channel2.result["body"])
|
||||
self.assertEqual(channel2.result["body"], b"file_to_stream")
|
||||
|
||||
# We expect only one new file to have been persisted.
|
||||
self.assertEqual(start_count + 1, self._count_remote_media())
|
||||
|
||||
def test_download_image_race(self) -> None:
|
||||
"""Test that fetching remote *images* from two different processes at
|
||||
the same time works.
|
||||
|
||||
This checks that races generating thumbnails are handled correctly.
|
||||
"""
|
||||
hs1 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
hs2 = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
start_count = self._count_remote_thumbnails()
|
||||
|
||||
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
|
||||
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
|
||||
|
||||
request1.setResponseCode(200)
|
||||
request1.responseHeaders.setRawHeaders(
|
||||
b"Content-Type",
|
||||
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
|
||||
)
|
||||
img_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: image/png\r\nContent-Disposition: inline; filename=test_img\r\n\r\n"
|
||||
request1.write(img_data)
|
||||
request1.write(SMALL_PNG)
|
||||
request1.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
|
||||
request1.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel1.code, 200, channel1.result["body"])
|
||||
self.assertEqual(channel1.result["body"], SMALL_PNG)
|
||||
|
||||
request2.setResponseCode(200)
|
||||
request2.responseHeaders.setRawHeaders(
|
||||
b"Content-Type",
|
||||
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
|
||||
)
|
||||
request2.write(img_data)
|
||||
request2.write(SMALL_PNG)
|
||||
request2.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
|
||||
request2.finish()
|
||||
|
||||
self.pump(0.1)
|
||||
|
||||
self.assertEqual(channel2.code, 200, channel2.result["body"])
|
||||
self.assertEqual(channel2.result["body"], SMALL_PNG)
|
||||
|
||||
# We expect only three new thumbnails to have been persisted.
|
||||
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
|
||||
|
||||
def _count_remote_media(self) -> int:
|
||||
"""Count the number of files in our remote media directory."""
|
||||
path = os.path.join(
|
||||
self.hs.get_media_repository().primary_base_path, "remote_content"
|
||||
)
|
||||
return sum(len(files) for _, _, files in os.walk(path))
|
||||
|
||||
def _count_remote_thumbnails(self) -> int:
|
||||
"""Count the number of files in our remote thumbnails directory."""
|
||||
path = os.path.join(
|
||||
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
|
||||
)
|
||||
return sum(len(files) for _, _, files in os.walk(path))
|
||||
|
||||
|
||||
def _log_request(request: Request) -> None:
|
||||
"""Implements Factory.log, which is expected by Request.finish"""
|
||||
logger.info("Completed request %s", request)
|
||||
|
|
|
@ -19,31 +19,54 @@
|
|||
#
|
||||
#
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type
|
||||
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from urllib import parse
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
from parameterized import parameterized_class
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet._resolver import HostResolution
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.internet.interfaces import IAddress, IResolutionReceiver
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config.oembed import OEmbedEndpointConfig
|
||||
from synapse.http.client import MultipartResponse
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.media._base import FileInfo
|
||||
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, media
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||
from tests.media.test_media_storage import (
|
||||
SVG,
|
||||
TestImage,
|
||||
empty_file,
|
||||
small_lossless_webp,
|
||||
small_png,
|
||||
small_png_with_transparency,
|
||||
)
|
||||
from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock
|
||||
from tests.test_utils import SMALL_PNG
|
||||
from tests.unittest import override_config
|
||||
|
||||
|
@ -1607,3 +1630,583 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
|
|||
self.assertEqual(
|
||||
channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size
|
||||
)
|
||||
|
||||
|
||||
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
media.register_servlets,
|
||||
login.register_servlets,
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
self.media_store_path = self.mktemp()
|
||||
os.mkdir(self.storage_path)
|
||||
os.mkdir(self.media_store_path)
|
||||
config["media_store_path"] = self.media_store_path
|
||||
|
||||
provider_config = {
|
||||
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
|
||||
"store_local": True,
|
||||
"store_synchronous": False,
|
||||
"store_remote": True,
|
||||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
|
||||
config["media_storage_providers"] = [provider_config]
|
||||
|
||||
return self.setup_test_homeserver(config=config)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.repo = hs.get_media_repository()
|
||||
self.client = hs.get_federation_http_client()
|
||||
self.store = hs.get_datastores().main
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.tok = self.login("user", "pass")
|
||||
|
||||
# mock actually reading file body
|
||||
def read_multipart_response_30MiB(*args: Any, **kwargs: Any) -> Deferred:
|
||||
d: Deferred = defer.Deferred()
|
||||
d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
|
||||
return d
|
||||
|
||||
def read_multipart_response_50MiB(*args: Any, **kwargs: Any) -> Deferred:
|
||||
d: Deferred = defer.Deferred()
|
||||
d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
|
||||
return d
|
||||
|
||||
@patch(
|
||||
"synapse.http.matrixfederationclient.read_multipart_response",
|
||||
read_multipart_response_30MiB,
|
||||
)
|
||||
def test_download_ratelimit_default(self) -> None:
|
||||
"""
|
||||
Test remote media download ratelimiting against default configuration - 500MB bucket
|
||||
and 87kb/second drain rate
|
||||
"""
|
||||
|
||||
# mock out actually sending the request, returns a 30MiB response
|
||||
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||
resp = MagicMock(spec=IResponse)
|
||||
resp.code = 200
|
||||
resp.length = 31457280
|
||||
resp.headers = Headers(
|
||||
{"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
|
||||
)
|
||||
resp.phrase = b"OK"
|
||||
return resp
|
||||
|
||||
self.client._send_request = _send_request # type: ignore
|
||||
|
||||
# first request should go through
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abc",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel.code == 200
|
||||
|
||||
# next 15 should go through
|
||||
for i in range(15):
|
||||
channel2 = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v1/media/download/remote.org/abc{i}",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel2.code == 200
|
||||
|
||||
# 17th will hit ratelimit
|
||||
channel3 = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abcd",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel3.code == 429
|
||||
|
||||
# however, a request from a different IP will go through
|
||||
channel4 = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abcde",
|
||||
shorthand=False,
|
||||
client_ip="187.233.230.159",
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel4.code == 200
|
||||
|
||||
# at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
|
||||
# 30MiB download is authorized - The last download was blocked at 503,316,480.
|
||||
# The next download will be authorized when bucket hits 492,830,720
|
||||
# (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
|
||||
# needs to drain before another download will be authorized, that will take ~=
|
||||
# 2 minutes (10,485,760/89,088/60)
|
||||
self.reactor.pump([2.0 * 60.0])
|
||||
|
||||
# enough has drained and next request goes through
|
||||
channel5 = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abcdef",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel5.code == 200
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"remote_media_download_per_second": "50M",
|
||||
"remote_media_download_burst_count": "50M",
|
||||
}
|
||||
)
|
||||
@patch(
|
||||
"synapse.http.matrixfederationclient.read_multipart_response",
|
||||
read_multipart_response_50MiB,
|
||||
)
|
||||
def test_download_rate_limit_config(self) -> None:
|
||||
"""
|
||||
Test that download rate limit config options are correctly picked up and applied
|
||||
"""
|
||||
|
||||
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||
resp = MagicMock(spec=IResponse)
|
||||
resp.code = 200
|
||||
resp.length = 52428800
|
||||
resp.headers = Headers(
|
||||
{"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
|
||||
)
|
||||
resp.phrase = b"OK"
|
||||
return resp
|
||||
|
||||
self.client._send_request = _send_request # type: ignore
|
||||
|
||||
# first request should go through
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abc",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel.code == 200
|
||||
|
||||
# immediate second request should fail
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abcd",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel.code == 429
|
||||
|
||||
# advance half a second
|
||||
self.reactor.pump([0.5])
|
||||
|
||||
# request still fails
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abcde",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel.code == 429
|
||||
|
||||
# advance another half second
|
||||
self.reactor.pump([0.5])
|
||||
|
||||
# enough has drained from bucket and request is successful
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abcdef",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel.code == 200
|
||||
|
||||
@patch(
|
||||
"synapse.http.matrixfederationclient.read_multipart_response",
|
||||
read_multipart_response_30MiB,
|
||||
)
|
||||
def test_download_ratelimit_max_size_sub(self) -> None:
|
||||
"""
|
||||
Test that if no content-length is provided, the default max size is applied instead
|
||||
"""
|
||||
|
||||
# mock out actually sending the request
|
||||
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||
resp = MagicMock(spec=IResponse)
|
||||
resp.code = 200
|
||||
resp.length = UNKNOWN_LENGTH
|
||||
resp.headers = Headers(
|
||||
{"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
|
||||
)
|
||||
resp.phrase = b"OK"
|
||||
return resp
|
||||
|
||||
self.client._send_request = _send_request # type: ignore
|
||||
|
||||
# ten requests should go through using the max size (500MB/50MB)
|
||||
for i in range(10):
|
||||
channel2 = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v1/media/download/remote.org/abc{i}",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel2.code == 200
|
||||
|
||||
# eleventh will hit ratelimit
|
||||
channel3 = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/v1/media/download/remote.org/abcd",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
assert channel3.code == 429
|
||||
|
||||
def test_file_download(self) -> None:
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
self.repo.create_content(
|
||||
"text/plain",
|
||||
"test_upload",
|
||||
content,
|
||||
46,
|
||||
UserID.from_string("@user_id:whatever.org"),
|
||||
)
|
||||
)
|
||||
# test with a text file
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v1/media/download/test/{content_uri.media_id}",
|
||||
shorthand=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(200, channel.code)
|
||||
|
||||
|
||||
test_images = [
|
||||
small_png,
|
||||
small_png_with_transparency,
|
||||
small_lossless_webp,
|
||||
empty_file,
|
||||
SVG,
|
||||
]
|
||||
input_values = [(x,) for x in test_images]
|
||||
|
||||
|
||||
@parameterized_class(("test_image",), input_values)
|
||||
class DownloadTestCase(unittest.HomeserverTestCase):
|
||||
test_image: ClassVar[TestImage]
|
||||
servlets = [
|
||||
media.register_servlets,
|
||||
login.register_servlets,
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.fetches: List[
|
||||
Tuple[
|
||||
"Deferred[Any]",
|
||||
str,
|
||||
str,
|
||||
Optional[QueryParams],
|
||||
]
|
||||
] = []
|
||||
|
||||
def federation_get_file(
|
||||
destination: str,
|
||||
path: str,
|
||||
output_stream: BinaryIO,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: Any,
|
||||
max_size: int,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
ignore_backoff: bool = False,
|
||||
follow_redirects: bool = False,
|
||||
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]], bytes]]":
|
||||
"""A mock for MatrixFederationHttpClient.federation_get_file."""
|
||||
|
||||
def write_to(
|
||||
r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
|
||||
data, response = r
|
||||
output_stream.write(data)
|
||||
return response
|
||||
|
||||
def write_err(f: Failure) -> Failure:
|
||||
f.trap(HttpResponseException)
|
||||
output_stream.write(f.value.response)
|
||||
return f
|
||||
|
||||
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]] = (
|
||||
Deferred()
|
||||
)
|
||||
self.fetches.append((d, destination, path, args))
|
||||
# Note that this callback changes the value held by d.
|
||||
d_after_callback = d.addCallbacks(write_to, write_err)
|
||||
return make_deferred_yieldable(d_after_callback)
|
||||
|
||||
def get_file(
|
||||
destination: str,
|
||||
path: str,
|
||||
output_stream: BinaryIO,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: Any,
|
||||
max_size: int,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
ignore_backoff: bool = False,
|
||||
follow_redirects: bool = False,
|
||||
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
|
||||
"""A mock for MatrixFederationHttpClient.get_file."""
|
||||
|
||||
def write_to(
|
||||
r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||
data, response = r
|
||||
output_stream.write(data)
|
||||
return response
|
||||
|
||||
def write_err(f: Failure) -> Failure:
|
||||
f.trap(HttpResponseException)
|
||||
output_stream.write(f.value.response)
|
||||
return f
|
||||
|
||||
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
|
||||
self.fetches.append((d, destination, path, args))
|
||||
# Note that this callback changes the value held by d.
|
||||
d_after_callback = d.addCallbacks(write_to, write_err)
|
||||
return make_deferred_yieldable(d_after_callback)
|
||||
|
||||
# Mock out the homeserver's MatrixFederationHttpClient
|
||||
client = Mock()
|
||||
client.federation_get_file = federation_get_file
|
||||
client.get_file = get_file
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
self.media_store_path = self.mktemp()
|
||||
os.mkdir(self.storage_path)
|
||||
os.mkdir(self.media_store_path)
|
||||
|
||||
config = self.default_config()
|
||||
config["media_store_path"] = self.media_store_path
|
||||
config["max_image_pixels"] = 2000000
|
||||
|
||||
provider_config = {
|
||||
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
|
||||
"store_local": True,
|
||||
"store_synchronous": False,
|
||||
"store_remote": True,
|
||||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
config["media_storage_providers"] = [provider_config]
|
||||
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
|
||||
|
||||
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.media_repo = hs.get_media_repository()
|
||||
|
||||
self.remote = "example.com"
|
||||
self.media_id = "12345"
|
||||
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.tok = self.login("user", "pass")
|
||||
|
||||
def _req(
|
||||
self, content_disposition: Optional[bytes], include_content_type: bool = True
|
||||
) -> FakeChannel:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
|
||||
shorthand=False,
|
||||
await_result=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.pump()
|
||||
|
||||
# We've made one fetch, to example.com, using the federation media URL
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
self.assertEqual(self.fetches[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.fetches[0][2], "/_matrix/federation/v1/media/download/" + self.media_id
|
||||
)
|
||||
self.assertEqual(
|
||||
self.fetches[0][3],
|
||||
{"timeout_ms": "20000"},
|
||||
)
|
||||
|
||||
headers = {
|
||||
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
||||
}
|
||||
|
||||
if include_content_type:
|
||||
headers[b"Content-Type"] = [self.test_image.content_type]
|
||||
|
||||
if content_disposition:
|
||||
headers[b"Content-Disposition"] = [content_disposition]
|
||||
|
||||
self.fetches[0][0].callback(
|
||||
(self.test_image.data, (len(self.test_image.data), headers, b"{}"))
|
||||
)
|
||||
|
||||
self.pump()
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
return channel
|
||||
|
||||
def test_handle_missing_content_type(self) -> None:
|
||||
channel = self._req(
|
||||
b"attachment; filename=out" + self.test_image.extension,
|
||||
include_content_type=False,
|
||||
)
|
||||
headers = channel.headers
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
|
||||
)
|
||||
|
||||
def test_disposition_filename_ascii(self) -> None:
|
||||
"""
|
||||
If the filename is filename=<ascii> then Synapse will decode it as an
|
||||
ASCII string, and use filename= in the response.
|
||||
"""
|
||||
channel = self._req(b"attachment; filename=out" + self.test_image.extension)
|
||||
|
||||
headers = channel.headers
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
|
||||
)
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"Content-Disposition"),
|
||||
[
|
||||
(b"inline" if self.test_image.is_inline else b"attachment")
|
||||
+ b"; filename=out"
|
||||
+ self.test_image.extension
|
||||
],
|
||||
)
|
||||
|
||||
def test_disposition_filenamestar_utf8escaped(self) -> None:
|
||||
"""
|
||||
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
|
||||
correctly decode it as the UTF-8 string, and use filename* in the
|
||||
response.
|
||||
"""
|
||||
filename = parse.quote("\u2603".encode()).encode("ascii")
|
||||
channel = self._req(
|
||||
b"attachment; filename*=utf-8''" + filename + self.test_image.extension
|
||||
)
|
||||
|
||||
headers = channel.headers
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
|
||||
)
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"Content-Disposition"),
|
||||
[
|
||||
(b"inline" if self.test_image.is_inline else b"attachment")
|
||||
+ b"; filename*=utf-8''"
|
||||
+ filename
|
||||
+ self.test_image.extension
|
||||
],
|
||||
)
|
||||
|
||||
def test_disposition_none(self) -> None:
|
||||
"""
|
||||
If there is no filename, Content-Disposition should only
|
||||
be a disposition type.
|
||||
"""
|
||||
channel = self._req(None)
|
||||
|
||||
headers = channel.headers
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
|
||||
)
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"Content-Disposition"),
|
||||
[b"inline" if self.test_image.is_inline else b"attachment"],
|
||||
)
|
||||
|
||||
def test_x_robots_tag_header(self) -> None:
|
||||
"""
|
||||
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
|
||||
to not index, archive, or follow links in media.
|
||||
"""
|
||||
channel = self._req(b"attachment; filename=out" + self.test_image.extension)
|
||||
|
||||
headers = channel.headers
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"X-Robots-Tag"),
|
||||
[b"noindex, nofollow, noarchive, noimageindex"],
|
||||
)
|
||||
|
||||
def test_cross_origin_resource_policy_header(self) -> None:
|
||||
"""
|
||||
Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
|
||||
allowing web clients to embed media from the downloads API.
|
||||
"""
|
||||
channel = self._req(b"attachment; filename=out" + self.test_image.extension)
|
||||
|
||||
headers = channel.headers
|
||||
|
||||
self.assertEqual(
|
||||
headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
|
||||
[b"cross-origin"],
|
||||
)
|
||||
|
||||
def test_unknown_federation_endpoint(self) -> None:
|
||||
"""
|
||||
Test that if the downloadd request to remote federation endpoint returns a 404
|
||||
we fall back to the _matrix/media endpoint
|
||||
"""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
|
||||
shorthand=False,
|
||||
await_result=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.pump()
|
||||
|
||||
# We've made one fetch, to example.com, using the media URL, and asking
|
||||
# the other server not to do a remote fetch
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
self.assertEqual(self.fetches[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
|
||||
)
|
||||
|
||||
# The result which says the endpoint is unknown.
|
||||
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
|
||||
self.fetches[0][0].errback(
|
||||
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
||||
# There should now be another request to the _matrix/media/v3/download URL.
|
||||
self.assertEqual(len(self.fetches), 2)
|
||||
self.assertEqual(self.fetches[1][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.fetches[1][2],
|
||||
f"/_matrix/media/v3/download/example.com/{self.media_id}",
|
||||
)
|
||||
|
||||
headers = {
|
||||
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
||||
}
|
||||
|
||||
self.fetches[1][0].callback(
|
||||
(self.test_image.data, (len(self.test_image.data), headers))
|
||||
)
|
||||
|
||||
self.pump()
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
|
|
@ -1387,10 +1387,12 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
|
|||
# Create a future token that will cause us to wait. Since we never send a new
|
||||
# event to reach that future stream_ordering, the worker will wait until the
|
||||
# full timeout.
|
||||
stream_id_gen = self.store.get_events_stream_id_generator()
|
||||
stream_id = self.get_success(stream_id_gen.get_next().__aenter__())
|
||||
current_token = self.event_sources.get_current_token()
|
||||
future_position_token = current_token.copy_and_replace(
|
||||
StreamKeyType.ROOM,
|
||||
RoomStreamToken(stream=current_token.room_key.stream + 1),
|
||||
RoomStreamToken(stream=stream_id),
|
||||
)
|
||||
|
||||
future_position_token_serialized = self.get_success(
|
||||
|
|
Loading…
Reference in a new issue