mirror of
https://github.com/element-hq/synapse
synced 2024-07-04 16:53:31 +00:00
fallback to matrix/media/v3/download if federation endpoint 404s + test to verify this behavior
This commit is contained in:
parent
a86d448b88
commit
196335c52e
|
@ -1880,7 +1880,11 @@ class FederationClient(FederationBase):
|
|||
max_timeout_ms: int,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: str,
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
|
||||
) -> Union[
|
||||
Tuple[int, Dict[bytes, List[bytes]], bytes],
|
||||
Tuple[int, Dict[bytes, List[bytes]]],
|
||||
]:
|
||||
try:
|
||||
return await self.transport_layer.federation_download_media(
|
||||
destination,
|
||||
media_id,
|
||||
|
@ -1890,6 +1894,28 @@ class FederationClient(FederationBase):
|
|||
download_ratelimiter=download_ratelimiter,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
# If an error is received that is due to an unrecognised endpoint,
|
||||
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
|
||||
# and raise.
|
||||
if not is_unknown_endpoint(e):
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
|
||||
destination,
|
||||
media_id,
|
||||
)
|
||||
|
||||
return await self.transport_layer.download_media_v3(
|
||||
destination,
|
||||
media_id,
|
||||
output_stream=output_stream,
|
||||
max_size=max_size,
|
||||
max_timeout_ms=max_timeout_ms,
|
||||
download_ratelimiter=download_ratelimiter,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
async def download_media(
|
||||
self,
|
||||
|
|
|
@ -830,7 +830,7 @@ class MediaRepository:
|
|||
|
||||
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||
try:
|
||||
length, headers, json = await self.client.federation_download_media(
|
||||
res = await self.client.federation_download_media(
|
||||
server_name,
|
||||
media_id,
|
||||
output_stream=f,
|
||||
|
@ -839,6 +839,12 @@ class MediaRepository:
|
|||
download_ratelimiter=download_ratelimiter,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
# if we had to fall back to the _matrix/media endpoint it will only return
|
||||
# the headers and length, check the length of the tuple before unpacking
|
||||
if len(res) == 3:
|
||||
length, headers, json = res
|
||||
else:
|
||||
length, headers = res
|
||||
except RequestSendFailed as e:
|
||||
logger.warning(
|
||||
"Request failed fetching remote media %s/%s: %r",
|
||||
|
|
|
@ -23,13 +23,11 @@ import io
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from urllib import parse
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
from parameterized import parameterized_class
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet._resolver import HostResolution
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
|
@ -60,7 +58,6 @@ from synapse.util.stringutils import parse_and_validate_mxc_uri
|
|||
from tests import unittest
|
||||
from tests.media.test_media_storage import (
|
||||
SVG,
|
||||
TestImage,
|
||||
empty_file,
|
||||
small_lossless_webp,
|
||||
small_png,
|
||||
|
@ -1898,9 +1895,10 @@ test_images = [
|
|||
input_values = [(x,) for x in test_images]
|
||||
|
||||
|
||||
@parameterized_class(("test_image",), input_values)
|
||||
# @parameterized_class(("test_image",), input_values)
|
||||
class DownloadTestCase(unittest.HomeserverTestCase):
|
||||
test_image: ClassVar[TestImage]
|
||||
# test_image: ClassVar[TestImage]
|
||||
test_image = SVG
|
||||
servlets = [
|
||||
media.register_servlets,
|
||||
login.register_servlets,
|
||||
|
@ -1910,7 +1908,7 @@ class DownloadTestCase(unittest.HomeserverTestCase):
|
|||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.fetches: List[
|
||||
Tuple[
|
||||
"Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]]",
|
||||
"Deferred[Any]",
|
||||
str,
|
||||
str,
|
||||
Optional[QueryParams],
|
||||
|
@ -1951,9 +1949,42 @@ class DownloadTestCase(unittest.HomeserverTestCase):
|
|||
d_after_callback = d.addCallbacks(write_to, write_err)
|
||||
return make_deferred_yieldable(d_after_callback)
|
||||
|
||||
def get_file(
|
||||
destination: str,
|
||||
path: str,
|
||||
output_stream: BinaryIO,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: Any,
|
||||
max_size: int,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
ignore_backoff: bool = False,
|
||||
follow_redirects: bool = False,
|
||||
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
|
||||
"""A mock for MatrixFederationHttpClient.get_file."""
|
||||
|
||||
def write_to(
|
||||
r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
|
||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||
data, response = r
|
||||
output_stream.write(data)
|
||||
return response
|
||||
|
||||
def write_err(f: Failure) -> Failure:
|
||||
f.trap(HttpResponseException)
|
||||
output_stream.write(f.value.response)
|
||||
return f
|
||||
|
||||
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
|
||||
self.fetches.append((d, destination, path, args))
|
||||
# Note that this callback changes the value held by d.
|
||||
d_after_callback = d.addCallbacks(write_to, write_err)
|
||||
return make_deferred_yieldable(d_after_callback)
|
||||
|
||||
# Mock out the homeserver's MatrixFederationHttpClient
|
||||
client = Mock()
|
||||
client.federation_get_file = federation_get_file
|
||||
client.get_file = get_file
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
self.media_store_path = self.mktemp()
|
||||
|
@ -2128,3 +2159,52 @@ class DownloadTestCase(unittest.HomeserverTestCase):
|
|||
headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
|
||||
[b"cross-origin"],
|
||||
)
|
||||
|
||||
def test_unknown_federation_endpoint(self) -> None:
|
||||
"""
|
||||
Test that if the downloadd request to remote federation endpoint returns a 404
|
||||
we fall back to the _matrix/media endpoint
|
||||
"""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
|
||||
shorthand=False,
|
||||
await_result=False,
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.pump()
|
||||
|
||||
# We've made one fetch, to example.com, using the media URL, and asking
|
||||
# the other server not to do a remote fetch
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
self.assertEqual(self.fetches[0][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
|
||||
)
|
||||
|
||||
# The result which says the endpoint is unknown.
|
||||
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
|
||||
self.fetches[0][0].errback(
|
||||
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
||||
# There should now be another request to the _matrix/media/v3/download URL.
|
||||
self.assertEqual(len(self.fetches), 2)
|
||||
self.assertEqual(self.fetches[1][1], "example.com")
|
||||
self.assertEqual(
|
||||
self.fetches[1][2],
|
||||
f"/_matrix/media/v3/download/example.com/{self.media_id}",
|
||||
)
|
||||
|
||||
headers = {
|
||||
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
||||
}
|
||||
|
||||
self.fetches[1][0].callback(
|
||||
(self.test_image.data, (len(self.test_image.data), headers))
|
||||
)
|
||||
|
||||
self.pump()
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
|
Loading…
Reference in a new issue