From 5b5280e3e5e37c1cf2ed758db30f221c438cc33f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Jul 2024 12:38:29 +0100 Subject: [PATCH 1/7] Fix building debian packages for sid (#17389) Sid now defaults to python3.12, and our pinned version of cffi (1.5.1) does not have wheels for 3.12. This installing cffi to fail as we did not have the correct libs installed to build from source. --- changelog.d/17389.misc | 1 + docker/Dockerfile-dhvirtualenv | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/17389.misc diff --git a/changelog.d/17389.misc b/changelog.d/17389.misc new file mode 100644 index 0000000000..7022ed93d9 --- /dev/null +++ b/changelog.d/17389.misc @@ -0,0 +1 @@ +Fix building debian package for debian sid. diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv index b7679924c2..f000144567 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv @@ -73,6 +73,8 @@ RUN apt-get update -qq -o Acquire::Languages=none \ curl \ debhelper \ devscripts \ + # Required for building cffi from source. + libffi-dev \ libsystemd-dev \ lsb-release \ pkg-config \ From 9c8f1a6d412c8178eadaf64346c6e386328ba1ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Jul 2024 12:39:32 +0100 Subject: [PATCH 2/7] Fix building debian packages on non-clean checkouts (#17390) If we leave the `.so` in place it causes the tests to fail, as it gets picked up (instead of the newly built .so) and so fails with mismatched GLIBC errors. --- changelog.d/17390.misc | 1 + docker/build_debian.sh | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/17390.misc diff --git a/changelog.d/17390.misc b/changelog.d/17390.misc new file mode 100644 index 0000000000..6a4e344c5c --- /dev/null +++ b/changelog.d/17390.misc @@ -0,0 +1 @@ +Fix building debian packages on non-clean checkouts. diff --git a/docker/build_debian.sh b/docker/build_debian.sh index 9eae38af91..00e0856c7d 100644 --- a/docker/build_debian.sh +++ b/docker/build_debian.sh @@ -11,6 +11,9 @@ DIST=$(cut -d ':' -f2 <<< "${distro:?}") cp -aT /synapse/source /synapse/build cd /synapse/build +# Delete any existing `.so` files to ensure a clean build. +rm -f /synapse/build/synapse/*.so + # if this is a prerelease, set the Section accordingly. # # When the package is later added to the package repo, reprepro will use the From b3b793786c82383edec6c7d3226d98dbafe3b098 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Jul 2024 12:39:49 +0100 Subject: [PATCH 3/7] Fix sync waiting for an invalid token from the "future" (#17386) Fixes https://github.com/element-hq/synapse/issues/17274, hopefully. Basically, old versions of Synapse could advance streams without persisting anything in the DB (fixed in #17229). On restart those updates would get lost, and so the position of the stream would revert to an older position. If this happened across an upgrade to a later Synapse version which included #17215, then sync could get blocked indefinitely (until the stream advanced to the position in the token). We fix this by bounding the stream positions we'll wait for to the maximum position of the underlying stream ID generator. --- changelog.d/17386.bugfix | 1 + synapse/notifier.py | 7 ++ .../storage/databases/main/account_data.py | 10 +-- synapse/storage/databases/main/deviceinbox.py | 10 +-- synapse/storage/databases/main/devices.py | 3 + .../storage/databases/main/events_worker.py | 4 +- synapse/storage/databases/main/presence.py | 10 +-- synapse/storage/databases/main/push_rule.py | 3 + synapse/storage/databases/main/receipts.py | 10 +-- synapse/storage/databases/main/room.py | 11 ++- synapse/storage/databases/main/stream.py | 3 + synapse/storage/util/id_generators.py | 5 ++ synapse/storage/util/sequence.py | 24 ++++++ synapse/streams/events.py | 64 +++++++++++++++- synapse/types/__init__.py | 18 +++++ tests/handlers/test_sync.py | 73 ++++++++++++++++++- tests/rest/client/test_sync.py | 4 +- 17 files changed, 229 insertions(+), 31 deletions(-) create mode 100644 changelog.d/17386.bugfix diff --git a/changelog.d/17386.bugfix b/changelog.d/17386.bugfix new file mode 100644 index 0000000000..9686b5c276 --- /dev/null +++ b/changelog.d/17386.bugfix @@ -0,0 +1 @@ +Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0. diff --git a/synapse/notifier.py b/synapse/notifier.py index c87eb748c0..c3ecf86ec4 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -764,6 +764,13 @@ class Notifier: async def wait_for_stream_token(self, stream_token: StreamToken) -> bool: """Wait for this worker to catch up with the given stream token.""" + current_token = self.event_sources.get_current_token() + if stream_token.is_before_or_eq(current_token): + return True + + # Work around a bug where older Synapse versions gave out tokens "from + # the future", i.e. that are ahead of the tokens persisted in the DB. + stream_token = await self.event_sources.bound_future_token(stream_token) start = self.clock.time_msec() while True: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 9611a84932..966393869b 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -43,10 +43,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._instance_name in hs.config.worker.writers.account_data ) - self._account_data_id_gen: AbstractStreamIdGenerator + self._account_data_id_gen: MultiWriterIdGenerator self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) """ return self._account_data_id_gen.get_current_token() + def get_account_data_id_generator(self) -> MultiWriterIdGenerator: + return self._account_data_id_gen + @cached() async def get_global_account_data_for_user( self, user_id: str diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 5a752b9b8c..042d595ea0 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -50,10 +50,7 @@ from synapse.storage.database import ( LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache @@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): self._instance_name in hs.config.worker.writers.to_device ) - self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator( + self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator( db_conn=db_conn, db=database, notifier=hs.get_replication_notifier(), @@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self) -> int: return self._to_device_msg_id_gen.get_current_token() + def get_to_device_id_generator(self) -> MultiWriterIdGenerator: + return self._to_device_msg_id_gen + async def get_messages_for_user_devices( self, user_ids: Collection[str], diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 59a035dd62..53024bddc3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() + def get_device_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._device_list_id_gen + async def count_devices_by_users( self, user_ids: Optional[Collection[str]] = None ) -> int: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e264d36f02..198e65cfa5 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdGenerator - self._backfill_id_gen: AbstractStreamIdGenerator + self._stream_id_gen: MultiWriterIdGenerator + self._backfill_id_gen: MultiWriterIdGenerator self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 923e764491..065c885603 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -42,10 +42,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines._base import IsolationLevel from synapse.storage.types import Connection -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter @@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() - self._presence_id_gen: AbstractStreamIdGenerator + self._presence_id_gen: MultiWriterIdGenerator self._can_persist_presence = ( self._instance_name in hs.config.worker.writers.presence @@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) def get_current_presence_token(self) -> int: return self._presence_id_gen.get_current_token() + def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._presence_id_gen + def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: """Fetch non-offline presence from the database so that we can register the appropriate time outs. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2a39dc9f90..bbdde17711 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -178,6 +178,9 @@ class PushRulesWorkerStore( """ return self._push_rules_stream_id_gen.get_current_token() + def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._push_rules_stream_id_gen + def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 8432560a89..3bde0ae0d4 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -45,10 +45,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines._base import IsolationLevel -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import ( JsonDict, JsonMapping, @@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._receipts_id_gen: AbstractStreamIdGenerator + self._receipts_id_gen: MultiWriterIdGenerator self._can_write_to_receipts = ( self._instance_name in hs.config.worker.writers.receipts @@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: return self._receipts_id_gen.get_current_token_for_writer(instance_name) + def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._receipts_id_gen + def get_last_unthreaded_receipt_for_user_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index d5627b1d6e..80a4bf95f2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -59,11 +59,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - IdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): self.config: HomeServerConfig = hs.config - self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator + self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): instance_name ) + def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator: + return self._un_partial_stated_rooms_stream_id_gen + async def get_un_partial_stated_rooms_between( self, last_id: int, current_id: int, room_ids: Collection[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ff0d723684..b7eb3116ae 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -577,6 +577,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions)) + def get_events_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._stream_id_gen + async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 48f88a6f8a..e8588f33cf 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): pos = self.get_current_token_for_writer(self._instance_name) txn.execute(sql, (self._stream_name, self._instance_name, pos)) + async def get_max_allocated_token(self) -> int: + return await self._db.runInteraction( + "get_max_allocated_token", self._sequence_gen.get_max_allocated + ) + @attr.s(frozen=True, auto_attribs=True) class _AsyncCtxManagerWrapper(Generic[T]): diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index c4c0602b28..cac3eba1a5 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """ ... + @abc.abstractmethod + def get_max_allocated(self, txn: Cursor) -> int: + """Get the maximum ID that we have allocated""" + class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" @@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator): % {"seq": self._sequence_name, "stream_name": stream_name} ) + def get_max_allocated(self, txn: Cursor) -> int: + # We just read from the sequence what the last value we fetched was. + txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}") + row = txn.fetchone() + assert row is not None + + last_value, is_called = row + if not is_called: + last_value -= 1 + return last_value + GetFirstCallbackType = Callable[[Cursor], int] @@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator): # There is nothing to do for in memory sequences pass + def get_max_allocated(self, txn: Cursor) -> int: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + return self._current_max_id + def build_sequence_generator( db_conn: "LoggingDatabaseConnection", diff --git a/synapse/streams/events.py b/synapse/streams/events.py index dd7401ac8e..93d5ae1a55 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.logging.opentracing import trace from synapse.streams import EventSource -from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken +from synapse.types import ( + AbstractMultiWriterStreamToken, + MultiWriterStreamToken, + StreamKeyType, + StreamToken, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -91,6 +96,63 @@ class EventSources: ) return token + async def bound_future_token(self, token: StreamToken) -> StreamToken: + """Bound a token that is ahead of the current token to the maximum + persisted values. + + This ensures that if we wait for the given token we know the stream will + eventually advance to that point. + + This works around a bug where older Synapse versions will give out + tokens for streams, and then after a restart will give back tokens where + the stream has "gone backwards". + """ + + current_token = self.get_current_token() + + stream_key_to_id_gen = { + StreamKeyType.ROOM: self.store.get_events_stream_id_generator(), + StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(), + StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(), + StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(), + StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(), + StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(), + StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(), + StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), + } + + for _, key in StreamKeyType.__members__.items(): + if key == StreamKeyType.TYPING: + # Typing stream is allowed to "reset", and so comparisons don't + # really make sense as is. + # TODO: Figure out a better way of tracking resets. + continue + + token_value = token.get_field(key) + current_value = current_token.get_field(key) + + if isinstance(token_value, AbstractMultiWriterStreamToken): + assert type(current_value) is type(token_value) + + if not token_value.is_before_or_eq(current_value): # type: ignore[arg-type] + max_token = await stream_key_to_id_gen[ + key + ].get_max_allocated_token() + + token = token.copy_and_replace( + key, token.room_key.bound_stream_token(max_token) + ) + else: + assert isinstance(current_value, int) + if current_value < token_value: + max_token = await stream_key_to_id_gen[ + key + ].get_max_allocated_token() + + token = token.copy_and_replace(key, min(token_value, max_token)) + + return token + @trace async def get_start_token_for_pagination(self, room_id: str) -> StreamToken: """Get the start token for a given room to be used to paginate diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 151658df53..8ab9f90238 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -536,6 +536,16 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): return True + def bound_stream_token(self, max_stream: int) -> "Self": + """Bound the stream positions to a maximum value""" + + return type(self)( + stream=min(self.stream, max_stream), + instance_map=immutabledict( + {k: min(s, max_stream) for k, s in self.instance_map.items()} + ), + ) + @attr.s(frozen=True, slots=True, order=False) class RoomStreamToken(AbstractMultiWriterStreamToken): @@ -722,6 +732,14 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): else: return "s%d" % (self.stream,) + def bound_stream_token(self, max_stream: int) -> "RoomStreamToken": + """See super class""" + + # This only makes sense for stream tokens. + assert self.topological is None + + return super().bound_stream_token(max_stream) + @attr.s(frozen=True, slots=True, order=False) class MultiWriterStreamToken(AbstractMultiWriterStreamToken): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 02371ce724..5319928c28 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized +from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules @@ -35,7 +36,7 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, StreamKeyType, UserID, create_requester from synapse.util import Clock import tests.unittest @@ -959,6 +960,76 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.fail("No push rules found") + def test_wait_for_future_sync_token(self) -> None: + """Test that if we receive a token that is ahead of our current token, + we'll wait until the stream position advances. + + This can happen if replication streams start lagging, and the client's + previous sync request was serviced by a worker ahead of ours. + """ + user = self.register_user("alice", "password") + + # We simulate a lagging stream by getting a stream ID from the ID gen + # and then waiting to mark it as "persisted". + presence_id_gen = self.store.get_presence_stream_id_gen() + ctx_mgr = presence_id_gen.get_next() + stream_id = self.get_success(ctx_mgr.__aenter__()) + + # Create the new token based on the stream ID above. + current_token = self.hs.get_event_sources().get_current_token() + since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id) + + sync_d = defer.ensureDeferred( + self.sync_handler.wait_for_sync_for_user( + create_requester(user), + generate_sync_config(user), + sync_version=SyncVersion.SYNC_V2, + request_key=generate_request_key(), + since_token=since_token, + timeout=0, + ) + ) + + # This should block waiting for the presence stream to update + self.pump() + self.assertFalse(sync_d.called) + + # Marking the stream ID as persisted should unblock the request. + self.get_success(ctx_mgr.__aexit__(None, None, None)) + + self.get_success(sync_d, by=1.0) + + def test_wait_for_invalid_future_sync_token(self) -> None: + """Like the previous test, except we give a token that has a stream + position ahead of what is in the DB, i.e. its invalid and we shouldn't + wait for the stream to advance (as it may never do so). + + This can happen due to older versions of Synapse giving out stream + positions without persisting them in the DB, and so on restart the + stream would get reset back to an older position. + """ + user = self.register_user("alice", "password") + + # Create a token and arbitrarily advance one of the streams. + current_token = self.hs.get_event_sources().get_current_token() + since_token = current_token.copy_and_advance( + StreamKeyType.PRESENCE, current_token.presence_key + 1 + ) + + sync_d = defer.ensureDeferred( + self.sync_handler.wait_for_sync_for_user( + create_requester(user), + generate_sync_config(user), + sync_version=SyncVersion.SYNC_V2, + request_key=generate_request_key(), + since_token=since_token, + timeout=0, + ) + ) + + # We should return without waiting for the presence stream to advance. + self.get_success(sync_d) + def generate_sync_config( user_id: str, diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index bfb26139d3..12c11f342c 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -1386,10 +1386,12 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Create a future token that will cause us to wait. Since we never send a new # event to reach that future stream_ordering, the worker will wait until the # full timeout. + stream_id_gen = self.store.get_events_stream_id_generator() + stream_id = self.get_success(stream_id_gen.get_next().__aenter__()) current_token = self.event_sources.get_current_token() future_position_token = current_token.copy_and_replace( StreamKeyType.ROOM, - RoomStreamToken(stream=current_token.room_key.stream + 1), + RoomStreamToken(stream=stream_id), ) future_position_token_serialized = self.get_success( From 1ce59d7ba002a869ee94fbe375898cc79c6eb4d1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Jul 2024 12:39:49 +0100 Subject: [PATCH 4/7] Fix sync waiting for an invalid token from the "future" (#17386) Fixes https://github.com/element-hq/synapse/issues/17274, hopefully. Basically, old versions of Synapse could advance streams without persisting anything in the DB (fixed in #17229). On restart those updates would get lost, and so the position of the stream would revert to an older position. If this happened across an upgrade to a later Synapse version which included #17215, then sync could get blocked indefinitely (until the stream advanced to the position in the token). We fix this by bounding the stream positions we'll wait for to the maximum position of the underlying stream ID generator. --- changelog.d/17386.bugfix | 1 + synapse/notifier.py | 7 ++ .../storage/databases/main/account_data.py | 10 +-- synapse/storage/databases/main/deviceinbox.py | 10 +-- synapse/storage/databases/main/devices.py | 3 + .../storage/databases/main/events_worker.py | 4 +- synapse/storage/databases/main/presence.py | 10 +-- synapse/storage/databases/main/push_rule.py | 3 + synapse/storage/databases/main/receipts.py | 10 +-- synapse/storage/databases/main/room.py | 11 ++- synapse/storage/databases/main/stream.py | 3 + synapse/storage/util/id_generators.py | 5 ++ synapse/storage/util/sequence.py | 24 ++++++ synapse/streams/events.py | 64 +++++++++++++++- synapse/types/__init__.py | 18 +++++ tests/handlers/test_sync.py | 73 ++++++++++++++++++- tests/rest/client/test_sync.py | 4 +- 17 files changed, 229 insertions(+), 31 deletions(-) create mode 100644 changelog.d/17386.bugfix diff --git a/changelog.d/17386.bugfix b/changelog.d/17386.bugfix new file mode 100644 index 0000000000..9686b5c276 --- /dev/null +++ b/changelog.d/17386.bugfix @@ -0,0 +1 @@ +Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0. diff --git a/synapse/notifier.py b/synapse/notifier.py index c87eb748c0..c3ecf86ec4 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -764,6 +764,13 @@ class Notifier: async def wait_for_stream_token(self, stream_token: StreamToken) -> bool: """Wait for this worker to catch up with the given stream token.""" + current_token = self.event_sources.get_current_token() + if stream_token.is_before_or_eq(current_token): + return True + + # Work around a bug where older Synapse versions gave out tokens "from + # the future", i.e. that are ahead of the tokens persisted in the DB. + stream_token = await self.event_sources.bound_future_token(stream_token) start = self.clock.time_msec() while True: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 9611a84932..966393869b 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -43,10 +43,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._instance_name in hs.config.worker.writers.account_data ) - self._account_data_id_gen: AbstractStreamIdGenerator + self._account_data_id_gen: MultiWriterIdGenerator self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) """ return self._account_data_id_gen.get_current_token() + def get_account_data_id_generator(self) -> MultiWriterIdGenerator: + return self._account_data_id_gen + @cached() async def get_global_account_data_for_user( self, user_id: str diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 07333efff8..304ac42411 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -50,10 +50,7 @@ from synapse.storage.database import ( LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache @@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): self._instance_name in hs.config.worker.writers.to_device ) - self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator( + self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator( db_conn=db_conn, db=database, notifier=hs.get_replication_notifier(), @@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self) -> int: return self._to_device_msg_id_gen.get_current_token() + def get_to_device_id_generator(self) -> MultiWriterIdGenerator: + return self._to_device_msg_id_gen + async def get_messages_for_user_devices( self, user_ids: Collection[str], diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 59a035dd62..53024bddc3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() + def get_device_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._device_list_id_gen + async def count_devices_by_users( self, user_ids: Optional[Collection[str]] = None ) -> int: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e264d36f02..198e65cfa5 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdGenerator - self._backfill_id_gen: AbstractStreamIdGenerator + self._stream_id_gen: MultiWriterIdGenerator + self._backfill_id_gen: MultiWriterIdGenerator self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 923e764491..065c885603 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -42,10 +42,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines._base import IsolationLevel from synapse.storage.types import Connection -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter @@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() - self._presence_id_gen: AbstractStreamIdGenerator + self._presence_id_gen: MultiWriterIdGenerator self._can_persist_presence = ( self._instance_name in hs.config.worker.writers.presence @@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) def get_current_presence_token(self) -> int: return self._presence_id_gen.get_current_token() + def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._presence_id_gen + def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: """Fetch non-offline presence from the database so that we can register the appropriate time outs. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2a39dc9f90..bbdde17711 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -178,6 +178,9 @@ class PushRulesWorkerStore( """ return self._push_rules_stream_id_gen.get_current_token() + def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._push_rules_stream_id_gen + def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 8432560a89..3bde0ae0d4 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -45,10 +45,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines._base import IsolationLevel -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import ( JsonDict, JsonMapping, @@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._receipts_id_gen: AbstractStreamIdGenerator + self._receipts_id_gen: MultiWriterIdGenerator self._can_write_to_receipts = ( self._instance_name in hs.config.worker.writers.receipts @@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: return self._receipts_id_gen.get_current_token_for_writer(instance_name) + def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._receipts_id_gen + def get_last_unthreaded_receipt_for_user_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index d5627b1d6e..80a4bf95f2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -59,11 +59,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - IdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): self.config: HomeServerConfig = hs.config - self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator + self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): instance_name ) + def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator: + return self._un_partial_stated_rooms_stream_id_gen + async def get_un_partial_stated_rooms_between( self, last_id: int, current_id: int, room_ids: Collection[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ff0d723684..b7eb3116ae 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -577,6 +577,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions)) + def get_events_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._stream_id_gen + async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 48f88a6f8a..e8588f33cf 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): pos = self.get_current_token_for_writer(self._instance_name) txn.execute(sql, (self._stream_name, self._instance_name, pos)) + async def get_max_allocated_token(self) -> int: + return await self._db.runInteraction( + "get_max_allocated_token", self._sequence_gen.get_max_allocated + ) + @attr.s(frozen=True, auto_attribs=True) class _AsyncCtxManagerWrapper(Generic[T]): diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index c4c0602b28..cac3eba1a5 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """ ... + @abc.abstractmethod + def get_max_allocated(self, txn: Cursor) -> int: + """Get the maximum ID that we have allocated""" + class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" @@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator): % {"seq": self._sequence_name, "stream_name": stream_name} ) + def get_max_allocated(self, txn: Cursor) -> int: + # We just read from the sequence what the last value we fetched was. + txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}") + row = txn.fetchone() + assert row is not None + + last_value, is_called = row + if not is_called: + last_value -= 1 + return last_value + GetFirstCallbackType = Callable[[Cursor], int] @@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator): # There is nothing to do for in memory sequences pass + def get_max_allocated(self, txn: Cursor) -> int: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + return self._current_max_id + def build_sequence_generator( db_conn: "LoggingDatabaseConnection", diff --git a/synapse/streams/events.py b/synapse/streams/events.py index dd7401ac8e..93d5ae1a55 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.logging.opentracing import trace from synapse.streams import EventSource -from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken +from synapse.types import ( + AbstractMultiWriterStreamToken, + MultiWriterStreamToken, + StreamKeyType, + StreamToken, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -91,6 +96,63 @@ class EventSources: ) return token + async def bound_future_token(self, token: StreamToken) -> StreamToken: + """Bound a token that is ahead of the current token to the maximum + persisted values. + + This ensures that if we wait for the given token we know the stream will + eventually advance to that point. + + This works around a bug where older Synapse versions will give out + tokens for streams, and then after a restart will give back tokens where + the stream has "gone backwards". + """ + + current_token = self.get_current_token() + + stream_key_to_id_gen = { + StreamKeyType.ROOM: self.store.get_events_stream_id_generator(), + StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(), + StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(), + StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(), + StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(), + StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(), + StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(), + StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), + } + + for _, key in StreamKeyType.__members__.items(): + if key == StreamKeyType.TYPING: + # Typing stream is allowed to "reset", and so comparisons don't + # really make sense as is. + # TODO: Figure out a better way of tracking resets. + continue + + token_value = token.get_field(key) + current_value = current_token.get_field(key) + + if isinstance(token_value, AbstractMultiWriterStreamToken): + assert type(current_value) is type(token_value) + + if not token_value.is_before_or_eq(current_value): # type: ignore[arg-type] + max_token = await stream_key_to_id_gen[ + key + ].get_max_allocated_token() + + token = token.copy_and_replace( + key, token.room_key.bound_stream_token(max_token) + ) + else: + assert isinstance(current_value, int) + if current_value < token_value: + max_token = await stream_key_to_id_gen[ + key + ].get_max_allocated_token() + + token = token.copy_and_replace(key, min(token_value, max_token)) + + return token + @trace async def get_start_token_for_pagination(self, room_id: str) -> StreamToken: """Get the start token for a given room to be used to paginate diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 151658df53..8ab9f90238 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -536,6 +536,16 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): return True + def bound_stream_token(self, max_stream: int) -> "Self": + """Bound the stream positions to a maximum value""" + + return type(self)( + stream=min(self.stream, max_stream), + instance_map=immutabledict( + {k: min(s, max_stream) for k, s in self.instance_map.items()} + ), + ) + @attr.s(frozen=True, slots=True, order=False) class RoomStreamToken(AbstractMultiWriterStreamToken): @@ -722,6 +732,14 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): else: return "s%d" % (self.stream,) + def bound_stream_token(self, max_stream: int) -> "RoomStreamToken": + """See super class""" + + # This only makes sense for stream tokens. + assert self.topological is None + + return super().bound_stream_token(max_stream) + @attr.s(frozen=True, slots=True, order=False) class MultiWriterStreamToken(AbstractMultiWriterStreamToken): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 02371ce724..5319928c28 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized +from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules @@ -35,7 +36,7 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, StreamKeyType, UserID, create_requester from synapse.util import Clock import tests.unittest @@ -959,6 +960,76 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.fail("No push rules found") + def test_wait_for_future_sync_token(self) -> None: + """Test that if we receive a token that is ahead of our current token, + we'll wait until the stream position advances. + + This can happen if replication streams start lagging, and the client's + previous sync request was serviced by a worker ahead of ours. + """ + user = self.register_user("alice", "password") + + # We simulate a lagging stream by getting a stream ID from the ID gen + # and then waiting to mark it as "persisted". + presence_id_gen = self.store.get_presence_stream_id_gen() + ctx_mgr = presence_id_gen.get_next() + stream_id = self.get_success(ctx_mgr.__aenter__()) + + # Create the new token based on the stream ID above. + current_token = self.hs.get_event_sources().get_current_token() + since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id) + + sync_d = defer.ensureDeferred( + self.sync_handler.wait_for_sync_for_user( + create_requester(user), + generate_sync_config(user), + sync_version=SyncVersion.SYNC_V2, + request_key=generate_request_key(), + since_token=since_token, + timeout=0, + ) + ) + + # This should block waiting for the presence stream to update + self.pump() + self.assertFalse(sync_d.called) + + # Marking the stream ID as persisted should unblock the request. + self.get_success(ctx_mgr.__aexit__(None, None, None)) + + self.get_success(sync_d, by=1.0) + + def test_wait_for_invalid_future_sync_token(self) -> None: + """Like the previous test, except we give a token that has a stream + position ahead of what is in the DB, i.e. its invalid and we shouldn't + wait for the stream to advance (as it may never do so). + + This can happen due to older versions of Synapse giving out stream + positions without persisting them in the DB, and so on restart the + stream would get reset back to an older position. + """ + user = self.register_user("alice", "password") + + # Create a token and arbitrarily advance one of the streams. + current_token = self.hs.get_event_sources().get_current_token() + since_token = current_token.copy_and_advance( + StreamKeyType.PRESENCE, current_token.presence_key + 1 + ) + + sync_d = defer.ensureDeferred( + self.sync_handler.wait_for_sync_for_user( + create_requester(user), + generate_sync_config(user), + sync_version=SyncVersion.SYNC_V2, + request_key=generate_request_key(), + since_token=since_token, + timeout=0, + ) + ) + + # We should return without waiting for the presence stream to advance. + self.get_success(sync_d) + def generate_sync_config( user_id: str, diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index bfb26139d3..12c11f342c 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -1386,10 +1386,12 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Create a future token that will cause us to wait. Since we never send a new # event to reach that future stream_ordering, the worker will wait until the # full timeout. + stream_id_gen = self.store.get_events_stream_id_generator() + stream_id = self.get_success(stream_id_gen.get_next().__aenter__()) current_token = self.event_sources.get_current_token() future_position_token = current_token.copy_and_replace( StreamKeyType.ROOM, - RoomStreamToken(stream=current_token.room_key.stream + 1), + RoomStreamToken(stream=stream_id), ) future_position_token_serialized = self.get_success( From b905ae27caac4bb27262d9d7ac6e834de5694f10 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Jul 2024 14:06:36 +0100 Subject: [PATCH 5/7] Fix regression when bounding future tokens (#17391) Fix bug added in #17386, where we accidentally used `room_key` for the receipts stream. See first commit. Reviewable commit-by-commit --- changelog.d/17391.bugfix | 1 + synapse/streams/events.py | 26 ++++++++++++++++++++++---- tests/handlers/test_sync.py | 37 +++++++++++++++++++++++++++++++------ 3 files changed, 54 insertions(+), 10 deletions(-) create mode 100644 changelog.d/17391.bugfix diff --git a/changelog.d/17391.bugfix b/changelog.d/17391.bugfix new file mode 100644 index 0000000000..9686b5c276 --- /dev/null +++ b/changelog.d/17391.bugfix @@ -0,0 +1 @@ +Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0. diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 93d5ae1a55..856f646795 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -19,6 +19,7 @@ # # +import logging from typing import TYPE_CHECKING, Sequence, Tuple import attr @@ -41,6 +42,9 @@ if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + + @attr.s(frozen=True, slots=True, auto_attribs=True) class _EventSourcesInner: room: RoomEventSource @@ -139,9 +143,16 @@ class EventSources: key ].get_max_allocated_token() - token = token.copy_and_replace( - key, token.room_key.bound_stream_token(max_token) - ) + if max_token < token_value.get_max_stream_pos(): + logger.error( + "Bounding token from the future '%s': token: %s, bound: %s", + key, + token_value, + max_token, + ) + token = token.copy_and_replace( + key, token_value.bound_stream_token(max_token) + ) else: assert isinstance(current_value, int) if current_value < token_value: @@ -149,7 +160,14 @@ class EventSources: key ].get_max_allocated_token() - token = token.copy_and_replace(key, min(token_value, max_token)) + if max_token < token_value: + logger.error( + "Bounding token from the future '%s': token: %s, bound: %s", + key, + token_value, + max_token, + ) + token = token.copy_and_replace(key, max_token) return token diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 5319928c28..674dd4fb54 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -36,7 +36,14 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer -from synapse.types import JsonDict, StreamKeyType, UserID, create_requester +from synapse.types import ( + JsonDict, + MultiWriterStreamToken, + RoomStreamToken, + StreamKeyType, + UserID, + create_requester, +) from synapse.util import Clock import tests.unittest @@ -999,7 +1006,13 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.get_success(sync_d, by=1.0) - def test_wait_for_invalid_future_sync_token(self) -> None: + @parameterized.expand( + [(key,) for key in StreamKeyType.__members__.values()], + name_func=lambda func, _, param: f"{func.__name__}_{param.args[0].name}", + ) + def test_wait_for_invalid_future_sync_token( + self, stream_key: StreamKeyType + ) -> None: """Like the previous test, except we give a token that has a stream position ahead of what is in the DB, i.e. its invalid and we shouldn't wait for the stream to advance (as it may never do so). @@ -1010,11 +1023,23 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): """ user = self.register_user("alice", "password") - # Create a token and arbitrarily advance one of the streams. + # Create a token and advance one of the streams. current_token = self.hs.get_event_sources().get_current_token() - since_token = current_token.copy_and_advance( - StreamKeyType.PRESENCE, current_token.presence_key + 1 - ) + token_value = current_token.get_field(stream_key) + + # How we advance the streams depends on the type. + if isinstance(token_value, int): + since_token = current_token.copy_and_advance(stream_key, token_value + 1) + elif isinstance(token_value, MultiWriterStreamToken): + since_token = current_token.copy_and_advance( + stream_key, MultiWriterStreamToken(stream=token_value.stream + 1) + ) + elif isinstance(token_value, RoomStreamToken): + since_token = current_token.copy_and_advance( + stream_key, RoomStreamToken(stream=token_value.stream + 1) + ) + else: + raise Exception("Unreachable") sync_d = defer.ensureDeferred( self.sync_handler.wait_for_sync_for_user( From 8f890447b0f8b6cbe369b162670185e8c746b2f2 Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 2 Jul 2024 06:07:04 -0700 Subject: [PATCH 6/7] Support MSC3916 by adding `_matrix/client/v1/media/download` endpoint (#17365) --- changelog.d/17365.feature | 1 + docker/configure_workers_and_start.py | 3 +- docs/upgrade.md | 13 + docs/workers.md | 1 + mypy.ini | 3 + poetry.lock | 18 +- pyproject.toml | 2 + synapse/api/ratelimiting.py | 3 +- synapse/federation/federation_client.py | 46 ++ synapse/federation/transport/client.py | 25 +- .../federation/transport/server/__init__.py | 9 +- synapse/federation/transport/server/_base.py | 4 +- .../federation/transport/server/federation.py | 5 +- synapse/http/client.py | 152 +++++ synapse/http/matrixfederationclient.py | 192 ++++++ synapse/media/_base.py | 28 +- synapse/media/media_repository.py | 151 ++++- synapse/media/media_storage.py | 27 +- synapse/rest/__init__.py | 4 + synapse/rest/client/media.py | 79 ++- synapse/rest/media/download_resource.py | 1 + tests/federation/test_federation_media.py | 35 +- tests/http/test_client.py | 143 +++- tests/media/test_media_storage.py | 14 +- tests/replication/test_multi_media_repo.py | 234 ++++++- tests/rest/client/test_media.py | 609 +++++++++++++++++- 26 files changed, 1718 insertions(+), 84 deletions(-) create mode 100644 changelog.d/17365.feature diff --git a/changelog.d/17365.feature b/changelog.d/17365.feature new file mode 100644 index 0000000000..f90dc84e38 --- /dev/null +++ b/changelog.d/17365.feature @@ -0,0 +1 @@ +Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md) by adding _matrix/client/v1/media/download endpoint. \ No newline at end of file diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 063f3727f9..b6690f3404 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -117,7 +117,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { }, "media_repository": { "app": "synapse.app.generic_worker", - "listener_resources": ["media"], + "listener_resources": ["media", "client"], "endpoint_patterns": [ "^/_matrix/media/", "^/_synapse/admin/v1/purge_media_cache$", @@ -125,6 +125,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_synapse/admin/v1/user/.*/media.*$", "^/_synapse/admin/v1/media/.*$", "^/_synapse/admin/v1/quarantine_media/.*$", + "^/_matrix/client/v1/media/.*$", ], # The first configured media worker will run the media background jobs "shared_extra_conf": { diff --git a/docs/upgrade.md b/docs/upgrade.md index 99be4122bb..cf53f56b06 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -117,6 +117,19 @@ each upgrade are complete before moving on to the next upgrade, to avoid stacking them up. You can monitor the currently running background updates with [the Admin API](usage/administration/admin_api/background_updates.html#status). +# Upgrading to v1.111.0 + +## New worker endpoints for authenticated client media + +[Media repository workers](./workers.md#synapseappmedia_repository) handling +Media APIs can now handle the following endpoint pattern: + +``` +^/_matrix/client/v1/media/.*$ +``` + +Please update your reverse proxy configuration. + # Upgrading to v1.106.0 ## Minimum supported Rust version diff --git a/docs/workers.md b/docs/workers.md index 1f6bfd9e7f..22fde488a9 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -739,6 +739,7 @@ An example for a federation sender instance: Handles the media repository. It can handle all endpoints starting with: /_matrix/media/ + /_matrix/client/v1/media/ ... and the following regular expressions matching media-specific administration APIs: diff --git a/mypy.ini b/mypy.ini index 1a2b9ea410..3fca15c01b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -96,3 +96,6 @@ ignore_missing_imports = True # https://github.com/twisted/treq/pull/366 [mypy-treq.*] ignore_missing_imports = True + +[mypy-multipart.*] +ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index 99c3b62c7d..8142406e3f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -2039,6 +2039,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-multipart" +version = "0.0.9" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"}, + {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"}, +] + +[package.extras] +dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"] + [[package]] name = "pytz" version = "2022.7.1" @@ -3187,4 +3201,4 @@ user-search = ["pyicu"] [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "107c8fb5c67360340854fbdba3c085fc5f9c7be24bcb592596a914eea621faea" +content-hash = "e8d5806e10eb69bc06900fde18ea3df38f38490ab6baa73fe4a563dfb6abacba" diff --git a/pyproject.toml b/pyproject.toml index bbf9c78420..0555e67613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,8 @@ pydantic = ">=1.7.4, <3" # needed. setuptools_rust = ">=1.3" +# This is used for parsing multipart responses +python-multipart = ">=0.0.9" # Optional Dependencies # --------------------- diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index a99a9e09fc..26b8711851 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -130,7 +130,8 @@ class Ratelimiter: Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. Overrides the value set during instantiation if set. - update: Whether to count this check as performing the action + update: Whether to count this check as performing the action. If the action + cannot be performed, the user's action count is not incremented at all. n_actions: The number of times the user wants to do this action. If the user cannot do all of the actions, the user's action count is not incremented at all. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f0f5a37a57..7d80ff6998 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1871,6 +1871,52 @@ class FederationClient(FederationBase): return filtered_statuses, filtered_failures + async def federation_download_media( + self, + destination: str, + media_id: str, + output_stream: BinaryIO, + max_size: int, + max_timeout_ms: int, + download_ratelimiter: Ratelimiter, + ip_address: str, + ) -> Union[ + Tuple[int, Dict[bytes, List[bytes]], bytes], + Tuple[int, Dict[bytes, List[bytes]]], + ]: + try: + return await self.transport_layer.federation_download_media( + destination, + media_id, + output_stream=output_stream, + max_size=max_size, + max_timeout_ms=max_timeout_ms, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, + ) + except HttpResponseException as e: + # If an error is received that is due to an unrecognised endpoint, + # fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error + # and raise. + if not is_unknown_endpoint(e): + raise + + logger.debug( + "Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path", + destination, + media_id, + ) + + return await self.transport_layer.download_media_v3( + destination, + media_id, + output_stream=output_stream, + max_size=max_size, + max_timeout_ms=max_timeout_ms, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, + ) + async def download_media( self, destination: str, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index af1336fe5f..206e91ed14 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -824,7 +824,6 @@ class TransportLayerClient: ip_address: str, ) -> Tuple[int, Dict[bytes, List[bytes]]]: path = f"/_matrix/media/r0/download/{destination}/{media_id}" - return await self.client.get_file( destination, path, @@ -852,7 +851,6 @@ class TransportLayerClient: ip_address: str, ) -> Tuple[int, Dict[bytes, List[bytes]]]: path = f"/_matrix/media/v3/download/{destination}/{media_id}" - return await self.client.get_file( destination, path, @@ -873,6 +871,29 @@ class TransportLayerClient: ip_address=ip_address, ) + async def federation_download_media( + self, + destination: str, + media_id: str, + output_stream: BinaryIO, + max_size: int, + max_timeout_ms: int, + download_ratelimiter: Ratelimiter, + ip_address: str, + ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]: + path = f"/_matrix/federation/v1/media/download/{media_id}" + return await self.client.federation_get_file( + destination, + path, + output_stream=output_stream, + max_size=max_size, + args={ + "timeout_ms": str(max_timeout_ms), + }, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index edaf0196d6..c44e5daa47 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -32,8 +32,8 @@ from synapse.federation.transport.server._base import ( from synapse.federation.transport.server.federation import ( FEDERATION_SERVLET_CLASSES, FederationAccountStatusServlet, + FederationMediaDownloadServlet, FederationUnstableClientKeysClaimServlet, - FederationUnstableMediaDownloadServlet, ) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( @@ -316,11 +316,8 @@ def register_servlets( ): continue - if servletclass == FederationUnstableMediaDownloadServlet: - if ( - not hs.config.server.enable_media_repo - or not hs.config.experimental.msc3916_authenticated_media_enabled - ): + if servletclass == FederationMediaDownloadServlet: + if not hs.config.server.enable_media_repo: continue servletclass( diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 4e2717b565..e124481474 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -362,7 +362,7 @@ class BaseFederationServlet: return None if ( func.__self__.__class__.__name__ # type: ignore - == "FederationUnstableMediaDownloadServlet" + == "FederationMediaDownloadServlet" ): response = await func( origin, content, request, *args, **kwargs @@ -374,7 +374,7 @@ class BaseFederationServlet: else: if ( func.__self__.__class__.__name__ # type: ignore - == "FederationUnstableMediaDownloadServlet" + == "FederationMediaDownloadServlet" ): response = await func( origin, content, request, *args, **kwargs diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 67bb907050..ec957768d4 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -790,7 +790,7 @@ class FederationAccountStatusServlet(BaseFederationServerServlet): return 200, {"account_statuses": statuses, "failures": failures} -class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet): +class FederationMediaDownloadServlet(BaseFederationServerServlet): """ Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns a multipart/mixed response consisting of a JSON object and the requested media @@ -798,7 +798,6 @@ class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet): """ PATH = "/media/download/(?P[^/]*)" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916" RATELIMIT = True def __init__( @@ -858,5 +857,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationV1SendKnockServlet, FederationMakeKnockServlet, FederationAccountStatusServlet, - FederationUnstableMediaDownloadServlet, + FederationMediaDownloadServlet, ) diff --git a/synapse/http/client.py b/synapse/http/client.py index 4718517c97..56ad28eabf 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -35,6 +35,8 @@ from typing import ( Union, ) +import attr +import multipart import treq from canonicaljson import encode_canonical_json from netaddr import AddrFormatError, IPAddress, IPSet @@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): self._maybe_fail() +@attr.s(auto_attribs=True, slots=True) +class MultipartResponse: + """ + A small class to hold parsed values of a multipart response. + """ + + json: bytes = b"{}" + length: Optional[int] = None + content_type: Optional[bytes] = None + disposition: Optional[bytes] = None + url: Optional[bytes] = None + + +class _MultipartParserProtocol(protocol.Protocol): + """ + Protocol to read and parse a MSC3916 multipart/mixed response + """ + + transport: Optional[ITCPTransport] = None + + def __init__( + self, + stream: ByteWriteable, + deferred: defer.Deferred, + boundary: str, + max_length: Optional[int], + ) -> None: + self.stream = stream + self.deferred = deferred + self.boundary = boundary + self.max_length = max_length + self.parser = None + self.multipart_response = MultipartResponse() + self.has_redirect = False + self.in_json = False + self.json_done = False + self.file_length = 0 + self.total_length = 0 + self.in_disposition = False + self.in_content_type = False + + def dataReceived(self, incoming_data: bytes) -> None: + if self.deferred.called: + return + + # we don't have a parser yet, instantiate it + if not self.parser: + + def on_header_field(data: bytes, start: int, end: int) -> None: + if data[start:end] == b"Location": + self.has_redirect = True + if data[start:end] == b"Content-Disposition": + self.in_disposition = True + if data[start:end] == b"Content-Type": + self.in_content_type = True + + def on_header_value(data: bytes, start: int, end: int) -> None: + # the first header should be content-type for application/json + if not self.in_json and not self.json_done: + assert data[start:end] == b"application/json" + self.in_json = True + elif self.has_redirect: + self.multipart_response.url = data[start:end] + elif self.in_content_type: + self.multipart_response.content_type = data[start:end] + self.in_content_type = False + elif self.in_disposition: + self.multipart_response.disposition = data[start:end] + self.in_disposition = False + + def on_part_data(data: bytes, start: int, end: int) -> None: + # we've seen json header but haven't written the json data + if self.in_json and not self.json_done: + self.multipart_response.json = data[start:end] + self.json_done = True + # we have a redirect header rather than a file, and have already captured it + elif self.has_redirect: + return + # otherwise we are in the file part + else: + logger.info("Writing multipart file data to stream") + try: + self.stream.write(data[start:end]) + except Exception as e: + logger.warning( + f"Exception encountered writing file data to stream: {e}" + ) + self.deferred.errback() + self.file_length += end - start + + callbacks = { + "on_header_field": on_header_field, + "on_header_value": on_header_value, + "on_part_data": on_part_data, + } + self.parser = multipart.MultipartParser(self.boundary, callbacks) + + self.total_length += len(incoming_data) + if self.max_length is not None and self.total_length >= self.max_length: + self.deferred.errback(BodyExceededMaxSize()) + # Close the connection (forcefully) since all the data will get + # discarded anyway. + assert self.transport is not None + self.transport.abortConnection() + + try: + self.parser.write(incoming_data) # type: ignore[attr-defined] + except Exception as e: + logger.warning(f"Exception writing to multipart parser: {e}") + self.deferred.errback() + return + + def connectionLost(self, reason: Failure = connectionDone) -> None: + # If the maximum size was already exceeded, there's nothing to do. + if self.deferred.called: + return + + if reason.check(ResponseDone): + self.multipart_response.length = self.file_length + self.deferred.callback(self.multipart_response) + else: + self.deferred.errback(reason) + + class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" @@ -1091,6 +1217,32 @@ def read_body_with_max_size( return d +def read_multipart_response( + response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int] +) -> "defer.Deferred[MultipartResponse]": + """ + Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into + the stream passed in and returning a deferred resolving to a MultipartResponse + + Args: + response: The HTTP response to read from. + stream: The file-object to write to. + boundary: the multipart/mixed boundary string + max_length: maximum allowable length of the response + """ + d: defer.Deferred[MultipartResponse] = defer.Deferred() + + # If the Content-Length header gives a size larger than the maximum allowed + # size, do not bother downloading the body. + if max_length is not None and response.length != UNKNOWN_LENGTH: + if response.length > max_length: + response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d)) + return d + + response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length)) + return d + + def encode_query_args(args: Optional[QueryParams]) -> bytes: """ Encodes a map of query arguments to bytes which can be appended to a URL. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 104b803b0f..749b01dd0e 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -75,9 +75,11 @@ from synapse.http.client import ( BlocklistingAgentWrapper, BodyExceededMaxSize, ByteWriteable, + SimpleHttpClient, _make_scheduler, encode_query_args, read_body_with_max_size, + read_multipart_response, ) from synapse.http.connectproxyclient import BearerProxyCredentials from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent @@ -466,6 +468,13 @@ class MatrixFederationHttpClient: self._sleeper = AwakenableSleeper(self.reactor) + self._simple_http_client = SimpleHttpClient( + hs, + ip_blocklist=hs.config.server.federation_ip_range_blocklist, + ip_allowlist=hs.config.server.federation_ip_range_allowlist, + use_proxy=True, + ) + def wake_destination(self, destination: str) -> None: """Called when the remote server may have come back online.""" @@ -1553,6 +1562,189 @@ class MatrixFederationHttpClient: ) return length, headers + async def federation_get_file( + self, + destination: str, + path: str, + output_stream: BinaryIO, + download_ratelimiter: Ratelimiter, + ip_address: str, + max_size: int, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + ignore_backoff: bool = False, + ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]: + """GETs a file from a given homeserver over the federation /download endpoint + Args: + destination: The remote server to send the HTTP request to. + path: The HTTP path to GET. + output_stream: File to write the response body to. + download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to + requester IP + ip_address: IP address of the requester + max_size: maximum allowable size in bytes of the file + args: Optional dictionary used to create the query string. + ignore_backoff: true to ignore the historical backoff data + and try the request anyway. + + Returns: + Resolves to an (int, dict, bytes) tuple of + the file length, a dict of the response headers, and the file json + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. + SynapseError: If the requested file exceeds ratelimits or the response from the + remote server is not a multipart response + AssertionError: if the resolved multipart response's length is None + """ + request = MatrixFederationRequest( + method="GET", destination=destination, path=path, query=args + ) + + # check for a minimum balance of 1MiB in ratelimiter before initiating request + send_req, _ = await download_ratelimiter.can_do_action( + requester=None, key=ip_address, n_actions=1048576, update=False + ) + + if not send_req: + msg = "Requested file size exceeds ratelimits" + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED) + + response = await self._send_request( + request, + retry_on_dns_fail=retry_on_dns_fail, + ignore_backoff=ignore_backoff, + ) + + headers = dict(response.headers.getAllRawHeaders()) + + expected_size = response.length + # if we don't get an expected length then use the max length + if expected_size == UNKNOWN_LENGTH: + expected_size = max_size + logger.debug( + f"File size unknown, assuming file is max allowable size: {max_size}" + ) + + read_body, _ = await download_ratelimiter.can_do_action( + requester=None, + key=ip_address, + n_actions=expected_size, + ) + if not read_body: + msg = "Requested file size exceeds ratelimits" + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED) + + # this should be a multipart/mixed response with the boundary string in the header + try: + raw_content_type = headers.get(b"Content-Type") + assert raw_content_type is not None + content_type = raw_content_type[0].decode("UTF-8") + content_type_parts = content_type.split("boundary=") + boundary = content_type_parts[1] + except Exception: + msg = "Remote response is malformed: expected Content-Type of multipart/mixed with a boundary present." + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg) + + try: + # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >= + deferred = read_multipart_response( + response, output_stream, boundary, expected_size + 1 + ) + deferred.addTimeout(self.default_timeout_seconds, self.reactor) + except BodyExceededMaxSize: + msg = "Requested file is too large > %r bytes" % (expected_size,) + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) + except defer.TimeoutError as e: + logger.warning( + "{%s} [%s] Timed out reading response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e + except ResponseFailed as e: + logger.warning( + "{%s} [%s] Failed to read response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e + except Exception as e: + logger.warning( + "{%s} [%s] Error reading response: %s", + request.txn_id, + request.destination, + e, + ) + raise + + multipart_response = await make_deferred_yieldable(deferred) + if not multipart_response.url: + assert multipart_response.length is not None + length = multipart_response.length + headers[b"Content-Type"] = [multipart_response.content_type] + headers[b"Content-Disposition"] = [multipart_response.disposition] + + # the response contained a redirect url to download the file from + else: + str_url = multipart_response.url.decode("utf-8") + logger.info( + "{%s} [%s] File download redirected, now downloading from: %s", + request.txn_id, + request.destination, + str_url, + ) + length, headers, _, _ = await self._simple_http_client.get_file( + str_url, output_stream, expected_size + ) + + logger.info( + "{%s} [%s] Completed: %d %s [%d bytes] %s %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode("ascii", errors="replace"), + length, + request.method, + request.uri.decode("ascii"), + ) + return length, headers, multipart_response.json + def _flatten_response_never_received(e: BaseException) -> str: if hasattr(e, "reasons"): diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 7ad0b7c3cf..1b268ce4d4 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -221,6 +221,7 @@ def add_file_headers( # select private. don't bother setting Expires as all our # clients are smart enough to be happy with Cache-Control request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") + if file_size is not None: request.setHeader(b"Content-Length", b"%d" % (file_size,)) @@ -302,12 +303,37 @@ async def respond_with_multipart_responder( ) return + if media_info.media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES: + disposition = "inline" + else: + disposition = "attachment" + + def _quote(x: str) -> str: + return urllib.parse.quote(x.encode("utf-8")) + + if media_info.upload_name: + if _can_encode_filename_as_token(media_info.upload_name): + disposition = "%s; filename=%s" % ( + disposition, + media_info.upload_name, + ) + else: + disposition = "%s; filename*=utf-8''%s" % ( + disposition, + _quote(media_info.upload_name), + ) + from synapse.media.media_storage import MultipartFileConsumer # note that currently the json_object is just {}, this will change when linked media # is implemented multipart_consumer = MultipartFileConsumer( - clock, request, media_info.media_type, {}, media_info.media_length + clock, + request, + media_info.media_type, + {}, + disposition, + media_info.media_length, ) logger.debug("Responding to media request with responder %s", responder) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 1436329fad..542642b900 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -480,6 +480,7 @@ class MediaRepository: name: Optional[str], max_timeout_ms: int, ip_address: str, + use_federation_endpoint: bool, ) -> None: """Respond to requests for remote media. @@ -492,6 +493,8 @@ class MediaRepository: max_timeout_ms: the maximum number of milliseconds to wait for the media to be uploaded. ip_address: the IP address of the requester + use_federation_endpoint: whether to request the remote media over the new + federation `/download` endpoint Returns: Resolves once a response has successfully been written to request @@ -522,6 +525,7 @@ class MediaRepository: max_timeout_ms, self.download_ratelimiter, ip_address, + use_federation_endpoint, ) # We deliberately stream the file outside the lock @@ -569,6 +573,7 @@ class MediaRepository: max_timeout_ms, self.download_ratelimiter, ip_address, + False, ) # Ensure we actually use the responder so that it releases resources @@ -585,6 +590,7 @@ class MediaRepository: max_timeout_ms: int, download_ratelimiter: Ratelimiter, ip_address: str, + use_federation_endpoint: bool, ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -598,6 +604,8 @@ class MediaRepository: download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to requester IP. ip_address: the IP address of the requester + use_federation_endpoint: whether to request the remote media over the new federation + /download endpoint Returns: A tuple of responder and the media info of the file. @@ -629,9 +637,23 @@ class MediaRepository: # Failed to find the file anywhere, lets download it. try: - media_info = await self._download_remote_file( - server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address - ) + if not use_federation_endpoint: + media_info = await self._download_remote_file( + server_name, + media_id, + max_timeout_ms, + download_ratelimiter, + ip_address, + ) + else: + media_info = await self._federation_download_remote_file( + server_name, + media_id, + max_timeout_ms, + download_ratelimiter, + ip_address, + ) + except SynapseError: raise except Exception as e: @@ -775,6 +797,129 @@ class MediaRepository: quarantined_by=None, ) + async def _federation_download_remote_file( + self, + server_name: str, + media_id: str, + max_timeout_ms: int, + download_ratelimiter: Ratelimiter, + ip_address: str, + ) -> RemoteMedia: + """Attempt to download the remote file from the given server name. + Uses the given file_id as the local id and downloads the file over the federation + v1 download endpoint + + Args: + server_name: Originating server + media_id: The media ID of the content (as defined by the + remote server). This is different than the file_id, which is + locally generated. + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. + download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to + requester IP + ip_address: the IP address of the requester + + Returns: + The media info of the file. + """ + + file_id = random_string(24) + + file_info = FileInfo(server_name=server_name, file_id=file_id) + + async with self.media_storage.store_into_file(file_info) as (f, fname): + try: + res = await self.client.federation_download_media( + server_name, + media_id, + output_stream=f, + max_size=self.max_upload_size, + max_timeout_ms=max_timeout_ms, + download_ratelimiter=download_ratelimiter, + ip_address=ip_address, + ) + # if we had to fall back to the _matrix/media endpoint it will only return + # the headers and length, check the length of the tuple before unpacking + if len(res) == 3: + length, headers, json = res + else: + length, headers = res + except RequestSendFailed as e: + logger.warning( + "Request failed fetching remote media %s/%s: %r", + server_name, + media_id, + e, + ) + raise SynapseError(502, "Failed to fetch remote media") + + except HttpResponseException as e: + logger.warning( + "HTTP error fetching remote media %s/%s: %s", + server_name, + media_id, + e.response, + ) + if e.code == twisted.web.http.NOT_FOUND: + raise e.to_synapse_error() + raise SynapseError(502, "Failed to fetch remote media") + + except SynapseError: + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) + raise + except NotRetryingDestination: + logger.warning("Not retrying destination %r", server_name) + raise SynapseError(502, "Failed to fetch remote media") + except Exception: + logger.exception( + "Failed to fetch remote media %s/%s", server_name, media_id + ) + raise SynapseError(502, "Failed to fetch remote media") + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + upload_name = get_filename_from_headers(headers) + time_now_ms = self.clock.time_msec() + + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk (as we're in the ctx manager). + # + # However: we've already called `finish()` so we may have also + # written to the storage providers. This is preferable to the + # alternative where we call `finish()` *after* this, where we could + # end up having an entry in the DB but fail to write the files to + # the storage providers. + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=time_now_ms, + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) + + logger.debug("Stored remote media in file %r", fname) + + return RemoteMedia( + media_origin=server_name, + media_id=media_id, + media_type=media_type, + media_length=length, + upload_name=upload_name, + created_ts=time_now_ms, + filesystem_id=file_id, + last_access_ts=time_now_ms, + quarantined_by=None, + ) + def _get_thumbnail_requirements( self, media_type: str ) -> Tuple[ThumbnailRequirement, ...]: diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index 1be2c9b5f5..2a106bb0eb 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -401,13 +401,14 @@ class MultipartFileConsumer: wrapped_consumer: interfaces.IConsumer, file_content_type: str, json_object: JsonDict, - content_length: Optional[int] = None, + disposition: str, + content_length: Optional[int], ) -> None: self.clock = clock self.wrapped_consumer = wrapped_consumer self.json_field = json_object self.json_field_written = False - self.content_type_written = False + self.file_headers_written = False self.file_content_type = file_content_type self.boundary = uuid4().hex.encode("ascii") @@ -420,6 +421,7 @@ class MultipartFileConsumer: self.paused = False self.length = content_length + self.disposition = disposition ### IConsumer APIs ### @@ -488,11 +490,13 @@ class MultipartFileConsumer: self.json_field_written = True # if we haven't written the content type yet, do so - if not self.content_type_written: + if not self.file_headers_written: type = self.file_content_type.encode("utf-8") content_type = Header(b"Content-Type", type) - self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF) - self.content_type_written = True + self.wrapped_consumer.write(bytes(content_type) + CRLF) + disp_header = Header(b"Content-Disposition", self.disposition) + self.wrapped_consumer.write(bytes(disp_header) + CRLF + CRLF) + self.file_headers_written = True self.wrapped_consumer.write(data) @@ -506,7 +510,6 @@ class MultipartFileConsumer: producing data for good. """ assert self.producer is not None - self.paused = True self.producer.stopProducing() @@ -518,7 +521,6 @@ class MultipartFileConsumer: the time being, and to stop until C{resumeProducing()} is called. """ assert self.producer is not None - self.paused = True if self.streaming: @@ -549,7 +551,7 @@ class MultipartFileConsumer: """ if not self.length: return None - # calculate length of json field and content-type header + # calculate length of json field and content-type, disposition headers json_field = json.dumps(self.json_field) json_bytes = json_field.encode("utf-8") json_length = len(json_bytes) @@ -558,9 +560,13 @@ class MultipartFileConsumer: content_type = Header(b"Content-Type", type) type_length = len(bytes(content_type)) - # 154 is the length of the elements that aren't variable, ie + disp = self.disposition.encode("utf-8") + disp_header = Header(b"Content-Disposition", disp) + disp_length = len(bytes(disp_header)) + + # 156 is the length of the elements that aren't variable, ie # CRLFs and boundary strings, etc - self.length += json_length + type_length + 154 + self.length += json_length + type_length + disp_length + 156 return self.length @@ -569,7 +575,6 @@ class MultipartFileConsumer: async def _resumeProducingRepeatedly(self) -> None: assert self.producer is not None assert not self.streaming - producer = cast("interfaces.IPullProducer", self.producer) self.paused = False diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 0024ccf708..c94d454a28 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -145,6 +145,10 @@ class ClientRestResource(JsonResource): password_policy.register_servlets(hs, client_resource) knock.register_servlets(hs, client_resource) appservice_ping.register_servlets(hs, client_resource) + if hs.config.server.enable_media_repo: + from synapse.rest.client import media + + media.register_servlets(hs, client_resource) # moving to /_synapse/admin if is_main_process: diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 0c089163c1..c0ae5dd66f 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -22,6 +22,7 @@ import logging import re +from typing import Optional from synapse.http.server import ( HttpServer, @@ -194,14 +195,76 @@ class UnstableThumbnailResource(RestServlet): self.media_repo.mark_recently_accessed(server_name, media_id) +class DownloadResource(RestServlet): + PATTERNS = [ + re.compile( + "/_matrix/client/v1/media/download/(?P[^/]*)/(?P[^/]*)(/(?P[^/]*))?$" + ) + ] + + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): + super().__init__() + self.media_repo = media_repo + self._is_mine_server_name = hs.is_mine_server_name + self.auth = hs.get_auth() + + async def on_GET( + self, + request: SynapseRequest, + server_name: str, + media_id: str, + file_name: Optional[str] = None, + ) -> None: + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + + await self.auth.get_user_by_req(request) + + set_cors_headers(request) + set_corp_headers(request) + request.setHeader( + b"Content-Security-Policy", + b"sandbox;" + b" default-src 'none';" + b" script-src 'none';" + b" plugin-types application/pdf;" + b" style-src 'unsafe-inline';" + b" media-src 'self';" + b" object-src 'self';", + ) + # Limited non-standard form of CSP for IE11 + request.setHeader(b"X-Content-Security-Policy", b"sandbox;") + request.setHeader(b"Referrer-Policy", b"no-referrer") + max_timeout_ms = parse_integer( + request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS + ) + max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS) + + if self._is_mine_server_name(server_name): + await self.media_repo.get_local_media( + request, media_id, file_name, max_timeout_ms + ) + else: + ip_address = request.getClientAddress().host + await self.media_repo.get_remote_media( + request, + server_name, + media_id, + file_name, + max_timeout_ms, + ip_address, + True, + ) + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if hs.config.experimental.msc3916_authenticated_media_enabled: - media_repo = hs.get_media_repository() - if hs.config.media.url_preview_enabled: - UnstablePreviewURLServlet( - hs, media_repo, media_repo.media_storage - ).register(http_server) - UnstableMediaConfigResource(hs).register(http_server) - UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register( + media_repo = hs.get_media_repository() + if hs.config.media.url_preview_enabled: + UnstablePreviewURLServlet(hs, media_repo, media_repo.media_storage).register( http_server ) + UnstableMediaConfigResource(hs).register(http_server) + UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register( + http_server + ) + DownloadResource(hs, media_repo).register(http_server) diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py index 1628d58926..c32c626905 100644 --- a/synapse/rest/media/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -105,4 +105,5 @@ class DownloadResource(RestServlet): file_name, max_timeout_ms, ip_address, + False, ) diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index 2c396adbe3..142f73cfdb 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -36,10 +36,9 @@ from synapse.util import Clock from tests import unittest from tests.test_utils import SMALL_PNG -from tests.unittest import override_config -class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase): +class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) @@ -65,9 +64,6 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase ) self.media_repo = hs.get_media_repository() - @override_config( - {"experimental_features": {"msc3916_authenticated_media_enabled": True}} - ) def test_file_download(self) -> None: content = io.BytesIO(b"file_to_stream") content_uri = self.get_success( @@ -82,7 +78,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase # test with a text file channel = self.make_signed_federation_request( "GET", - f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}", + f"/_matrix/federation/v1/media/download/{content_uri.media_id}", ) self.pump() self.assertEqual(200, channel.code) @@ -106,7 +102,8 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase # check that the text file and expected value exist found_file = any( - "\r\nContent-Type: text/plain\r\n\r\nfile_to_stream" in field + "\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream" + in field for field in stripped ) self.assertTrue(found_file) @@ -124,7 +121,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase # test with an image file channel = self.make_signed_federation_request( "GET", - f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}", + f"/_matrix/federation/v1/media/download/{content_uri.media_id}", ) self.pump() self.assertEqual(200, channel.code) @@ -149,25 +146,3 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase # check that the png file exists and matches what was uploaded found_file = any(SMALL_PNG in field for field in stripped_bytes) self.assertTrue(found_file) - - @override_config( - {"experimental_features": {"msc3916_authenticated_media_enabled": False}} - ) - def test_disable_config(self) -> None: - content = io.BytesIO(b"file_to_stream") - content_uri = self.get_success( - self.media_repo.create_content( - "text/plain", - "test_upload", - content, - 46, - UserID.from_string("@user_id:whatever.org"), - ) - ) - channel = self.make_signed_federation_request( - "GET", - f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}", - ) - self.pump() - self.assertEqual(404, channel.code) - self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED") diff --git a/tests/http/test_client.py b/tests/http/test_client.py index a98091d711..721917f957 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -37,18 +37,155 @@ from synapse.http.client import ( BlocklistingAgentWrapper, BlocklistingReactorWrapper, BodyExceededMaxSize, + MultipartResponse, _DiscardBodyWithMaxSizeProtocol, + _MultipartParserProtocol, read_body_with_max_size, + read_multipart_response, ) from tests.server import FakeTransport, get_clock from tests.unittest import TestCase +class ReadMultipartResponseTests(TestCase): + data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_" + data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" + + redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" + + def _build_multipart_response( + self, response_length: Union[int, str], max_length: int + ) -> Tuple[ + BytesIO, + "Deferred[MultipartResponse]", + _MultipartParserProtocol, + ]: + """Start reading the body, returns the response, result and proto""" + response = Mock(length=response_length) + result = BytesIO() + boundary = "6067d4698f8d40a0a794ea7d7379d53a" + deferred = read_multipart_response(response, result, boundary, max_length) + + # Fish the protocol out of the response. + protocol = response.deliverBody.call_args[0][0] + protocol.transport = Mock() + + return result, deferred, protocol + + def _assert_error( + self, + deferred: "Deferred[MultipartResponse]", + protocol: _MultipartParserProtocol, + ) -> None: + """Ensure that the expected error is received.""" + assert isinstance(deferred.result, Failure) + self.assertIsInstance(deferred.result.value, BodyExceededMaxSize) + assert protocol.transport is not None + # type-ignore: presumably abortConnection has been replaced with a Mock. + protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined] + + def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None: + """Ensure that the error in the Deferred is handled gracefully.""" + called = [False] + + def errback(f: Failure) -> None: + called[0] = True + + deferred.addErrback(errback) + self.assertTrue(called[0]) + + def test_parse_file(self) -> None: + """ + Check that a multipart response containing a file is properly parsed + into the json/file parts, and the json and file are properly captured + """ + result, deferred, protocol = self._build_multipart_response(249, 250) + + # Start sending data. + protocol.dataReceived(self.data1) + protocol.dataReceived(self.data2) + # Close the connection. + protocol.connectionLost(Failure(ResponseDone())) + + multipart_response: MultipartResponse = deferred.result # type: ignore[assignment] + + self.assertEqual(multipart_response.json, b"{}") + self.assertEqual(result.getvalue(), b"file_to_stream") + self.assertEqual(multipart_response.length, len(b"file_to_stream")) + self.assertEqual(multipart_response.content_type, b"text/plain") + self.assertEqual( + multipart_response.disposition, b"inline; filename=test_upload" + ) + + def test_parse_redirect(self) -> None: + """ + check that a multipart response containing a redirect is properly parsed and redirect url is + returned + """ + result, deferred, protocol = self._build_multipart_response(249, 250) + + # Start sending data. + protocol.dataReceived(self.redirect_data) + # Close the connection. + protocol.connectionLost(Failure(ResponseDone())) + + multipart_response: MultipartResponse = deferred.result # type: ignore[assignment] + + self.assertEqual(multipart_response.json, b"{}") + self.assertEqual(result.getvalue(), b"") + self.assertEqual( + multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt" + ) + + def test_too_large(self) -> None: + """A response which is too large raises an exception.""" + result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180) + + # Start sending data. + protocol.dataReceived(self.data1) + + self.assertEqual(result.getvalue(), b"file_") + self._assert_error(deferred, protocol) + self._cleanup_error(deferred) + + def test_additional_data(self) -> None: + """A connection can receive data after being closed.""" + result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180) + + # Start sending data. + protocol.dataReceived(self.data1) + self._assert_error(deferred, protocol) + + # More data might have come in. + protocol.dataReceived(self.data2) + + self.assertEqual(result.getvalue(), b"file_") + self._assert_error(deferred, protocol) + self._cleanup_error(deferred) + + def test_content_length(self) -> None: + """The body shouldn't be read (at all) if the Content-Length header is too large.""" + result, deferred, protocol = self._build_multipart_response(250, 1) + + # Deferred shouldn't be called yet. + self.assertFalse(deferred.called) + + # Start sending data. + protocol.dataReceived(self.data1) + self._assert_error(deferred, protocol) + self._cleanup_error(deferred) + + # The data is never consumed. + self.assertEqual(result.getvalue(), b"") + + class ReadBodyWithMaxSizeTests(TestCase): - def _build_response( - self, length: Union[int, str] = UNKNOWN_LENGTH - ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]: + def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[ + BytesIO, + "Deferred[int]", + _DiscardBodyWithMaxSizeProtocol, + ]: """Start reading the body, returns the response, result and proto""" response = Mock(length=length) result = BytesIO() diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 46d20ce775..024086b775 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -129,7 +129,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): @attr.s(auto_attribs=True, slots=True, frozen=True) -class _TestImage: +class TestImage: """An image for testing thumbnailing with the expected results Attributes: @@ -158,7 +158,7 @@ class _TestImage: is_inline: bool = True -small_png = _TestImage( +small_png = TestImage( SMALL_PNG, b"image/png", b".png", @@ -175,7 +175,7 @@ small_png = _TestImage( ), ) -small_png_with_transparency = _TestImage( +small_png_with_transparency = TestImage( unhexlify( b"89504e470d0a1a0a0000000d49484452000000010000000101000" b"00000376ef9240000000274524e5300010194fdae0000000a4944" @@ -188,7 +188,7 @@ small_png_with_transparency = _TestImage( # different versions of Pillow. ) -small_lossless_webp = _TestImage( +small_lossless_webp = TestImage( unhexlify( b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700" ), @@ -196,7 +196,7 @@ small_lossless_webp = _TestImage( b".webp", ) -empty_file = _TestImage( +empty_file = TestImage( b"", b"image/gif", b".gif", @@ -204,7 +204,7 @@ empty_file = _TestImage( unable_to_thumbnail=True, ) -SVG = _TestImage( +SVG = TestImage( b""" @@ -236,7 +236,7 @@ urls = [ @parameterized_class(("test_image", "url"), itertools.product(test_images, urls)) class MediaRepoTests(unittest.HomeserverTestCase): servlets = [media.register_servlets] - test_image: ClassVar[_TestImage] + test_image: ClassVar[TestImage] hijack_auth = True user_id = "@test:user" url: ClassVar[str] diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 4927e45446..6fc4600c41 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -28,7 +28,7 @@ from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.rest import admin -from synapse.rest.client import login +from synapse.rest.client import login, media from synapse.server import HomeServer from synapse.util import Clock @@ -255,6 +255,238 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): return sum(len(files) for _, _, files in os.walk(path)) +class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks running multiple media repos work correctly using autheticated media paths""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + media.register_servlets, + ] + + file_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + self.reactor.lookups["example.com"] = "1.2.3.4" + + def default_config(self) -> dict: + conf = super().default_config() + conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] + return conf + + def make_worker_hs( + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any + ) -> HomeServer: + worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) + # Force the media paths onto the replication resource. + worker_hs.get_media_repository_resource().register_servlets( + self._hs_to_site[worker_hs].resource, worker_hs + ) + return worker_hs + + def _get_media_req( + self, hs: HomeServer, target: str, media_id: str + ) -> Tuple[FakeChannel, Request]: + """Request some remote media from the given HS by calling the download + API. + + This then triggers an outbound request from the HS to the target. + + Returns: + The channel for the *client* request and the *outbound* request for + the media which the caller should respond to. + """ + channel = make_request( + self.reactor, + self._hs_to_site[hs], + "GET", + f"/_matrix/client/v1/media/download/{target}/{media_id}", + shorthand=False, + access_token=self.access_token, + await_result=False, + ) + self.pump() + + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + + # build the test server + server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + server_tls_protocol = wrap_server_factory_for_tls( + server_factory, self.reactor, sanlist=[b"DNS:example.com"] + ).buildProtocol(None) + + # now, tell the client protocol factory to build the client protocol (it will be a + # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an + # HTTP11ClientProtocol) and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_tls_protocol, self.reactor, client_protocol) + ) + + # tell the server tls protocol to send its stuff back to the client, too + server_tls_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_tls_protocol) + ) + + # fish the test server back out of the server-side TLS protocol. + http_server: HTTPChannel = server_tls_protocol.wrappedProtocol + + # give the reactor a pump to get the TLS juices flowing. + self.reactor.pump((0.1,)) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + + self.assertEqual(request.method, b"GET") + self.assertEqual( + request.path, + f"/_matrix/federation/v1/media/download/{media_id}".encode(), + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] + ) + + return channel, request + + def test_basic(self) -> None: + """Test basic fetching of remote media from a single worker.""" + hs1 = self.make_worker_hs("synapse.app.generic_worker") + + channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") + + request.setResponseCode(200) + request.responseHeaders.setRawHeaders( + b"Content-Type", + ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], + ) + request.write(self.file_data) + request.finish() + + self.pump(0.1) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.result["body"], b"file_to_stream") + + def test_download_simple_file_race(self) -> None: + """Test that fetching remote media from two different processes at the + same time works. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders( + b"Content-Type", + ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], + ) + request1.write(self.file_data) + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], b"file_to_stream") + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders( + b"Content-Type", + ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], + ) + request2.write(self.file_data) + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], b"file_to_stream") + + # We expect only one new file to have been persisted. + self.assertEqual(start_count + 1, self._count_remote_media()) + + def test_download_image_race(self) -> None: + """Test that fetching remote *images* from two different processes at + the same time works. + + This checks that races generating thumbnails are handled correctly. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_thumbnails() + + channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") + + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders( + b"Content-Type", + ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], + ) + img_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: image/png\r\nContent-Disposition: inline; filename=test_img\r\n\r\n" + request1.write(img_data) + request1.write(SMALL_PNG) + request1.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n") + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], SMALL_PNG) + + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders( + b"Content-Type", + ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], + ) + request2.write(img_data) + request2.write(SMALL_PNG) + request2.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n") + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], SMALL_PNG) + + # We expect only three new thumbnails to have been persisted. + self.assertEqual(start_count + 3, self._count_remote_thumbnails()) + + def _count_remote_media(self) -> int: + """Count the number of files in our remote media directory.""" + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_content" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _count_remote_thumbnails(self) -> int: + """Count the number of files in our remote thumbnails directory.""" + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_thumbnail" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _log_request(request: Request) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info("Completed request %s", request) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index be4a289ec1..6b5af2dbb6 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -19,31 +19,54 @@ # # import base64 +import io import json import os import re -from typing import Any, Dict, Optional, Sequence, Tuple, Type +from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type +from unittest.mock import MagicMock, Mock, patch +from urllib import parse from urllib.parse import quote, urlencode +from parameterized import parameterized_class + +from twisted.internet import defer from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IAddress, IResolutionReceiver +from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor +from twisted.web.http_headers import Headers +from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.resource import Resource +from synapse.api.errors import HttpResponseException +from synapse.api.ratelimiting import Ratelimiter from synapse.config.oembed import OEmbedEndpointConfig +from synapse.http.client import MultipartResponse +from synapse.http.types import QueryParams +from synapse.logging.context import make_deferred_yieldable from synapse.media._base import FileInfo from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS from synapse.rest import admin from synapse.rest.client import login, media from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest -from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.media.test_media_storage import ( + SVG, + TestImage, + empty_file, + small_lossless_webp, + small_png, + small_png_with_transparency, +) +from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock from tests.test_utils import SMALL_PNG from tests.unittest import override_config @@ -1607,3 +1630,583 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase): self.assertEqual( channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size ) + + +class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.client = hs.get_federation_http_client() + self.store = hs.get_datastores().main + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + # mock actually reading file body + def read_multipart_response_30MiB(*args: Any, **kwargs: Any) -> Deferred: + d: Deferred = defer.Deferred() + d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None)) + return d + + def read_multipart_response_50MiB(*args: Any, **kwargs: Any) -> Deferred: + d: Deferred = defer.Deferred() + d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None)) + return d + + @patch( + "synapse.http.matrixfederationclient.read_multipart_response", + read_multipart_response_30MiB, + ) + def test_download_ratelimit_default(self) -> None: + """ + Test remote media download ratelimiting against default configuration - 500MB bucket + and 87kb/second drain rate + """ + + # mock out actually sending the request, returns a 30MiB response + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = 31457280 + resp.headers = Headers( + {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]} + ) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # first request should go through + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abc", + shorthand=False, + access_token=self.tok, + ) + assert channel.code == 200 + + # next 15 should go through + for i in range(15): + channel2 = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/remote.org/abc{i}", + shorthand=False, + access_token=self.tok, + ) + assert channel2.code == 200 + + # 17th will hit ratelimit + channel3 = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abcd", + shorthand=False, + access_token=self.tok, + ) + assert channel3.code == 429 + + # however, a request from a different IP will go through + channel4 = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abcde", + shorthand=False, + client_ip="187.233.230.159", + access_token=self.tok, + ) + assert channel4.code == 200 + + # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another + # 30MiB download is authorized - The last download was blocked at 503,316,480. + # The next download will be authorized when bucket hits 492,830,720 + # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760 + # needs to drain before another download will be authorized, that will take ~= + # 2 minutes (10,485,760/89,088/60) + self.reactor.pump([2.0 * 60.0]) + + # enough has drained and next request goes through + channel5 = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abcdef", + shorthand=False, + access_token=self.tok, + ) + assert channel5.code == 200 + + @override_config( + { + "remote_media_download_per_second": "50M", + "remote_media_download_burst_count": "50M", + } + ) + @patch( + "synapse.http.matrixfederationclient.read_multipart_response", + read_multipart_response_50MiB, + ) + def test_download_rate_limit_config(self) -> None: + """ + Test that download rate limit config options are correctly picked up and applied + """ + + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = 52428800 + resp.headers = Headers( + {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]} + ) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # first request should go through + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abc", + shorthand=False, + access_token=self.tok, + ) + assert channel.code == 200 + + # immediate second request should fail + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abcd", + shorthand=False, + access_token=self.tok, + ) + assert channel.code == 429 + + # advance half a second + self.reactor.pump([0.5]) + + # request still fails + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abcde", + shorthand=False, + access_token=self.tok, + ) + assert channel.code == 429 + + # advance another half second + self.reactor.pump([0.5]) + + # enough has drained from bucket and request is successful + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abcdef", + shorthand=False, + access_token=self.tok, + ) + assert channel.code == 200 + + @patch( + "synapse.http.matrixfederationclient.read_multipart_response", + read_multipart_response_30MiB, + ) + def test_download_ratelimit_max_size_sub(self) -> None: + """ + Test that if no content-length is provided, the default max size is applied instead + """ + + # mock out actually sending the request + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = UNKNOWN_LENGTH + resp.headers = Headers( + {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]} + ) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # ten requests should go through using the max size (500MB/50MB) + for i in range(10): + channel2 = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/remote.org/abc{i}", + shorthand=False, + access_token=self.tok, + ) + assert channel2.code == 200 + + # eleventh will hit ratelimit + channel3 = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abcd", + shorthand=False, + access_token=self.tok, + ) + assert channel3.code == 429 + + def test_file_download(self) -> None: + content = io.BytesIO(b"file_to_stream") + content_uri = self.get_success( + self.repo.create_content( + "text/plain", + "test_upload", + content, + 46, + UserID.from_string("@user_id:whatever.org"), + ) + ) + # test with a text file + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/test/{content_uri.media_id}", + shorthand=False, + access_token=self.tok, + ) + self.pump() + self.assertEqual(200, channel.code) + + +test_images = [ + small_png, + small_png_with_transparency, + small_lossless_webp, + empty_file, + SVG, +] +input_values = [(x,) for x in test_images] + + +@parameterized_class(("test_image",), input_values) +class DownloadTestCase(unittest.HomeserverTestCase): + test_image: ClassVar[TestImage] + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.fetches: List[ + Tuple[ + "Deferred[Any]", + str, + str, + Optional[QueryParams], + ] + ] = [] + + def federation_get_file( + destination: str, + path: str, + output_stream: BinaryIO, + download_ratelimiter: Ratelimiter, + ip_address: Any, + max_size: int, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + ignore_backoff: bool = False, + follow_redirects: bool = False, + ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]], bytes]]": + """A mock for MatrixFederationHttpClient.federation_get_file.""" + + def write_to( + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]] + ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]: + data, response = r + output_stream.write(data) + return response + + def write_err(f: Failure) -> Failure: + f.trap(HttpResponseException) + output_stream.write(f.value.response) + return f + + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]] = ( + Deferred() + ) + self.fetches.append((d, destination, path, args)) + # Note that this callback changes the value held by d. + d_after_callback = d.addCallbacks(write_to, write_err) + return make_deferred_yieldable(d_after_callback) + + def get_file( + destination: str, + path: str, + output_stream: BinaryIO, + download_ratelimiter: Ratelimiter, + ip_address: Any, + max_size: int, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + ignore_backoff: bool = False, + follow_redirects: bool = False, + ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": + """A mock for MatrixFederationHttpClient.get_file.""" + + def write_to( + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + ) -> Tuple[int, Dict[bytes, List[bytes]]]: + data, response = r + output_stream.write(data) + return response + + def write_err(f: Failure) -> Failure: + f.trap(HttpResponseException) + output_stream.write(f.value.response) + return f + + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() + self.fetches.append((d, destination, path, args)) + # Note that this callback changes the value held by d. + d_after_callback = d.addCallbacks(write_to, write_err) + return make_deferred_yieldable(d_after_callback) + + # Mock out the homeserver's MatrixFederationHttpClient + client = Mock() + client.federation_get_file = federation_get_file + client.get_file = get_file + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + + config = self.default_config() + config["media_store_path"] = self.media_store_path + config["max_image_pixels"] = 2000000 + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + config["media_storage_providers"] = [provider_config] + config["experimental_features"] = {"msc3916_authenticated_media_enabled": True} + + hs = self.setup_test_homeserver(config=config, federation_http_client=client) + + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.media_repo = hs.get_media_repository() + + self.remote = "example.com" + self.media_id = "12345" + + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + def _req( + self, content_disposition: Optional[bytes], include_content_type: bool = True + ) -> FakeChannel: + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}", + shorthand=False, + await_result=False, + access_token=self.tok, + ) + self.pump() + + # We've made one fetch, to example.com, using the federation media URL + self.assertEqual(len(self.fetches), 1) + self.assertEqual(self.fetches[0][1], "example.com") + self.assertEqual( + self.fetches[0][2], "/_matrix/federation/v1/media/download/" + self.media_id + ) + self.assertEqual( + self.fetches[0][3], + {"timeout_ms": "20000"}, + ) + + headers = { + b"Content-Length": [b"%d" % (len(self.test_image.data))], + } + + if include_content_type: + headers[b"Content-Type"] = [self.test_image.content_type] + + if content_disposition: + headers[b"Content-Disposition"] = [content_disposition] + + self.fetches[0][0].callback( + (self.test_image.data, (len(self.test_image.data), headers, b"{}")) + ) + + self.pump() + self.assertEqual(channel.code, 200) + + return channel + + def test_handle_missing_content_type(self) -> None: + channel = self._req( + b"attachment; filename=out" + self.test_image.extension, + include_content_type=False, + ) + headers = channel.headers + self.assertEqual(channel.code, 200) + self.assertEqual( + headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"] + ) + + def test_disposition_filename_ascii(self) -> None: + """ + If the filename is filename= then Synapse will decode it as an + ASCII string, and use filename= in the response. + """ + channel = self._req(b"attachment; filename=out" + self.test_image.extension) + + headers = channel.headers + self.assertEqual( + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) + self.assertEqual( + headers.getRawHeaders(b"Content-Disposition"), + [ + (b"inline" if self.test_image.is_inline else b"attachment") + + b"; filename=out" + + self.test_image.extension + ], + ) + + def test_disposition_filenamestar_utf8escaped(self) -> None: + """ + If the filename is filename=*utf8'' then Synapse will + correctly decode it as the UTF-8 string, and use filename* in the + response. + """ + filename = parse.quote("\u2603".encode()).encode("ascii") + channel = self._req( + b"attachment; filename*=utf-8''" + filename + self.test_image.extension + ) + + headers = channel.headers + self.assertEqual( + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) + self.assertEqual( + headers.getRawHeaders(b"Content-Disposition"), + [ + (b"inline" if self.test_image.is_inline else b"attachment") + + b"; filename*=utf-8''" + + filename + + self.test_image.extension + ], + ) + + def test_disposition_none(self) -> None: + """ + If there is no filename, Content-Disposition should only + be a disposition type. + """ + channel = self._req(None) + + headers = channel.headers + self.assertEqual( + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) + self.assertEqual( + headers.getRawHeaders(b"Content-Disposition"), + [b"inline" if self.test_image.is_inline else b"attachment"], + ) + + def test_x_robots_tag_header(self) -> None: + """ + Tests that the `X-Robots-Tag` header is present, which informs web crawlers + to not index, archive, or follow links in media. + """ + channel = self._req(b"attachment; filename=out" + self.test_image.extension) + + headers = channel.headers + self.assertEqual( + headers.getRawHeaders(b"X-Robots-Tag"), + [b"noindex, nofollow, noarchive, noimageindex"], + ) + + def test_cross_origin_resource_policy_header(self) -> None: + """ + Test that the Cross-Origin-Resource-Policy header is set to "cross-origin" + allowing web clients to embed media from the downloads API. + """ + channel = self._req(b"attachment; filename=out" + self.test_image.extension) + + headers = channel.headers + + self.assertEqual( + headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), + [b"cross-origin"], + ) + + def test_unknown_federation_endpoint(self) -> None: + """ + Test that if the downloadd request to remote federation endpoint returns a 404 + we fall back to the _matrix/media endpoint + """ + channel = self.make_request( + "GET", + f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}", + shorthand=False, + await_result=False, + access_token=self.tok, + ) + self.pump() + + # We've made one fetch, to example.com, using the media URL, and asking + # the other server not to do a remote fetch + self.assertEqual(len(self.fetches), 1) + self.assertEqual(self.fetches[0][1], "example.com") + self.assertEqual( + self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}" + ) + + # The result which says the endpoint is unknown. + unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}' + self.fetches[0][0].errback( + HttpResponseException(404, "NOT FOUND", unknown_endpoint) + ) + + self.pump() + + # There should now be another request to the _matrix/media/v3/download URL. + self.assertEqual(len(self.fetches), 2) + self.assertEqual(self.fetches[1][1], "example.com") + self.assertEqual( + self.fetches[1][2], + f"/_matrix/media/v3/download/example.com/{self.media_id}", + ) + + headers = { + b"Content-Length": [b"%d" % (len(self.test_image.data))], + } + + self.fetches[1][0].callback( + (self.test_image.data, (len(self.test_image.data), headers)) + ) + + self.pump() + self.assertEqual(channel.code, 200) From 1609855ff8322e3d4d91f8aea322f9750ac24ba2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 1 Jul 2024 12:48:36 +0100 Subject: [PATCH 7/7] Limit size of presence EDUs (#17371) Otherwise they are unbounded. --------- Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/17371.misc | 1 + .../sender/per_destination_queue.py | 31 +++-- tests/federation/test_federation_sender.py | 119 ++++++++++++++++++ 3 files changed, 140 insertions(+), 11 deletions(-) create mode 100644 changelog.d/17371.misc diff --git a/changelog.d/17371.misc b/changelog.d/17371.misc new file mode 100644 index 0000000000..0fbf19f4fb --- /dev/null +++ b/changelog.d/17371.misc @@ -0,0 +1 @@ +Limit size of presence EDUs to 50 entries. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index d9f2f017ed..9f1c2fe22a 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -21,6 +21,7 @@ # import datetime import logging +from collections import OrderedDict from types import TracebackType from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type @@ -68,6 +69,10 @@ sent_edus_by_type = Counter( # If the retry interval is larger than this then we enter "catchup" mode CATCHUP_RETRY_INTERVAL = 60 * 60 * 1000 +# Limit how many presence states we add to each presence EDU, to ensure that +# they are bounded in size. +MAX_PRESENCE_STATES_PER_EDU = 50 + class PerDestinationQueue: """ @@ -144,7 +149,7 @@ class PerDestinationQueue: # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence: Dict[str, UserPresenceState] = {} + self._pending_presence: OrderedDict[str, UserPresenceState] = OrderedDict() # List of room_id -> receipt_type -> user_id -> receipt_dict, # @@ -399,7 +404,7 @@ class PerDestinationQueue: # through another mechanism, because this is all volatile! self._pending_edus = [] self._pending_edus_keyed = {} - self._pending_presence = {} + self._pending_presence.clear() self._pending_receipt_edus = [] self._start_catching_up() @@ -721,22 +726,26 @@ class _TransactionQueueManager: # Add presence EDU. if self.queue._pending_presence: + # Only send max 50 presence entries in the EDU, to bound the amount + # of data we're sending. + presence_to_add: List[JsonDict] = [] + while ( + self.queue._pending_presence + and len(presence_to_add) < MAX_PRESENCE_STATES_PER_EDU + ): + _, presence = self.queue._pending_presence.popitem(last=False) + presence_to_add.append( + format_user_presence_state(presence, self.queue._clock.time_msec()) + ) + pending_edus.append( Edu( origin=self.queue._server_name, destination=self.queue._destination, edu_type=EduTypes.PRESENCE, - content={ - "push": [ - format_user_presence_state( - presence, self.queue._clock.time_msec() - ) - for presence in self.queue._pending_presence.values() - ] - }, + content={"push": presence_to_add}, ) ) - self.queue._pending_presence = {} # Add read receipt EDUs. pending_edus.extend(self.queue._get_receipt_edus(force_flush=False, limit=5)) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 9073afc70e..6a8887fe74 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -27,6 +27,8 @@ from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms +from synapse.api.presence import UserPresenceState +from synapse.federation.sender.per_destination_queue import MAX_PRESENCE_STATES_PER_EDU from synapse.federation.units import Transaction from synapse.handlers.device import DeviceHandler from synapse.rest import admin @@ -266,6 +268,123 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ) +class FederationSenderPresenceTestCases(HomeserverTestCase): + """ + Test federation sending for presence updates. + """ + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.federation_transport_client = Mock(spec=["send_transaction"]) + self.federation_transport_client.send_transaction = AsyncMock() + hs = self.setup_test_homeserver( + federation_transport_client=self.federation_transport_client, + ) + + return hs + + def default_config(self) -> JsonDict: + config = super().default_config() + config["federation_sender_instances"] = None + return config + + def test_presence_simple(self) -> None: + "Test that sending a single presence update works" + + mock_send_transaction: AsyncMock = ( + self.federation_transport_client.send_transaction + ) + mock_send_transaction.return_value = {} + + sender = self.hs.get_federation_sender() + self.get_success( + sender.send_presence_to_destinations( + [UserPresenceState.default("@user:test")], + ["server"], + ) + ) + + self.pump() + + # expect a call to send_transaction + mock_send_transaction.assert_awaited_once() + + json_cb = mock_send_transaction.call_args[0][1] + data = json_cb() + self.assertEqual( + data["edus"], + [ + { + "edu_type": EduTypes.PRESENCE, + "content": { + "push": [ + { + "presence": "offline", + "user_id": "@user:test", + } + ] + }, + } + ], + ) + + def test_presence_batched(self) -> None: + """Test that sending lots of presence updates to a destination are + batched, rather than having them all sent in one EDU.""" + + mock_send_transaction: AsyncMock = ( + self.federation_transport_client.send_transaction + ) + mock_send_transaction.return_value = {} + + sender = self.hs.get_federation_sender() + + # We now send lots of presence updates to force the federation sender to + # batch the mup. + number_presence_updates_to_send = MAX_PRESENCE_STATES_PER_EDU * 2 + self.get_success( + sender.send_presence_to_destinations( + [ + UserPresenceState.default(f"@user{i}:test") + for i in range(number_presence_updates_to_send) + ], + ["server"], + ) + ) + + self.pump() + + # We should have seen at least one transcation be sent by now. + mock_send_transaction.assert_called() + + # We don't want to specify exactly how the presence EDUs get sent out, + # could be one per transaction or multiple per transaction. We just want + # to assert that a) each presence EDU has bounded number of updates, and + # b) that all updates get sent out. + presence_edus = [] + for transaction_call in mock_send_transaction.call_args_list: + json_cb = transaction_call[0][1] + data = json_cb() + + for edu in data["edus"]: + self.assertEqual(edu.get("edu_type"), EduTypes.PRESENCE) + presence_edus.append(edu) + + # A set of all user presence we see, this should end up matching the + # number we sent out above. + seen_users: Set[str] = set() + + for edu in presence_edus: + presence_states = edu["content"]["push"] + + # This is where we actually check that the number of presence + # updates is bounded. + self.assertLessEqual(len(presence_states), MAX_PRESENCE_STATES_PER_EDU) + + seen_users.update(p["user_id"] for p in presence_states) + + self.assertEqual(len(seen_users), number_presence_updates_to_send) + + class FederationSenderDevicesTestCases(HomeserverTestCase): """ Test federation sending to update devices.