fix sharded repo test failing and add shard tests for authenticated endpoint

This commit is contained in:
H. Shay 2024-06-27 14:08:09 -07:00
parent c95b0ef1fe
commit 35b9a6865a
2 changed files with 448 additions and 126 deletions

View file

@ -636,23 +636,45 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it. # Failed to find the file anywhere, lets download it.
try: if not use_federation_endpoint:
media_info = await self._download_remote_file( try:
server_name, media_info = await self._download_remote_file(
media_id, server_name,
max_timeout_ms, media_id,
download_ratelimiter, max_timeout_ms,
ip_address, download_ratelimiter,
use_federation_endpoint, ip_address,
) )
except SynapseError: except SynapseError:
raise raise
except Exception as e: except Exception as e:
# An exception may be because we downloaded media in another # An exception may be because we downloaded media in another
# process, so let's check if we magically have the media. # process, so let's check if we magically have the media.
media_info = await self.store.get_cached_remote_media(server_name, media_id) media_info = await self.store.get_cached_remote_media(
if not media_info: server_name, media_id
raise e )
if not media_info:
raise e
else:
try:
media_info = await self._federation_download_remote_file(
server_name,
media_id,
max_timeout_ms,
download_ratelimiter,
ip_address,
)
except SynapseError:
raise
except Exception as e:
# An exception may be because we downloaded media in another
# process, so let's check if we magically have the media.
media_info = await self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
raise e
file_id = media_info.filesystem_id file_id = media_info.filesystem_id
if not media_info.media_type: if not media_info.media_type:
@ -679,7 +701,6 @@ class MediaRepository:
max_timeout_ms: int, max_timeout_ms: int,
download_ratelimiter: Ratelimiter, download_ratelimiter: Ratelimiter,
ip_address: str, ip_address: str,
use_federation_endpoint: bool,
) -> RemoteMedia: ) -> RemoteMedia:
"""Attempt to download the remote file from the given server name, """Attempt to download the remote file from the given server name,
using the given file_id as the local id. using the given file_id as the local id.
@ -694,8 +715,6 @@ class MediaRepository:
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP requester IP
ip_address: the IP address of the requester ip_address: the IP address of the requester
use_federation_endpoint: whether to request the remote media over the new
federation /download endpoint
Returns: Returns:
The media info of the file. The media info of the file.
@ -706,123 +725,194 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id) file_info = FileInfo(server_name=server_name, file_id=file_id)
async with self.media_storage.store_into_file(file_info) as (f, fname): async with self.media_storage.store_into_file(file_info) as (f, fname):
if not use_federation_endpoint: try:
try: length, headers = await self.client.download_media(
length, headers = await self.client.download_media( server_name,
server_name, media_id,
media_id, output_stream=f,
output_stream=f, max_size=self.max_upload_size,
max_size=self.max_upload_size, max_timeout_ms=max_timeout_ms,
max_timeout_ms=max_timeout_ms, download_ratelimiter=download_ratelimiter,
download_ratelimiter=download_ratelimiter, ip_address=ip_address,
ip_address=ip_address, )
) except RequestSendFailed as e:
except RequestSendFailed as e: logger.warning(
logger.warning( "Request failed fetching remote media %s/%s: %r",
"Request failed fetching remote media %s/%s: %r", server_name,
server_name, media_id,
media_id, e,
e, )
) raise SynapseError(502, "Failed to fetch remote media")
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e: except HttpResponseException as e:
logger.warning( logger.warning(
"HTTP error fetching remote media %s/%s: %s", "HTTP error fetching remote media %s/%s: %s",
server_name, server_name,
media_id, media_id,
e.response, e.response,
) )
if e.code == twisted.web.http.NOT_FOUND: if e.code == twisted.web.http.NOT_FOUND:
raise e.to_synapse_error() raise e.to_synapse_error()
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
except SynapseError: except SynapseError:
logger.warning( logger.warning(
"Failed to fetch remote media %s/%s", server_name, media_id "Failed to fetch remote media %s/%s", server_name, media_id
) )
raise raise
except NotRetryingDestination: except NotRetryingDestination:
logger.warning("Not retrying destination %r", server_name) logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to fetch remote media %s/%s", server_name, media_id "Failed to fetch remote media %s/%s", server_name, media_id
) )
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else: else:
try: media_type = "application/octet-stream"
length, headers, json = await self.client.federation_download_media( upload_name = get_filename_from_headers(headers)
server_name, time_now_ms = self.clock.time_msec()
media_id,
output_stream=f,
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except RequestSendFailed as e:
logger.warning(
"Request failed fetching remote media %s/%s: %r",
server_name,
media_id,
e,
)
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e: # Multiple remote media download requests can race (when using
logger.warning( # multiple media repos), so this may throw a violation constraint
"HTTP error fetching remote media %s/%s: %s", # exception. If it does we'll delete the newly downloaded file from
server_name, # disk (as we're in the ctx manager).
media_id, #
e.response, # However: we've already called `finish()` so we may have also
) # written to the storage providers. This is preferable to the
if e.code == twisted.web.http.NOT_FOUND: # alternative where we call `finish()` *after* this, where we could
raise e.to_synapse_error() # end up having an entry in the DB but fail to write the files to
raise SynapseError(502, "Failed to fetch remote media") # the storage providers.
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
except SynapseError: logger.info("Stored remote media in file %r", fname)
logger.warning(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise
except NotRetryingDestination:
logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise SynapseError(502, "Failed to fetch remote media")
if b"Content-Type" in headers: return RemoteMedia(
media_type = headers[b"Content-Type"][0].decode("ascii") media_origin=server_name,
else:
media_type = "application/octet-stream"
logger.info(f"Headers is {headers}")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
# Multiple remote media download requests can race (when using
# multiple media repos), so this may throw a violation constraint
# exception. If it does we'll delete the newly downloaded file from
# disk (as we're in the ctx manager).
#
# However: we've already called `finish()` so we may have also
# written to the storage providers. This is preferable to the
# alternative where we call `finish()` *after* this, where we could
# end up having an entry in the DB but fail to write the files to
# the storage providers.
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id, media_id=media_id,
media_type=media_type, media_type=media_type,
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length, media_length=length,
upload_name=upload_name,
created_ts=time_now_ms,
filesystem_id=file_id, filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
) )
async def _federation_download_remote_file(
self,
server_name: str,
media_id: str,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id and downloading over federation v1 download
endpoint
Args:
server_name: Originating server
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP
ip_address: the IP address of the requester
Returns:
The media info of the file.
"""
file_id = random_string(24)
file_info = FileInfo(server_name=server_name, file_id=file_id)
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers, json = await self.client.federation_download_media(
server_name,
media_id,
output_stream=f,
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except RequestSendFailed as e:
logger.warning(
"Request failed fetching remote media %s/%s: %r",
server_name,
media_id,
e,
)
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
logger.warning(
"HTTP error fetching remote media %s/%s: %s",
server_name,
media_id,
e.response,
)
if e.code == twisted.web.http.NOT_FOUND:
raise e.to_synapse_error()
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
logger.warning(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise
except NotRetryingDestination:
logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise SynapseError(502, "Failed to fetch remote media")
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
media_type = "application/octet-stream"
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
# Multiple remote media download requests can race (when using
# multiple media repos), so this may throw a violation constraint
# exception. If it does we'll delete the newly downloaded file from
# disk (as we're in the ctx manager).
#
# However: we've already called `finish()` so we may have also
# written to the storage providers. This is preferable to the
# alternative where we call `finish()` *after* this, where we could
# end up having an entry in the DB but fail to write the files to
# the storage providers.
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
logger.info("Stored remote media in file %r", fname) logger.info("Stored remote media in file %r", fname)
return RemoteMedia( return RemoteMedia(

View file

@ -28,7 +28,7 @@ from twisted.web.http import HTTPChannel
from twisted.web.server import Request from twisted.web.server import Request
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login from synapse.rest.client import login, media
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -255,6 +255,238 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path)) return sum(len(files) for _, _, files in os.walk(path))
class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks running multiple media repos work correctly using autheticated media paths"""
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
media.register_servlets,
]
file_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self) -> dict:
conf = super().default_config()
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
# Force the media paths onto the replication resource.
worker_hs.get_media_repository_resource().register_servlets(
self._hs_to_site[worker_hs].resource, worker_hs
)
return worker_hs
def _get_media_req(
self, hs: HomeServer, target: str, media_id: str
) -> Tuple[FakeChannel, Request]:
"""Request some remote media from the given HS by calling the download
API.
This then triggers an outbound request from the HS to the target.
Returns:
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""
channel = make_request(
self.reactor,
self._hs_to_site[hs],
"GET",
f"/_matrix/client/v1/media/download/{target}/{media_id}",
shorthand=False,
access_token=self.access_token,
await_result=False,
)
self.pump()
clients = self.reactor.tcpClients
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
# build the test server
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request
server_tls_protocol = wrap_server_factory_for_tls(
server_factory, self.reactor, sanlist=[b"DNS:example.com"]
).buildProtocol(None)
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
# a FakeTransport.
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(
FakeTransport(server_tls_protocol, self.reactor, client_protocol)
)
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
# fish the test server back out of the server-side TLS protocol.
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
self.assertEqual(
request.path,
f"/_matrix/federation/v1/media/download/{media_id}".encode(),
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
)
return channel, request
def test_basic(self) -> None:
"""Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
request.setResponseCode(200)
request.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
request.write(self.file_data)
request.finish()
self.pump(0.1)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"file_to_stream")
def test_download_simple_file_race(self) -> None:
"""Test that fetching remote media from two different processes at the
same time works.
"""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
hs2 = self.make_worker_hs("synapse.app.generic_worker")
start_count = self._count_remote_media()
# Make two requests without responding to the outbound media requests.
channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
# Respond to the first outbound media request and check that the client
# request is successful
request1.setResponseCode(200)
request1.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
request1.write(self.file_data)
request1.finish()
self.pump(0.1)
self.assertEqual(channel1.code, 200, channel1.result["body"])
self.assertEqual(channel1.result["body"], b"file_to_stream")
# Now respond to the second with the same content.
request2.setResponseCode(200)
request2.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
request2.write(self.file_data)
request2.finish()
self.pump(0.1)
self.assertEqual(channel2.code, 200, channel2.result["body"])
self.assertEqual(channel2.result["body"], b"file_to_stream")
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
def test_download_image_race(self) -> None:
"""Test that fetching remote *images* from two different processes at
the same time works.
This checks that races generating thumbnails are handled correctly.
"""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
hs2 = self.make_worker_hs("synapse.app.generic_worker")
start_count = self._count_remote_thumbnails()
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
request1.setResponseCode(200)
request1.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
img_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: image/png\r\nContent-Disposition: inline; filename=test_img\r\n\r\n"
request1.write(img_data)
request1.write(SMALL_PNG)
request1.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
request1.finish()
self.pump(0.1)
self.assertEqual(channel1.code, 200, channel1.result["body"])
self.assertEqual(channel1.result["body"], SMALL_PNG)
request2.setResponseCode(200)
request2.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
request2.write(img_data)
request2.write(SMALL_PNG)
request2.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
request2.finish()
self.pump(0.1)
self.assertEqual(channel2.code, 200, channel2.result["body"])
self.assertEqual(channel2.result["body"], SMALL_PNG)
# We expect only three new thumbnails to have been persisted.
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
def _count_remote_media(self) -> int:
"""Count the number of files in our remote media directory."""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_content"
)
return sum(len(files) for _, _, files in os.walk(path))
def _count_remote_thumbnails(self) -> int:
"""Count the number of files in our remote thumbnails directory."""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
)
return sum(len(files) for _, _, files in os.walk(path))
def _log_request(request: Request) -> None: def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish""" """Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request) logger.info("Completed request %s", request)