fetch remote downloads for _matrix/client/v1/media/download over federation download endpoint

This commit is contained in:
H. Shay 2024-06-26 20:42:50 -07:00
parent c3b25c056d
commit fca35c89ae
5 changed files with 386 additions and 14 deletions

View file

@ -1871,6 +1871,26 @@ class FederationClient(FederationBase):
return filtered_statuses, filtered_failures return filtered_statuses, filtered_failures
async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
return await self.transport_layer.federation_download_media(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
async def download_media( async def download_media(
self, self,
destination: str, destination: str,

View file

@ -824,7 +824,6 @@ class TransportLayerClient:
ip_address: str, ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}" path = f"/_matrix/media/r0/download/{destination}/{media_id}"
return await self.client.get_file( return await self.client.get_file(
destination, destination,
path, path,
@ -852,7 +851,6 @@ class TransportLayerClient:
ip_address: str, ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}" path = f"/_matrix/media/v3/download/{destination}/{media_id}"
return await self.client.get_file( return await self.client.get_file(
destination, destination,
path, path,
@ -873,6 +871,29 @@ class TransportLayerClient:
ip_address=ip_address, ip_address=ip_address,
) )
async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
path = f"/_matrix/federation/v1/media/download/{media_id}"
return await self.client.federation_get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
"timeout_ms": str(max_timeout_ms),
},
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
def _create_path(federation_prefix: str, path: str, *args: str) -> str: def _create_path(federation_prefix: str, path: str, *args: str) -> str:
""" """

View file

@ -75,9 +75,11 @@ from synapse.http.client import (
BlocklistingAgentWrapper, BlocklistingAgentWrapper,
BodyExceededMaxSize, BodyExceededMaxSize,
ByteWriteable, ByteWriteable,
SimpleHttpClient,
_make_scheduler, _make_scheduler,
encode_query_args, encode_query_args,
read_body_with_max_size, read_body_with_max_size,
read_multipart_response,
) )
from synapse.http.connectproxyclient import BearerProxyCredentials from synapse.http.connectproxyclient import BearerProxyCredentials
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
@ -466,6 +468,12 @@ class MatrixFederationHttpClient:
self._sleeper = AwakenableSleeper(self.reactor) self._sleeper = AwakenableSleeper(self.reactor)
self._simple_http_client = SimpleHttpClient(
hs,
ip_blocklist=hs.config.server.federation_ip_range_blocklist,
use_proxy=True,
)
def wake_destination(self, destination: str) -> None: def wake_destination(self, destination: str) -> None:
"""Called when the remote server may have come back online.""" """Called when the remote server may have come back online."""
@ -1553,6 +1561,176 @@ class MatrixFederationHttpClient:
) )
return length, headers return length, headers
async def federation_get_file(
self,
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: str,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
"""GETs a file from a given homeserver over the federation /download endpoint
Args:
destination: The remote server to send the HTTP request to.
path: The HTTP path to GET.
output_stream: File to write the response body to.
download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
requester IP
ip_address: IP address of the requester
max_size: maximum allowable size in bytes of the file
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
Resolves to an (int, dict, bytes) tuple of
the file length, a dict of the response headers, and the file json
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
SynapseError: If the requested file exceeds ratelimits
"""
request = MatrixFederationRequest(
method="GET", destination=destination, path=path, query=args
)
# check for a minimum balance of 1MiB in ratelimiter before initiating request
send_req, _ = await download_ratelimiter.can_do_action(
requester=None, key=ip_address, n_actions=1048576, update=False
)
if not send_req:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
response = await self._send_request(
request,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
headers = dict(response.headers.getAllRawHeaders())
expected_size = response.length
# if we don't get an expected length then use the max length
if expected_size == UNKNOWN_LENGTH:
expected_size = max_size
logger.debug(
f"File size unknown, assuming file is max allowable size: {max_size}"
)
read_body, _ = await download_ratelimiter.can_do_action(
requester=None,
key=ip_address,
n_actions=expected_size,
)
if not read_body:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
# this should be a multipart/mixed response with the boundary string in the header
raw_content_type = headers.get(b"Content-Type")
assert raw_content_type is not None
content_type = raw_content_type[0].decode("UTF-8")
boundary = content_type.split("boundary=")[1]
try:
# add a byte of headroom to max size as function errs at >=
deferred = read_multipart_response(
response, output_stream, boundary, expected_size + 1
)
deferred.addTimeout(self.default_timeout_seconds, self.reactor)
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
except ResponseFailed as e:
logger.warning(
"{%s} [%s] Failed to read response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
e,
)
raise
multipart_response = await make_deferred_yieldable(deferred)
if not multipart_response.url:
assert multipart_response.length is not None
length = multipart_response.length
headers[b"Content-Type"] = [multipart_response.content_type]
headers[b"Content-Disposition"] = [multipart_response.disposition]
# the response contained a redirect url to download the file from
else:
str_url = multipart_response.url.decode("utf-8")
logger.info(
"{%s} [%s] File download redirected, now downloading from: %s",
request.txn_id,
request.destination,
str_url,
)
length, headers, _, _ = await self._simple_http_client.get_file(
str_url, output_stream, expected_size
)
logger.info(
"{%s} [%s] Completed: %d %s [%d bytes] %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
length,
request.method,
request.uri.decode("ascii"),
)
return length, headers, multipart_response.json
def _flatten_response_never_received(e: BaseException) -> str: def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"): if hasattr(e, "reasons"):

View file

@ -480,6 +480,7 @@ class MediaRepository:
name: Optional[str], name: Optional[str],
max_timeout_ms: int, max_timeout_ms: int,
ip_address: str, ip_address: str,
use_federation_endpoint: bool,
) -> None: ) -> None:
"""Respond to requests for remote media. """Respond to requests for remote media.
@ -492,6 +493,8 @@ class MediaRepository:
max_timeout_ms: the maximum number of milliseconds to wait for the max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded. media to be uploaded.
ip_address: the IP address of the requester ip_address: the IP address of the requester
use_federation_endpoint: whether to request the remote media over the new
federation `/download` endpoint
Returns: Returns:
Resolves once a response has successfully been written to request Resolves once a response has successfully been written to request
@ -522,6 +525,7 @@ class MediaRepository:
max_timeout_ms, max_timeout_ms,
self.download_ratelimiter, self.download_ratelimiter,
ip_address, ip_address,
use_federation_endpoint,
) )
# We deliberately stream the file outside the lock # We deliberately stream the file outside the lock
@ -569,6 +573,7 @@ class MediaRepository:
max_timeout_ms, max_timeout_ms,
self.download_ratelimiter, self.download_ratelimiter,
ip_address, ip_address,
False,
) )
# Ensure we actually use the responder so that it releases resources # Ensure we actually use the responder so that it releases resources
@ -585,6 +590,7 @@ class MediaRepository:
max_timeout_ms: int, max_timeout_ms: int,
download_ratelimiter: Ratelimiter, download_ratelimiter: Ratelimiter,
ip_address: str, ip_address: str,
use_federation_endpoint: bool,
) -> Tuple[Optional[Responder], RemoteMedia]: ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to """Looks for media in local cache, if not there then attempt to
download from remote server. download from remote server.
@ -598,6 +604,8 @@ class MediaRepository:
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP. requester IP.
ip_address: the IP address of the requester ip_address: the IP address of the requester
use_federation_endpoint: whether to request the remote media over the new federation
/download endpoint
Returns: Returns:
A tuple of responder and the media info of the file. A tuple of responder and the media info of the file.
@ -628,18 +636,45 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it. # Failed to find the file anywhere, lets download it.
try: if not use_federation_endpoint:
media_info = await self._download_remote_file( try:
server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address media_info = await self._download_remote_file(
) server_name,
except SynapseError: media_id,
raise max_timeout_ms,
except Exception as e: download_ratelimiter,
# An exception may be because we downloaded media in another ip_address,
# process, so let's check if we magically have the media. )
media_info = await self.store.get_cached_remote_media(server_name, media_id) except SynapseError:
if not media_info: raise
raise e except Exception as e:
# An exception may be because we downloaded media in another
# process, so let's check if we magically have the media.
media_info = await self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
raise e
else:
try:
media_info = await self._federation_download_remote_file(
server_name,
media_id,
max_timeout_ms,
download_ratelimiter,
ip_address,
)
except SynapseError:
raise
except Exception as e:
# An exception may be because we downloaded media in another
# process, so let's check if we magically have the media.
media_info = await self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
raise e
file_id = media_info.filesystem_id file_id = media_info.filesystem_id
if not media_info.media_type: if not media_info.media_type:
@ -775,6 +810,123 @@ class MediaRepository:
quarantined_by=None, quarantined_by=None,
) )
async def _federation_download_remote_file(
self,
server_name: str,
media_id: str,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id and downloading over federation v1 download
endpoint
Args:
server_name: Originating server
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP
ip_address: the IP address of the requester
Returns:
The media info of the file.
"""
file_id = random_string(24)
file_info = FileInfo(server_name=server_name, file_id=file_id)
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers, json = await self.client.federation_download_media(
server_name,
media_id,
output_stream=f,
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except RequestSendFailed as e:
logger.warning(
"Request failed fetching remote media %s/%s: %r",
server_name,
media_id,
e,
)
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
logger.warning(
"HTTP error fetching remote media %s/%s: %s",
server_name,
media_id,
e.response,
)
if e.code == twisted.web.http.NOT_FOUND:
raise e.to_synapse_error()
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
logger.warning(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise
except NotRetryingDestination:
logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise SynapseError(502, "Failed to fetch remote media")
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
media_type = "application/octet-stream"
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
# Multiple remote media download requests can race (when using
# multiple media repos), so this may throw a violation constraint
# exception. If it does we'll delete the newly downloaded file from
# disk (as we're in the ctx manager).
#
# However: we've already called `finish()` so we may have also
# written to the storage providers. This is preferable to the
# alternative where we call `finish()` *after* this, where we could
# end up having an entry in the DB but fail to write the files to
# the storage providers.
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
logger.info("Stored remote media in file %r", fname)
return RemoteMedia(
media_origin=server_name,
media_id=media_id,
media_type=media_type,
media_length=length,
upload_name=upload_name,
created_ts=time_now_ms,
filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
)
def _get_thumbnail_requirements( def _get_thumbnail_requirements(
self, media_type: str self, media_type: str
) -> Tuple[ThumbnailRequirement, ...]: ) -> Tuple[ThumbnailRequirement, ...]:

View file

@ -105,4 +105,5 @@ class DownloadResource(RestServlet):
file_name, file_name,
max_timeout_ms, max_timeout_ms,
ip_address, ip_address,
False,
) )