Support MSC3916 by adding a federation /download endpoint (#17172)

This commit is contained in:
Shay 2024-06-07 05:54:28 -07:00 committed by GitHub
parent 17d6c28285
commit ab94bce02c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 659 additions and 24 deletions

View file

@ -0,0 +1,2 @@
Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md)
by adding a federation /download endpoint (#17172).

View file

@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
import inspect
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type
@ -33,6 +34,7 @@ from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES, FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet, FederationAccountStatusServlet,
FederationUnstableClientKeysClaimServlet, FederationUnstableClientKeysClaimServlet,
FederationUnstableMediaDownloadServlet,
) )
from synapse.http.server import HttpServer, JsonResource from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -315,6 +317,28 @@ def register_servlets(
): ):
continue continue
if servletclass == FederationUnstableMediaDownloadServlet:
if (
not hs.config.server.enable_media_repo
or not hs.config.experimental.msc3916_authenticated_media_enabled
):
continue
# don't load the endpoint if the storage provider is incompatible
media_repo = hs.get_media_repository()
load_download_endpoint = True
for provider in media_repo.media_storage.storage_providers:
signature = inspect.signature(provider.backend.fetch)
if "federation" not in signature.parameters:
logger.warning(
f"Federation media `/download` endpoint will not be enabled as storage provider {provider.backend} is not compatible with this endpoint."
)
load_download_endpoint = False
break
if not load_download_endpoint:
continue
servletclass( servletclass(
hs=hs, hs=hs,
authenticator=authenticator, authenticator=authenticator,

View file

@ -360,13 +360,29 @@ class BaseFederationServlet:
"request" "request"
) )
return None return None
if (
func.__self__.__class__.__name__ # type: ignore
== "FederationUnstableMediaDownloadServlet"
):
response = await func(
origin, content, request, *args, **kwargs
)
else:
response = await func(
origin, content, request.args, *args, **kwargs
)
else:
if (
func.__self__.__class__.__name__ # type: ignore
== "FederationUnstableMediaDownloadServlet"
):
response = await func(
origin, content, request, *args, **kwargs
)
else:
response = await func( response = await func(
origin, content, request.args, *args, **kwargs origin, content, request.args, *args, **kwargs
) )
else:
response = await func(
origin, content, request.args, *args, **kwargs
)
finally: finally:
# if we used the origin's context as the parent, add a new span using # if we used the origin's context as the parent, add a new span using
# the servlet span as a parent, so that we have a link # the servlet span as a parent, so that we have a link

View file

@ -44,10 +44,13 @@ from synapse.federation.transport.server._base import (
) )
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_boolean_from_args, parse_boolean_from_args,
parse_integer,
parse_integer_from_args, parse_integer_from_args,
parse_string_from_args, parse_string_from_args,
parse_strings_from_args, parse_strings_from_args,
) )
from synapse.http.site import SynapseRequest
from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import SYNAPSE_VERSION from synapse.util import SYNAPSE_VERSION
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
@ -787,6 +790,43 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
return 200, {"account_statuses": statuses, "failures": failures} return 200, {"account_statuses": statuses, "failures": failures}
class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
"""
Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
a multipart/form-data response consisting of a JSON object and the requested media
item. This endpoint only returns local media.
"""
PATH = "/media/download/(?P<media_id>[^/]*)"
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
RATELIMIT = True
def __init__(
self,
hs: "HomeServer",
ratelimiter: FederationRateLimiter,
authenticator: Authenticator,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.media_repo = self.hs.get_media_repository()
async def on_GET(
self,
origin: Optional[str],
content: Literal[None],
request: SynapseRequest,
media_id: str,
) -> None:
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
await self.media_repo.get_local_media(
request, media_id, None, max_timeout_ms, federation=True
)
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationSendServlet, FederationSendServlet,
FederationEventServlet, FederationEventServlet,
@ -818,4 +858,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationV1SendKnockServlet, FederationV1SendKnockServlet,
FederationMakeKnockServlet, FederationMakeKnockServlet,
FederationAccountStatusServlet, FederationAccountStatusServlet,
FederationUnstableMediaDownloadServlet,
) )

View file

