fallback to matrix/media/v3/download if federation endpoint 404s + test to verify this behavior

This commit is contained in:
H. Shay 2024-06-28 19:04:16 -07:00
parent a86d448b88
commit 196335c52e
3 changed files with 122 additions and 10 deletions

View file

@ -1880,8 +1880,34 @@ class FederationClient(FederationBase):
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
return await self.transport_layer.federation_download_media(
) -> Union[
Tuple[int, Dict[bytes, List[bytes]], bytes],
Tuple[int, Dict[bytes, List[bytes]]],
]:
try:
return await self.transport_layer.federation_download_media(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise
logger.debug(
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
destination,
media_id,
)
return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,

View file

@ -830,7 +830,7 @@ class MediaRepository:
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers, json = await self.client.federation_download_media(
res = await self.client.federation_download_media(
server_name,
media_id,
output_stream=f,
@ -839,6 +839,12 @@ class MediaRepository:
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
# if we had to fall back to the _matrix/media endpoint it will only return
# the headers and length, check the length of the tuple before unpacking
if len(res) == 3:
length, headers, json = res
else:
length, headers = res
except RequestSendFailed as e:
logger.warning(
"Request failed fetching remote media %s/%s: %r",

View file

@ -23,13 +23,11 @@ import io
import json
import os
import re
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
from unittest.mock import MagicMock, Mock, patch
from urllib import parse
from urllib.parse import quote, urlencode
from parameterized import parameterized_class
from twisted.internet import defer
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
@ -60,7 +58,6 @@ from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
from tests.media.test_media_storage import (
SVG,
TestImage,
empty_file,
small_lossless_webp,
small_png,
@ -1898,9 +1895,10 @@ test_images = [
input_values = [(x,) for x in test_images]
@parameterized_class(("test_image",), input_values)
# @parameterized_class(("test_image",), input_values)
class DownloadTestCase(unittest.HomeserverTestCase):
test_image: ClassVar[TestImage]
# test_image: ClassVar[TestImage]
test_image = SVG
servlets = [
media.register_servlets,
login.register_servlets,
@ -1910,7 +1908,7 @@ class DownloadTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches: List[
Tuple[
"Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]]",
"Deferred[Any]",
str,
str,
Optional[QueryParams],
@ -1951,9 +1949,42 @@ class DownloadTestCase(unittest.HomeserverTestCase):
d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: Any,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
"""A mock for MatrixFederationHttpClient.get_file."""
def write_to(
r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
return response
def write_err(f: Failure) -> Failure:
f.trap(HttpResponseException)
output_stream.write(f.value.response)
return f
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
# Note that this callback changes the value held by d.
d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)
# Mock out the homeserver's MatrixFederationHttpClient
client = Mock()
client.federation_get_file = federation_get_file
client.get_file = get_file
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
@ -2128,3 +2159,52 @@ class DownloadTestCase(unittest.HomeserverTestCase):
headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
[b"cross-origin"],
)
def test_unknown_federation_endpoint(self) -> None:
"""
Test that if the downloadd request to remote federation endpoint returns a 404
we fall back to the _matrix/media endpoint
"""
channel = self.make_request(
"GET",
f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
shorthand=False,
await_result=False,
access_token=self.tok,
)
self.pump()
# We've made one fetch, to example.com, using the media URL, and asking
# the other server not to do a remote fetch
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
)
# The result which says the endpoint is unknown.
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
self.fetches[0][0].errback(
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
)
self.pump()
# There should now be another request to the _matrix/media/v3/download URL.
self.assertEqual(len(self.fetches), 2)
self.assertEqual(self.fetches[1][1], "example.com")
self.assertEqual(
self.fetches[1][2],
f"/_matrix/media/v3/download/example.com/{self.media_id}",
)
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
}
self.fetches[1][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
self.assertEqual(channel.code, 200)