mirror of
https://github.com/element-hq/synapse
synced 2024-07-04 18:03:32 +00:00
fallback to matrix/media/v3/download if federation endpoint 404s + test to verify this behavior
This commit is contained in:
parent
a86d448b88
commit
196335c52e
|
@ -1880,8 +1880,34 @@ class FederationClient(FederationBase):
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
download_ratelimiter: Ratelimiter,
|
download_ratelimiter: Ratelimiter,
|
||||||
ip_address: str,
|
ip_address: str,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
|
) -> Union[
|
||||||
return await self.transport_layer.federation_download_media(
|
Tuple[int, Dict[bytes, List[bytes]], bytes],
|
||||||
|
Tuple[int, Dict[bytes, List[bytes]]],
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
return await self.transport_layer.federation_download_media(
|
||||||
|
destination,
|
||||||
|
media_id,
|
||||||
|
output_stream=output_stream,
|
||||||
|
max_size=max_size,
|
||||||
|
max_timeout_ms=max_timeout_ms,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
|
)
|
||||||
|
except HttpResponseException as e:
|
||||||
|
# If an error is received that is due to an unrecognised endpoint,
|
||||||
|
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
|
||||||
|
# and raise.
|
||||||
|
if not is_unknown_endpoint(e):
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
|
||||||
|
destination,
|
||||||
|
media_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.transport_layer.download_media_v3(
|
||||||
destination,
|
destination,
|
||||||
media_id,
|
media_id,
|
||||||
output_stream=output_stream,
|
output_stream=output_stream,
|
||||||
|
|
|
@ -830,7 +830,7 @@ class MediaRepository:
|
||||||
|
|
||||||
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
async with self.media_storage.store_into_file(file_info) as (f, fname):
|
||||||
try:
|
try:
|
||||||
length, headers, json = await self.client.federation_download_media(
|
res = await self.client.federation_download_media(
|
||||||
server_name,
|
server_name,
|
||||||
media_id,
|
media_id,
|
||||||
output_stream=f,
|
output_stream=f,
|
||||||
|
@ -839,6 +839,12 @@ class MediaRepository:
|
||||||
download_ratelimiter=download_ratelimiter,
|
download_ratelimiter=download_ratelimiter,
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
|
# if we had to fall back to the _matrix/media endpoint it will only return
|
||||||
|
# the headers and length, check the length of the tuple before unpacking
|
||||||
|
if len(res) == 3:
|
||||||
|
length, headers, json = res
|
||||||
|
else:
|
||||||
|
length, headers = res
|
||||||
except RequestSendFailed as e:
|
except RequestSendFailed as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Request failed fetching remote media %s/%s: %r",
|
"Request failed fetching remote media %s/%s: %r",
|
||||||
|
|
|
@ -23,13 +23,11 @@ import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
|
from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
from urllib.parse import quote, urlencode
|
from urllib.parse import quote, urlencode
|
||||||
|
|
||||||
from parameterized import parameterized_class
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet._resolver import HostResolution
|
from twisted.internet._resolver import HostResolution
|
||||||
from twisted.internet.address import IPv4Address, IPv6Address
|
from twisted.internet.address import IPv4Address, IPv6Address
|
||||||
|
@ -60,7 +58,6 @@ from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.media.test_media_storage import (
|
from tests.media.test_media_storage import (
|
||||||
SVG,
|
SVG,
|
||||||
TestImage,
|
|
||||||
empty_file,
|
empty_file,
|
||||||
small_lossless_webp,
|
small_lossless_webp,
|
||||||
small_png,
|
small_png,
|
||||||
|
@ -1898,9 +1895,10 @@ test_images = [
|
||||||
input_values = [(x,) for x in test_images]
|
input_values = [(x,) for x in test_images]
|
||||||
|
|
||||||
|
|
||||||
@parameterized_class(("test_image",), input_values)
|
# @parameterized_class(("test_image",), input_values)
|
||||||
class DownloadTestCase(unittest.HomeserverTestCase):
|
class DownloadTestCase(unittest.HomeserverTestCase):
|
||||||
test_image: ClassVar[TestImage]
|
# test_image: ClassVar[TestImage]
|
||||||
|
test_image = SVG
|
||||||
servlets = [
|
servlets = [
|
||||||
media.register_servlets,
|
media.register_servlets,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
|
@ -1910,7 +1908,7 @@ class DownloadTestCase(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
self.fetches: List[
|
self.fetches: List[
|
||||||
Tuple[
|
Tuple[
|
||||||
"Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]]",
|
"Deferred[Any]",
|
||||||
str,
|
str,
|
||||||
str,
|
str,
|
||||||
Optional[QueryParams],
|
Optional[QueryParams],
|
||||||
|
@ -1951,9 +1949,42 @@ class DownloadTestCase(unittest.HomeserverTestCase):
|
||||||
d_after_callback = d.addCallbacks(write_to, write_err)
|
d_after_callback = d.addCallbacks(write_to, write_err)
|
||||||
return make_deferred_yieldable(d_after_callback)
|
return make_deferred_yieldable(d_after_callback)
|
||||||
|
|
||||||
|
def get_file(
|
||||||
|
destination: str,
|
||||||
|
path: str,
|
||||||
|
output_stream: BinaryIO,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: Any,
|
||||||
|
max_size: int,
|
||||||
|
args: Optional[QueryParams] = None,
|
||||||
|
retry_on_dns_fail: bool = True,
|
||||||
|
ignore_backoff: bool = False,
|
||||||
|
follow_redirects: bool = False,
|
||||||
|
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
|
||||||
|
"""A mock for MatrixFederationHttpClient.get_file."""
|
||||||
|
|
||||||
|
def write_to(
|
||||||
|
r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
|
||||||
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
|
data, response = r
|
||||||
|
output_stream.write(data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def write_err(f: Failure) -> Failure:
|
||||||
|
f.trap(HttpResponseException)
|
||||||
|
output_stream.write(f.value.response)
|
||||||
|
return f
|
||||||
|
|
||||||
|
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
|
||||||
|
self.fetches.append((d, destination, path, args))
|
||||||
|
# Note that this callback changes the value held by d.
|
||||||
|
d_after_callback = d.addCallbacks(write_to, write_err)
|
||||||
|
return make_deferred_yieldable(d_after_callback)
|
||||||
|
|
||||||
# Mock out the homeserver's MatrixFederationHttpClient
|
# Mock out the homeserver's MatrixFederationHttpClient
|
||||||
client = Mock()
|
client = Mock()
|
||||||
client.federation_get_file = federation_get_file
|
client.federation_get_file = federation_get_file
|
||||||
|
client.get_file = get_file
|
||||||
|
|
||||||
self.storage_path = self.mktemp()
|
self.storage_path = self.mktemp()
|
||||||
self.media_store_path = self.mktemp()
|
self.media_store_path = self.mktemp()
|
||||||
|
@ -2128,3 +2159,52 @@ class DownloadTestCase(unittest.HomeserverTestCase):
|
||||||
headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
|
headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
|
||||||
[b"cross-origin"],
|
[b"cross-origin"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_unknown_federation_endpoint(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that if the downloadd request to remote federation endpoint returns a 404
|
||||||
|
we fall back to the _matrix/media endpoint
|
||||||
|
"""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
|
||||||
|
shorthand=False,
|
||||||
|
await_result=False,
|
||||||
|
access_token=self.tok,
|
||||||
|
)
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# We've made one fetch, to example.com, using the media URL, and asking
|
||||||
|
# the other server not to do a remote fetch
|
||||||
|
self.assertEqual(len(self.fetches), 1)
|
||||||
|
self.assertEqual(self.fetches[0][1], "example.com")
|
||||||
|
self.assertEqual(
|
||||||
|
self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The result which says the endpoint is unknown.
|
||||||
|
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
|
||||||
|
self.fetches[0][0].errback(
|
||||||
|
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# There should now be another request to the _matrix/media/v3/download URL.
|
||||||
|
self.assertEqual(len(self.fetches), 2)
|
||||||
|
self.assertEqual(self.fetches[1][1], "example.com")
|
||||||
|
self.assertEqual(
|
||||||
|
self.fetches[1][2],
|
||||||
|
f"/_matrix/media/v3/download/example.com/{self.media_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
b"Content-Length": [b"%d" % (len(self.test_image.data))],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.fetches[1][0].callback(
|
||||||
|
(self.test_image.data, (len(self.test_image.data), headers))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pump()
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
Loading…
Reference in a new issue