@ -25,7 +25,16 @@ import os
import urllib import urllib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from types import TracebackType from types import TracebackType
from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type from typing import (
TYPE_CHECKING,
Awaitable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
)
import attr import attr
@ -39,6 +48,11 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
if TYPE_CHECKING:
from synapse.media.media_storage import MultipartResponder
from synapse.storage.databases.main.media_repository import LocalMedia
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# list all text content types that will have the charset default to UTF-8 when # list all text content types that will have the charset default to UTF-8 when
@ -260,6 +274,53 @@ def _can_encode_filename_as_token(x: str) -> bool:
return True return True
async def respond_with_multipart_responder(
request: SynapseRequest,
responder: "Optional[MultipartResponder]",
media_info: "LocalMedia",
) -> None:
"""
Responds via a Multipart responder for the federation media `/download` requests
Args:
request: the federation request to respond to
responder: the Multipart responder which will send the response
media_info: metadata about the media item
"""
if not responder:
respond_404(request)
return
# If we have a responder we *must* use it as a context manager.
with responder:
if request._disconnected:
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
logger.debug("Responding to media request with responder %s", responder)
if media_info.media_length is not None:
request.setHeader(b"Content-Length", b"%d" % (media_info.media_length,))
request.setHeader(
b"Content-Type", b"multipart/mixed; boundary=%s" % responder.boundary
)
try:
await responder.write_to_consumer(request)
except Exception as e:
# The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us
# in that case.
logger.warning("Failed to write to consumer: %s %s", type(e), e)
# Unregister the producer, if it has one, so Twisted doesn't complain
if request.producer:
request.unregisterProducer()
finish_request(request)
async def respond_with_responder( async def respond_with_responder(
request: SynapseRequest, request: SynapseRequest,
responder: "Optional[Responder]", responder: "Optional[Responder]",

View file

@ -54,10 +54,11 @@ from synapse.media._base import (
ThumbnailInfo, ThumbnailInfo,
get_filename_from_headers, get_filename_from_headers,
respond_404, respond_404,
respond_with_multipart_responder,
respond_with_responder, respond_with_responder,
) )
from synapse.media.filepath import MediaFilePaths from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage from synapse.media.media_storage import MediaStorage, MultipartResponder
from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer from synapse.media.url_previewer import UrlPreviewer
@ -429,6 +430,7 @@ class MediaRepository:
media_id: str, media_id: str,
name: Optional[str], name: Optional[str],
max_timeout_ms: int, max_timeout_ms: int,
federation: bool = False,
) -> None: ) -> None:
"""Responds to requests for local media, if exists, or returns 404. """Responds to requests for local media, if exists, or returns 404.
@ -440,6 +442,7 @@ class MediaRepository:
the filename in the Content-Disposition header of the response. the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded. media to be uploaded.
federation: whether the local media being fetched is for a federation request
Returns: Returns:
Resolves once a response has successfully been written to request Resolves once a response has successfully been written to request
@ -459,10 +462,17 @@ class MediaRepository:
file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(
await respond_with_responder( file_info, media_info, federation
request, responder, media_type, media_length, upload_name
) )
if federation:
# this really should be a Multipart responder but just in case
assert isinstance(responder, MultipartResponder)
await respond_with_multipart_responder(request, responder, media_info)
else:
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
async def get_remote_media( async def get_remote_media(
self, self,

View file

@ -19,9 +19,12 @@
# #
# #
import contextlib import contextlib
import json
import logging import logging
import os import os
import shutil import shutil
from contextlib import closing
from io import BytesIO
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
IO, IO,
@ -30,14 +33,19 @@ from typing import (
AsyncIterator, AsyncIterator,
BinaryIO, BinaryIO,
Callable, Callable,
List,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
Union,
) )
from uuid import uuid4
import attr import attr
from zope.interface import implementer
from twisted.internet import defer, interfaces
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
@ -48,15 +56,19 @@ from synapse.logging.opentracing import start_active_span, trace, trace_with_opn
from synapse.util import Clock from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.file_consumer import BackgroundFileConsumer
from ..storage.databases.main.media_repository import LocalMedia
from ..types import JsonDict
from ._base import FileInfo, Responder from ._base import FileInfo, Responder
from .filepath import MediaFilePaths from .filepath import MediaFilePaths
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.media.storage_provider import StorageProvider from synapse.media.storage_provider import StorageProviderWrapper
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CRLF = b"\r\n"
class MediaStorage: class MediaStorage:
"""Responsible for storing/fetching files from local sources. """Responsible for storing/fetching files from local sources.
@ -73,7 +85,7 @@ class MediaStorage:
hs: "HomeServer", hs: "HomeServer",
local_media_directory: str, local_media_directory: str,
filepaths: MediaFilePaths, filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"], storage_providers: Sequence["StorageProviderWrapper"],
): ):
self.hs = hs self.hs = hs
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
@ -169,15 +181,23 @@ class MediaStorage:
raise e from None raise e from None
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: async def fetch_media(
self,
file_info: FileInfo,
media_info: Optional[LocalMedia] = None,
federation: bool = False,
) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache """Attempts to fetch media described by file_info from the local cache
and configured storage providers. and configured storage providers.
Args: Args:
file_info file_info: Metadata about the media file
media_info: Metadata about the media item
federation: Whether this file is being fetched for a federation request
Returns: Returns:
Returns a Responder if the file was found, otherwise None. If the file was found returns a Responder (a Multipart Responder if the requested
file is for the federation /download endpoint), otherwise None.
""" """
paths = [self._file_info_to_path(file_info)] paths = [self._file_info_to_path(file_info)]
@ -197,12 +217,19 @@ class MediaStorage:
local_path = os.path.join(self.local_media_directory, path) local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path): if os.path.exists(local_path):
logger.debug("responding with local file %s", local_path) logger.debug("responding with local file %s", local_path)
return FileResponder(open(local_path, "rb")) if federation:
assert media_info is not None
boundary = uuid4().hex.encode("ascii")
return MultipartResponder(
open(local_path, "rb"), media_info, boundary
)
else:
return FileResponder(open(local_path, "rb"))
logger.debug("local file %s did not exist", local_path) logger.debug("local file %s did not exist", local_path)
for provider in self.storage_providers: for provider in self.storage_providers:
for path in paths: for path in paths:
res: Any = await provider.fetch(path, file_info) res: Any = await provider.fetch(path, file_info, media_info, federation)
if res: if res:
logger.debug("Streaming %s from %s", path, provider) logger.debug("Streaming %s from %s", path, provider)
return res return res
@ -316,7 +343,7 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request. """Wraps an open file that can be sent to a request.
Args: Args:
open_file: A file like object to be streamed ot the client, open_file: A file like object to be streamed to the client,
is closed when finished streaming. is closed when finished streaming.
""" """
@ -337,6 +364,38 @@ class FileResponder(Responder):
self.open_file.close() self.open_file.close()
class MultipartResponder(Responder):
"""Wraps an open file, formats the response according to MSC3916 and sends it to a
federation request.
Args:
open_file: A file like object to be streamed to the client,
is closed when finished streaming.
media_info: metadata about the media item
boundary: bytes to use for the multipart response boundary
"""
def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None:
self.open_file = open_file
self.media_info = media_info
self.boundary = boundary
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
MultipartFileSender().beginFileTransfer(
self.open_file, consumer, self.media_info.media_type, {}, self.boundary
)
)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.open_file.close()
class SpamMediaException(NotFoundError): class SpamMediaException(NotFoundError):
"""The media was blocked by a spam checker, so we simply 404 the request (in """The media was blocked by a spam checker, so we simply 404 the request (in
the same way as if it was quarantined). the same way as if it was quarantined).
@ -370,3 +429,151 @@ class ReadableFileWrapper:
# We yield to the reactor by sleeping for 0 seconds. # We yield to the reactor by sleeping for 0 seconds.
await self.clock.sleep(0) await self.clock.sleep(0)
@implementer(interfaces.IProducer)
class MultipartFileSender:
"""
A producer that sends the contents of a file to a federation request in the format
outlined in MSC3916 - a multipart/format-data response where the first field is a
JSON object and the second is the requested file.
This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format
outlined above.
"""
CHUNK_SIZE = 2**14
lastSent = ""
deferred: Optional[defer.Deferred] = None
def beginFileTransfer(
self,
file: IO,
consumer: IConsumer,
file_content_type: str,
json_object: JsonDict,
boundary: bytes,
) -> Deferred:
"""
Begin transferring a file
Args:
file: The file object to read data from
consumer: The synapse request to write the data to
file_content_type: The content-type of the file
json_object: The JSON object to write to the first field of the response
boundary: bytes to be used as the multipart/form-data boundary
Returns: A deferred whose callback will be invoked when the file has
been completely written to the consumer. The last byte written to the
consumer is passed to the callback.
"""
self.file: Optional[IO] = file
self.consumer = consumer
self.json_field = json_object
self.json_field_written = False
self.content_type_written = False
self.file_content_type = file_content_type
self.boundary = boundary
self.deferred: Deferred = defer.Deferred()
self.consumer.registerProducer(self, False)
# while it's not entirely clear why this assignment is necessary, it mirrors
# the behavior in FileSender.beginFileTransfer and thus is preserved here
deferred = self.deferred
return deferred
def resumeProducing(self) -> None:
# write the first field, which will always be a json field
if not self.json_field_written:
self.consumer.write(CRLF + b"--" + self.boundary + CRLF)
content_type = Header(b"Content-Type", b"application/json")
self.consumer.write(bytes(content_type) + CRLF)
json_field = json.dumps(self.json_field)
json_bytes = json_field.encode("utf-8")
self.consumer.write(json_bytes)
self.consumer.write(CRLF + b"--" + self.boundary + CRLF)
self.json_field_written = True
chunk: Any = ""
if self.file:
# if we haven't written the content type yet, do so
if not self.content_type_written:
type = self.file_content_type.encode("utf-8")
content_type = Header(b"Content-Type", type)
self.consumer.write(bytes(content_type) + CRLF)
self.content_type_written = True
chunk = self.file.read(self.CHUNK_SIZE)
if not chunk:
# we've reached the end of the file
self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
self.file = None
self.consumer.unregisterProducer()
if self.deferred:
self.deferred.callback(self.lastSent)
self.deferred = None
return
self.consumer.write(chunk)
self.lastSent = chunk[-1:]
def pauseProducing(self) -> None:
pass
def stopProducing(self) -> None:
if self.deferred:
self.deferred.errback(Exception("Consumer asked us to stop producing"))
self.deferred = None
class Header:
"""
`Header` This class is a tiny wrapper that produces
request headers. We can't use standard python header
class because it encodes unicode fields using =? bla bla ?=
encoding, which is correct, but no one in HTTP world expects
that, everyone wants utf-8 raw bytes. (stolen from treq.multipart)
"""
def __init__(
self,
name: bytes,
value: Any,
params: Optional[List[Tuple[Any, Any]]] = None,
):
self.name = name
self.value = value
self.params = params or []
def add_param(self, name: Any, value: Any) -> None:
self.params.append((name, value))
def __bytes__(self) -> bytes:
with closing(BytesIO()) as h:
h.write(self.name + b": " + escape(self.value).encode("us-ascii"))
if self.params:
for name, val in self.params:
h.write(b"; ")
h.write(escape(name).encode("us-ascii"))
h.write(b"=")
h.write(b'"' + escape(val).encode("utf-8") + b'"')
h.seek(0)
return h.read()
def escape(value: Union[str, bytes]) -> str:
"""
This function prevents header values from corrupting the request,
a newline in the file name parameter makes form-data request unreadable
for a majority of parsers. (stolen from treq.multipart)
"""
if isinstance(value, bytes):
value = value.decode("utf-8")
return value.replace("\r", "").replace("\n", "").replace('"', '\\"')

