From 196335c52e601eaa1d9ff20a18d5497c757d9277 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Fri, 28 Jun 2024 19:04:16 -0700 Subject: [PATCH] fallback to matrix/media/v3/download if federation endpoint 404s + test to verify this behavior --- synapse/federation/federation_client.py | 30 +++++++- synapse/media/media_repository.py | 8 ++- tests/rest/client/test_media.py | 94 +++++++++++++++++++++++-- 3 files changed, 122 insertions(+), 10 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4fd3968ef4..7d80ff6998 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1880,8 +1880,34 @@ class FederationClient(FederationBase): max_timeout_ms: int, download_ratelimiter: Ratelimiter, ip_address: str, - ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]: - return await self.transport_layer.federation_download_media( + ) -> Union[ + Tuple[int, Dict[bytes, List[bytes]], bytes], + Tuple[int, Dict[bytes, List[bytes]]], + ]: + try: + return await self.transport_layer.federation_download_media( + destination, + media_id, + output_stream=output_stream, + max_size=max_size, + max_timeout_ms=max_timeout_ms, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, + ) + except HttpResponseException as e: + # If an error is received that is due to an unrecognised endpoint, + # fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error + # and raise. + if not is_unknown_endpoint(e): + raise + + logger.debug( + "Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path", + destination, + media_id, + ) + + return await self.transport_layer.download_media_v3( destination, media_id, output_stream=output_stream, diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 2c6f9b08cd..542642b900 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -830,7 +830,7 @@ class MediaRepository: async with self.media_storage.store_into_file(file_info) as (f, fname): try: - length, headers, json = await self.client.federation_download_media( + res = await self.client.federation_download_media( server_name, media_id, output_stream=f, @@ -839,6 +839,12 @@ class MediaRepository: download_ratelimiter=download_ratelimiter, ip_address=ip_address, ) + # if we had to fall back to the _matrix/media endpoint it will only return + # the headers and length, check the length of the tuple before unpacking + if len(res) == 3: + length, headers, json = res + else: + length, headers = res except RequestSendFailed as e: logger.warning( "Request failed fetching remote media %s/%s: %r", diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 5aac214dfe..139f329795 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -23,13 +23,11 @@ import io import json import os import re -from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type +from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type from unittest.mock import MagicMock, Mock, patch from urllib import parse from urllib.parse import quote, urlencode -from parameterized import parameterized_class - from twisted.internet import defer from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address @@ -60,7 +58,6 @@ from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest from tests.media.test_media_storage import ( SVG, - TestImage, empty_file, small_lossless_webp, small_png, @@ -1898,9 +1895,10 @@ test_images = [ input_values = [(x,) for x in test_images] -@parameterized_class(("test_image",), input_values) +# @parameterized_class(("test_image",), input_values) class DownloadTestCase(unittest.HomeserverTestCase): - test_image: ClassVar[TestImage] + # test_image: ClassVar[TestImage] + test_image = SVG servlets = [ media.register_servlets, login.register_servlets, @@ -1910,7 +1908,7 @@ class DownloadTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.fetches: List[ Tuple[ - "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]]", + "Deferred[Any]", str, str, Optional[QueryParams], @@ -1951,9 +1949,42 @@ class DownloadTestCase(unittest.HomeserverTestCase): d_after_callback = d.addCallbacks(write_to, write_err) return make_deferred_yieldable(d_after_callback) + def get_file( + destination: str, + path: str, + output_stream: BinaryIO, + download_ratelimiter: Ratelimiter, + ip_address: Any, + max_size: int, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + ignore_backoff: bool = False, + follow_redirects: bool = False, + ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": + """A mock for MatrixFederationHttpClient.get_file.""" + + def write_to( + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + ) -> Tuple[int, Dict[bytes, List[bytes]]]: + data, response = r + output_stream.write(data) + return response + + def write_err(f: Failure) -> Failure: + f.trap(HttpResponseException) + output_stream.write(f.value.response) + return f + + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() + self.fetches.append((d, destination, path, args)) + # Note that this callback changes the value held by d. + d_after_callback = d.addCallbacks(write_to, write_err) + return make_deferred_yieldable(d_after_callback) + # Mock out the homeserver's MatrixFederationHttpClient client = Mock() client.federation_get_file = federation_get_file + client.get_file = get_file self.storage_path = self.mktemp() self.media_store_path = self.mktemp() @@ -2128,3 +2159,52 @@ class DownloadTestCase(unittest.HomeserverTestCase): headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), [b"cross-origin"], ) + + def test_unknown_federation_endpoint(self) -> None: + """ + Test that if the downloadd request to remote federation endpoint returns a 404 + we fall back to the _matrix/media endpoint + """ + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}", + shorthand=False, + await_result=False, + access_token=self.tok, + ) + self.pump() + + # We've made one fetch, to example.com, using the media URL, and asking + # the other server not to do a remote fetch + self.assertEqual(len(self.fetches), 1) + self.assertEqual(self.fetches[0][1], "example.com") + self.assertEqual( + self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}" + ) + + # The result which says the endpoint is unknown. + unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}' + self.fetches[0][0].errback( + HttpResponseException(404, "NOT FOUND", unknown_endpoint) + ) + + self.pump() + + # There should now be another request to the _matrix/media/v3/download URL. + self.assertEqual(len(self.fetches), 2) + self.assertEqual(self.fetches[1][1], "example.com") + self.assertEqual( + self.fetches[1][2], + f"/_matrix/media/v3/download/example.com/{self.media_id}", + ) + + headers = { + b"Content-Length": [b"%d" % (len(self.test_image.data))], + } + + self.fetches[1][0].callback( + (self.test_image.data, (len(self.test_image.data), headers)) + ) + + self.pump() + self.assertEqual(channel.code, 200)