View file

@ -24,14 +24,16 @@ import logging
import os import os
import shutil import shutil
from typing import TYPE_CHECKING, Callable, Optional from typing import TYPE_CHECKING, Callable, Optional
from uuid import uuid4
from synapse.config._base import Config from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.context import defer_to_thread, run_in_background
from synapse.logging.opentracing import start_active_span, trace_with_opname from synapse.logging.opentracing import start_active_span, trace_with_opname
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
from ..storage.databases.main.media_repository import LocalMedia
from ._base import FileInfo, Responder from ._base import FileInfo, Responder
from .media_storage import FileResponder from .media_storage import FileResponder, MultipartResponder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,13 +57,21 @@ class StorageProvider(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: async def fetch(
self,
path: str,
file_info: FileInfo,
media_info: Optional[LocalMedia] = None,
federation: bool = False,
) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it """Attempt to fetch the file described by file_info and stream it
into writer. into writer.
Args: Args:
path: Relative path of file in local cache path: Relative path of file in local cache
file_info: The metadata of the file. file_info: The metadata of the file.
media_info: metadata of the media item
federation: Whether the requested media is for a federation request
Returns: Returns:
Returns a Responder if the provider has the file, otherwise returns None. Returns a Responder if the provider has the file, otherwise returns None.
@ -124,7 +134,13 @@ class StorageProviderWrapper(StorageProvider):
run_in_background(store) run_in_background(store)
@trace_with_opname("StorageProviderWrapper.fetch") @trace_with_opname("StorageProviderWrapper.fetch")
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: async def fetch(
self,
path: str,
file_info: FileInfo,
media_info: Optional[LocalMedia] = None,
federation: bool = False,
) -> Optional[Responder]:
if file_info.url_cache: if file_info.url_cache:
# Files in the URL preview cache definitely aren't stored here, # Files in the URL preview cache definitely aren't stored here,
# so avoid any potentially slow I/O or network access. # so avoid any potentially slow I/O or network access.
@ -132,7 +148,9 @@ class StorageProviderWrapper(StorageProvider):
# store_file is supposed to return an Awaitable, but guard # store_file is supposed to return an Awaitable, but guard
# against improper implementations. # against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info)) return await maybe_awaitable(
self.backend.fetch(path, file_info, media_info, federation)
)
class FileStorageProviderBackend(StorageProvider): class FileStorageProviderBackend(StorageProvider):
@ -172,11 +190,23 @@ class FileStorageProviderBackend(StorageProvider):
) )
@trace_with_opname("FileStorageProviderBackend.fetch") @trace_with_opname("FileStorageProviderBackend.fetch")
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: async def fetch(
self,
path: str,
file_info: FileInfo,
media_info: Optional[LocalMedia] = None,
federation: bool = False,
) -> Optional[Responder]:
"""See StorageProvider.fetch""" """See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path) backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname): if os.path.isfile(backup_fname):
if federation:
assert media_info is not None
boundary = uuid4().hex.encode("ascii")
return MultipartResponder(
open(backup_fname, "rb"), media_info, boundary
)
return FileResponder(open(backup_fname, "rb")) return FileResponder(open(backup_fname, "rb"))
return None return None

View file

@ -0,0 +1,234 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import io
import os
import shutil
import tempfile
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.media._base import FileInfo, Responder
from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage
from synapse.media.storage_provider import (
FileStorageProviderBackend,
StorageProviderWrapper,
)
from synapse.server import HomeServer
from synapse.storage.databases.main.media_repository import LocalMedia
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
from tests.test_utils import SMALL_PNG
from tests.unittest import override_config
class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir)
self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs.config.media.media_store_path = self.primary_base_path
storage_providers = [
StorageProviderWrapper(
FileStorageProviderBackend(hs, self.secondary_base_path),
store_local=True,
store_remote=False,
store_synchronous=True,
)
]
self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
hs, self.primary_base_path, self.filepaths, storage_providers
)
self.media_repo = hs.get_media_repository()
@override_config(
{"experimental_features": {"msc3916_authenticated_media_enabled": True}}
)
def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
self.media_repo.create_content(
"text/plain",
"test_upload",
content,
46,
UserID.from_string("@user_id:whatever.org"),
)
)
# test with a text file
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(200, channel.code)
content_type = channel.headers.getRawHeaders("content-type")
assert content_type is not None
assert "multipart/mixed" in content_type[0]
assert "boundary" in content_type[0]
# extract boundary
boundary = content_type[0].split("boundary=")[1]
# split on boundary and check that json field and expected value exist
stripped = channel.text_body.split("\r\n" + "--" + boundary)
# TODO: the json object expected will change once MSC3911 is implemented, currently
# {} is returned for all requests as a placeholder (per MSC3196)
found_json = any(
"\r\nContent-Type: application/json\r\n{}" in field for field in stripped
)
self.assertTrue(found_json)
# check that text file and expected value exist
found_file = any(
"\r\nContent-Type: text/plain\r\nfile_to_stream" in field
for field in stripped
)
self.assertTrue(found_file)
content = io.BytesIO(SMALL_PNG)
content_uri = self.get_success(
self.media_repo.create_content(
"image/png",
"test_png_upload",
content,
67,
UserID.from_string("@user_id:whatever.org"),
)
)
# test with an image file
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(200, channel.code)
content_type = channel.headers.getRawHeaders("content-type")
assert content_type is not None
assert "multipart/mixed" in content_type[0]
assert "boundary" in content_type[0]
# extract boundary
boundary = content_type[0].split("boundary=")[1]
# split on boundary and check that json field and expected value exist
body = channel.result.get("body")
assert body is not None
stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8"))
found_json = any(
b"\r\nContent-Type: application/json\r\n{}" in field
for field in stripped_bytes
)
self.assertTrue(found_json)
# check that png file exists and matches what was uploaded
found_file = any(SMALL_PNG in field for field in stripped_bytes)
self.assertTrue(found_file)
@override_config(
{"experimental_features": {"msc3916_authenticated_media_enabled": False}}
)
def test_disable_config(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
self.media_repo.create_content(
"text/plain",
"test_upload",
content,
46,
UserID.from_string("@user_id:whatever.org"),
)
)
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(404, channel.code)
self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED")
class FakeFileStorageProviderBackend:
"""
Fake storage provider stub with incompatible `fetch` signature for testing
"""
def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
self.cache_directory = hs.config.media.media_store_path
self.base_directory = config
def __str__(self) -> str:
return "FakeFileStorageProviderBackend[%s]" % (self.base_directory,)
async def fetch(
self, path: str, file_info: FileInfo, media_info: Optional[LocalMedia] = None
) -> Optional[Responder]:
pass
TEST_DIR = tempfile.mkdtemp(prefix="synapse-tests-")
class FederationUnstableMediaEndpointCompatibilityTest(
unittest.FederatingHomeserverTestCase
):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.test_dir = TEST_DIR
self.addCleanup(shutil.rmtree, self.test_dir)
self.media_repo = hs.get_media_repository()
def default_config(self) -> JsonDict:
config = super().default_config()
primary_base_path = os.path.join(TEST_DIR, "primary")
config["media_storage_providers"] = [
{
"module": "tests.federation.test_federation_media.FakeFileStorageProviderBackend",
"store_local": "True",
"store_remote": "False",
"store_synchronous": "False",
"config": {"directory": primary_base_path},
}
]
return config
@override_config(
{"experimental_features": {"msc3916_authenticated_media_enabled": True}}
)
def test_incompatible_storage_provider_fails_to_load_endpoint(self) -> None:
channel = self.make_signed_federation_request(
"GET",
"/_matrix/federation/unstable/org.matrix.msc3916/media/download/xyz",
)
self.pump()
self.assertEqual(404, channel.code)
self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED")

View file

@ -49,7 +49,10 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo, ThumbnailInfo from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.filepath import MediaFilePaths from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
from synapse.media.storage_provider import FileStorageProviderBackend from synapse.media.storage_provider import (
FileStorageProviderBackend,
StorageProviderWrapper,
)
from synapse.media.thumbnailer import ThumbnailProvider from synapse.media.thumbnailer import ThumbnailProvider
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
@ -78,7 +81,14 @@ class MediaStorageTests(unittest.HomeserverTestCase):
hs.config.media.media_store_path = self.primary_base_path hs.config.media.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] storage_providers = [
StorageProviderWrapper(
FileStorageProviderBackend(hs, self.secondary_base_path),
store_local=True,
store_remote=False,
store_synchronous=True,
)
]
self.filepaths = MediaFilePaths(self.primary_base_path) self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage( self.media_storage = MediaStorage(