diff --git a/Cargo.lock b/Cargo.lock index 7472e16291..1955c1a4e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -212,9 +212,9 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" diff --git a/changelog.d/17255.feature b/changelog.d/17255.feature new file mode 100644 index 0000000000..4093de1146 --- /dev/null +++ b/changelog.d/17255.feature @@ -0,0 +1 @@ +Add support for [MSC823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823) - Account suspension. \ No newline at end of file diff --git a/changelog.d/17333.misc b/changelog.d/17333.misc new file mode 100644 index 0000000000..d3ef0b3777 --- /dev/null +++ b/changelog.d/17333.misc @@ -0,0 +1 @@ +Handle device lists notifications for large accounts more efficiently in worker mode. diff --git a/changelog.d/17336.bugfix b/changelog.d/17336.bugfix new file mode 100644 index 0000000000..618834302e --- /dev/null +++ b/changelog.d/17336.bugfix @@ -0,0 +1 @@ +Fix email notification subject when invited to a space. diff --git a/changelog.d/17338.misc b/changelog.d/17338.misc new file mode 100644 index 0000000000..1a81bdef85 --- /dev/null +++ b/changelog.d/17338.misc @@ -0,0 +1 @@ +Do not block event sending/receiving while calculating large event auth chains. diff --git a/changelog.d/17339.misc b/changelog.d/17339.misc new file mode 100644 index 0000000000..1d7cb96c8b --- /dev/null +++ b/changelog.d/17339.misc @@ -0,0 +1 @@ +Tidy up `parse_integer` docs and call sites to reflect the fact that they require non-negative integers by default, and bring `parse_integer_from_args` default in alignment. Contributed by Denis Kasak (@dkasak). diff --git a/poetry.lock b/poetry.lock index 58981ff6e1..1bae0ea388 100644 --- a/poetry.lock +++ b/poetry.lock @@ -35,13 +35,13 @@ tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "p [[package]] name = "authlib" -version = "1.3.0" +version = "1.3.1" description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." optional = true python-versions = ">=3.8" files = [ - {file = "Authlib-1.3.0-py2.py3-none-any.whl", hash = "sha256:9637e4de1fb498310a56900b3e2043a206b03cb11c05422014b0302cbc814be3"}, - {file = "Authlib-1.3.0.tar.gz", hash = "sha256:959ea62a5b7b5123c5059758296122b57cd2585ae2ed1c0622c21b371ffdae06"}, + {file = "Authlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:d35800b973099bbadc49b42b256ecb80041ad56b7fe1216a362c7943c088f377"}, + {file = "authlib-1.3.1.tar.gz", hash = "sha256:7ae843f03c06c5c0debd63c9db91f9fda64fa62a42a77419fa15fbb7e7a58917"}, ] [package.dependencies] @@ -1461,13 +1461,13 @@ test = ["lxml", "pytest (>=4.6)", "pytest-cov"] [[package]] name = "netaddr" -version = "1.2.1" +version = "1.3.0" description = "A network address manipulation library for Python" optional = false python-versions = ">=3.7" files = [ - {file = "netaddr-1.2.1-py3-none-any.whl", hash = "sha256:bd9e9534b0d46af328cf64f0e5a23a5a43fca292df221c85580b27394793496e"}, - {file = "netaddr-1.2.1.tar.gz", hash = "sha256:6eb8fedf0412c6d294d06885c110de945cf4d22d2b510d0404f4e06950857987"}, + {file = "netaddr-1.3.0-py3-none-any.whl", hash = "sha256:c2c6a8ebe5554ce33b7d5b3a306b71bbb373e000bbbf2350dd5213cc56e3dbbe"}, + {file = "netaddr-1.3.0.tar.gz", hash = "sha256:5c3c3d9895b551b763779ba7db7a03487dc1f8e3b385af819af341ae9ef6e48a"}, ] [package.extras] @@ -1488,13 +1488,13 @@ tests = ["Sphinx", "doubles", "flake8", "flake8-quotes", "gevent", "mock", "pyte [[package]] name = "packaging" -version = "24.0" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, - {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -2157,13 +2157,13 @@ rpds-py = ">=0.7.0" [[package]] name = "requests" -version = "2.31.0" +version = "2.32.2" description = "Python HTTP for Humans." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, + {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, ] [package.dependencies] @@ -2387,13 +2387,13 @@ doc = ["Sphinx", "sphinx-rtd-theme"] [[package]] name = "sentry-sdk" -version = "2.3.1" +version = "2.6.0" description = "Python client for Sentry (https://sentry.io)" optional = true python-versions = ">=3.6" files = [ - {file = "sentry_sdk-2.3.1-py2.py3-none-any.whl", hash = "sha256:c5aeb095ba226391d337dd42a6f9470d86c9fc236ecc71cfc7cd1942b45010c6"}, - {file = "sentry_sdk-2.3.1.tar.gz", hash = "sha256:139a71a19f5e9eb5d3623942491ce03cf8ebc14ea2e39ba3e6fe79560d8a5b1f"}, + {file = "sentry_sdk-2.6.0-py2.py3-none-any.whl", hash = "sha256:422b91cb49378b97e7e8d0e8d5a1069df23689d45262b86f54988a7db264e874"}, + {file = "sentry_sdk-2.6.0.tar.gz", hash = "sha256:65cc07e9c6995c5e316109f138570b32da3bd7ff8d0d0ee4aaf2628c3dd8127d"}, ] [package.dependencies] @@ -2598,22 +2598,22 @@ files = [ [[package]] name = "tornado" -version = "6.4" +version = "6.4.1" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." optional = true -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ - {file = "tornado-6.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:02ccefc7d8211e5a7f9e8bc3f9e5b0ad6262ba2fbb683a6443ecc804e5224ce0"}, - {file = "tornado-6.4-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:27787de946a9cffd63ce5814c33f734c627a87072ec7eed71f7fc4417bb16263"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7894c581ecdcf91666a0912f18ce5e757213999e183ebfc2c3fdbf4d5bd764e"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e43bc2e5370a6a8e413e1e1cd0c91bedc5bd62a74a532371042a18ef19e10579"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fd03192e287fbd0899dd8f81c6fb9cbbc69194d2074b38f384cb6fa72b80e9c2"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:88b84956273fbd73420e6d4b8d5ccbe913c65d31351b4c004ae362eba06e1f78"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:71ddfc23a0e03ef2df1c1397d859868d158c8276a0603b96cf86892bff58149f"}, - {file = "tornado-6.4-cp38-abi3-win32.whl", hash = "sha256:6f8a6c77900f5ae93d8b4ae1196472d0ccc2775cc1dfdc9e7727889145c45052"}, - {file = "tornado-6.4-cp38-abi3-win_amd64.whl", hash = "sha256:10aeaa8006333433da48dec9fe417877f8bcc21f48dda8d661ae79da357b2a63"}, - {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"}, + {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8"}, + {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6d5ce3437e18a2b66fbadb183c1d3364fb03f2be71299e7d10dbeeb69f4b2a14"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e20b9113cd7293f164dc46fffb13535266e713cdb87bd2d15ddb336e96cfc4"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae50a504a740365267b2a8d1a90c9fbc86b780a39170feca9bcc1787ff80842"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:613bf4ddf5c7a95509218b149b555621497a6cc0d46ac341b30bd9ec19eac7f3"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:454db8a7ecfcf2ff6042dde58404164d969b6f5d58b926da15e6b23817950fc4"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a02a08cc7a9314b006f653ce40483b9b3c12cda222d6a46d4ac63bb6c9057698"}, + {file = "tornado-6.4.1-cp38-abi3-win32.whl", hash = "sha256:d9a566c40b89757c9aa8e6f032bcdb8ca8795d7c1a9762910c722b1635c9de4d"}, + {file = "tornado-6.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:b24b8982ed444378d7f21d563f4180a2de31ced9d8d84443907a0a64da2072e7"}, + {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"}, ] [[package]] @@ -2917,13 +2917,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.11.0" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"}, - {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -2939,18 +2939,18 @@ files = [ [[package]] name = "urllib3" -version = "2.0.7" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"}, - {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 23e96da6a3..1b72727b75 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -433,6 +433,10 @@ class ExperimentalConfig(Config): ("experimental", "msc4108_delegation_endpoint"), ) + self.msc3823_account_suspension = experimental.get( + "msc3823_account_suspension", False + ) + self.msc3916_authenticated_media_enabled = experimental.get( "msc3916_authenticated_media_enabled", False ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 16d01efc67..5aa48230ec 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -642,6 +642,17 @@ class EventCreationHandler: """ await self.auth_blocking.check_auth_blocking(requester=requester) + if event_dict["type"] == EventTypes.Message: + requester_suspended = await self.store.get_user_suspended_status( + requester.user.to_string() + ) + if requester_suspended: + raise SynapseError( + 403, + "Sending messages while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version_id = event_dict["content"]["room_version"] maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index ab12951da8..08b8ff7afd 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -119,14 +119,15 @@ def parse_integer( default: value to use if the parameter is absent, defaults to None. required: whether to raise a 400 SynapseError if the parameter is absent, defaults to False. - negative: whether to allow negative integers, defaults to True. + negative: whether to allow negative integers, defaults to False (disallowing + negatives). Returns: An int value or the default. Raises: SynapseError: if the parameter is absent and required, if the parameter is present and not an integer, or if the - parameter is illegitimate negative. + parameter is illegitimately negative. """ args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore return parse_integer_from_args(args, name, default, required, negative) @@ -164,7 +165,7 @@ def parse_integer_from_args( name: str, default: Optional[int] = None, required: bool = False, - negative: bool = True, + negative: bool = False, ) -> Optional[int]: """Parse an integer parameter from the request string @@ -174,7 +175,8 @@ def parse_integer_from_args( default: value to use if the parameter is absent, defaults to None. required: whether to raise a 400 SynapseError if the parameter is absent, defaults to False. - negative: whether to allow negative integers, defaults to True. + negative: whether to allow negative integers, defaults to False (disallowing + negatives). Returns: An int value or the default. @@ -182,7 +184,7 @@ def parse_integer_from_args( Raises: SynapseError: if the parameter is absent and required, if the parameter is present and not an integer, or if the - parameter is illegitimate negative. + parameter is illegitimately negative. """ name_bytes = name.encode("ascii") diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 77cc69a71f..cf611bd90b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -28,7 +28,7 @@ import jinja2 from markupsafe import Markup from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership, RoomTypes +from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes from synapse.api.errors import StoreError from synapse.config.emailconfig import EmailSubjectConfig from synapse.events import EventBase @@ -716,7 +716,8 @@ class Mailer: ) if ( create_event - and create_event.content.get("room_type") == RoomTypes.SPACE + and create_event.content.get(EventContentFields.ROOM_TYPE) + == RoomTypes.SPACE ): return self.email_subjects.invite_from_person_to_space % { "person": inviter_name, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2d6d49eed7..3dddbb70b4 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -114,13 +114,19 @@ class ReplicationDataHandler: """ all_room_ids: Set[str] = set() if stream_name == DeviceListsStream.NAME: - if any(row.entity.startswith("@") and not row.is_signature for row in rows): + if any(not row.is_signature and not row.hosts_calculated for row in rows): prev_token = self.store.get_device_stream_token() all_room_ids = await self.store.get_all_device_list_changes( prev_token, token ) self.store.device_lists_in_rooms_have_changed(all_room_ids, token) + # If we're sending federation we need to update the device lists + # outbound pokes stream change cache with updated hosts. + if self.send_handler and any(row.hosts_calculated for row in rows): + hosts = await self.store.get_destinations_for_device(token) + self.store.device_lists_outbound_pokes_have_changed(hosts, token) + self.store.process_replication_rows(stream_name, instance_name, token, rows) # NOTE: this must be called after process_replication_rows to ensure any # cache invalidations are first handled before any stream ID advances. @@ -433,12 +439,11 @@ class FederationSenderHandler: # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. - hosts = { - row.entity - for row in rows - if not row.entity.startswith("@") and not row.is_signature - } - await self.federation_sender.send_device_messages(hosts, immediate=False) + if any(row.hosts_calculated for row in rows): + hosts = await self.store.get_destinations_for_device(token) + await self.federation_sender.send_device_messages( + hosts, immediate=False + ) elif stream_name == ToDeviceStream.NAME: # The to_device stream includes stuff to be pushed to both local diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 661206c841..d021904de7 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: - entity: str + user_id: str # Indicates that a user has signed their own device with their user-signing key is_signature: bool + # Indicates if this is a notification that we've calculated the hosts we + # need to send the update to. + hosts_calculated: bool + NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -594,13 +598,13 @@ class DeviceListsStream(_StreamFromIdGen): upper_limit_token = min(upper_limit_token, signatures_to_token) device_updates = [ - (stream_id, (entity, False)) - for stream_id, (entity,) in device_updates + (stream_id, (entity, False, hosts)) + for stream_id, (entity, hosts) in device_updates if stream_id <= upper_limit_token ] signatures_updates = [ - (stream_id, (entity, True)) + (stream_id, (entity, True, False)) for stream_id, (entity,) in signatures_updates if stream_id <= upper_limit_token ] diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6da1d79168..cdaee17451 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -101,6 +101,7 @@ from synapse.rest.admin.users import ( ResetPasswordRestServlet, SearchUsersRestServlet, ShadowBanRestServlet, + SuspendAccountRestServlet, UserAdminServlet, UserByExternalId, UserByThreePid, @@ -327,6 +328,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: BackgroundUpdateRestServlet(hs).register(http_server) BackgroundUpdateStartJobRestServlet(hs).register(http_server) ExperimentalFeaturesRestServlet(hs).register(http_server) + if hs.config.experimental.msc3823_account_suspension: + SuspendAccountRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 14ab4644cb..d85a04b825 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -61,8 +61,8 @@ class ListDestinationsRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) - start = parse_integer(request, "from", default=0, negative=False) - limit = parse_integer(request, "limit", default=100, negative=False) + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) destination = parse_string(request, "destination") @@ -181,8 +181,8 @@ class DestinationMembershipRestServlet(RestServlet): if not await self._store.is_destination_known(destination): raise NotFoundError("Unknown destination") - start = parse_integer(request, "from", default=0, negative=False) - limit = parse_integer(request, "limit", default=100, negative=False) + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index a05b7252ec..ee6a681285 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -311,8 +311,8 @@ class DeleteMediaByDateSize(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - before_ts = parse_integer(request, "before_ts", required=True, negative=False) - size_gt = parse_integer(request, "size_gt", default=0, negative=False) + before_ts = parse_integer(request, "before_ts", required=True) + size_gt = parse_integer(request, "size_gt", default=0) keep_profiles = parse_boolean(request, "keep_profiles", default=True) if before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds @@ -377,8 +377,8 @@ class UserMediaRestServlet(RestServlet): if user is None: raise NotFoundError("Unknown user") - start = parse_integer(request, "from", default=0, negative=False) - limit = parse_integer(request, "limit", default=100, negative=False) + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) # If neither `order_by` nor `dir` is set, set the default order # to newest media is on top for backward compatibility. @@ -421,8 +421,8 @@ class UserMediaRestServlet(RestServlet): if user is None: raise NotFoundError("Unknown user") - start = parse_integer(request, "from", default=0, negative=False) - limit = parse_integer(request, "limit", default=100, negative=False) + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) # If neither `order_by` nor `dir` is set, set the default order # to newest media is on top for backward compatibility. diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index dc27a41dd9..0adc5b7005 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -63,10 +63,10 @@ class UserMediaStatisticsRestServlet(RestServlet): ), ) - start = parse_integer(request, "from", default=0, negative=False) - limit = parse_integer(request, "limit", default=100, negative=False) - from_ts = parse_integer(request, "from_ts", default=0, negative=False) - until_ts = parse_integer(request, "until_ts", negative=False) + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + from_ts = parse_integer(request, "from_ts", default=0) + until_ts = parse_integer(request, "until_ts") if until_ts is not None: if until_ts <= from_ts: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 5bf12c4979..ad515bd5a3 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -27,11 +27,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 from synapse.api.constants import Direction, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_and_validate_json_object_from_request, parse_boolean, parse_enum, parse_integer, @@ -49,10 +51,17 @@ from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.registration import ExternalIDReuseException from synapse.storage.databases.main.stats import UserSortOrder from synapse.types import JsonDict, JsonMapping, UserID +from synapse.types.rest import RequestBodyModel if TYPE_CHECKING: from synapse.server import HomeServer +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import StrictBool +else: + from pydantic import StrictBool + + logger = logging.getLogger(__name__) @@ -90,8 +99,8 @@ class UsersRestServletV2(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - start = parse_integer(request, "from", default=0, negative=False) - limit = parse_integer(request, "limit", default=100, negative=False) + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) user_id = parse_string(request, "user_id") name = parse_string(request, "name", encoding="utf-8") @@ -732,6 +741,36 @@ class DeactivateAccountRestServlet(RestServlet): return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result} +class SuspendAccountRestServlet(RestServlet): + PATTERNS = admin_patterns("/suspend/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastores().main + + class PutBody(RequestBodyModel): + suspend: StrictBool + + async def on_PUT( + self, request: SynapseRequest, target_user_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester) + + if not self.is_mine(UserID.from_string(target_user_id)): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only suspend local users") + + if not await self.store.get_user_by_id(target_user_id): + raise NotFoundError("User not found") + + body = parse_and_validate_json_object_from_request(request, self.PutBody) + suspend = body.suspend + await self.store.set_user_suspended_status(target_user_id, suspend) + + return HTTPStatus.OK, {f"user_{target_user_id}_suspended": suspend} + + class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index 0323f6afa1..c1a80c5c3d 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -108,6 +108,19 @@ class ProfileDisplaynameRestServlet(RestServlet): propagate = _read_propagate(self.hs, request) + requester_suspended = ( + await self.hs.get_datastores().main.get_user_suspended_status( + requester.user.to_string() + ) + ) + + if requester_suspended: + raise SynapseError( + 403, + "Updating displayname while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + await self.profile_handler.set_displayname( user, requester, new_name, is_admin, propagate=propagate ) @@ -167,6 +180,19 @@ class ProfileAvatarURLRestServlet(RestServlet): propagate = _read_propagate(self.hs, request) + requester_suspended = ( + await self.hs.get_datastores().main.get_user_suspended_status( + requester.user.to_string() + ) + ) + + if requester_suspended: + raise SynapseError( + 403, + "Updating avatar URL while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + await self.profile_handler.set_avatar_url( user, requester, new_avatar_url, is_admin, propagate=propagate ) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index c98241f6ce..903c74f6d8 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -510,7 +510,7 @@ class PublicRoomListRestServlet(RestServlet): if server: raise e - limit: Optional[int] = parse_integer(request, "limit", 0, negative=False) + limit: Optional[int] = parse_integer(request, "limit", 0) since_token = parse_string(request, "since") if limit == 0: @@ -1120,6 +1120,20 @@ class RoomRedactEventRestServlet(TransactionRestServlet): ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) + requester_suspended = await self._store.get_user_suspended_status( + requester.user.to_string() + ) + + if requester_suspended: + event = await self._store.get_event(event_id, allow_none=True) + if event: + if event.sender != requester.user.to_string(): + raise SynapseError( + 403, + "You can only redact your own events while account is suspended.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + # Ensure the redacts property in the content matches the one provided in # the URL. room_version = await self._store.get_room_version(room_id) @@ -1430,16 +1444,7 @@ class RoomHierarchyRestServlet(RestServlet): requester = await self._auth.get_user_by_req(request, allow_guest=True) max_depth = parse_integer(request, "max_depth") - if max_depth is not None and max_depth < 0: - raise SynapseError( - 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON - ) - limit = parse_integer(request, "limit") - if limit is not None and limit <= 0: - raise SynapseError( - 400, "'limit' must be a positive integer", Codes.BAD_JSON - ) return 200, await self._room_summary_handler.get_room_hierarchy( requester, diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 84699a2ee1..d0e015bf19 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -617,6 +617,17 @@ class EventsPersistenceStorageController: room_id, chunk ) + with Measure(self._clock, "calculate_chain_cover_index_for_events"): + # We now calculate chain ID/sequence numbers for any state events we're + # persisting. We ignore out of band memberships as we're not in the room + # and won't have their auth chain (we'll fix it up later if we join the + # room). + # + # See: docs/auth_chain_difference_algorithm.md + new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events( + room_id, [e for e, _ in chunk] + ) + await self.persist_events_store._persist_events_and_state_updates( room_id, chunk, @@ -624,6 +635,7 @@ class EventsPersistenceStorageController: new_forward_extremities=new_forward_extremities, use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, + new_event_links=new_event_links, ) return replaced_events diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 40187496e2..5eeca6165d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): prefilled_cache=user_signature_stream_prefill, ) - ( - device_list_federation_prefill, - device_list_federation_list_id, - ) = self.db_pool.get_cache_dict( - db_conn, - "device_lists_outbound_pokes", - entity_column="destination", - stream_column="stream_id", - max_value=device_list_max, - limit=10000, - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", - device_list_federation_list_id, - prefilled_cache=device_list_federation_prefill, - ) + self._device_list_federation_stream_cache = None + if hs.should_send_federation(): + ( + device_list_federation_prefill, + device_list_federation_list_id, + ) = self.db_pool.get_cache_dict( + db_conn, + "device_lists_outbound_pokes", + entity_column="destination", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", + device_list_federation_list_id, + prefilled_cache=device_list_federation_prefill, + ) if hs.config.worker.run_background_tasks: self._clock.looping_call( @@ -207,23 +209,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) -> None: for row in rows: if row.is_signature: - self._user_signature_stream_cache.entity_has_changed(row.entity, token) + self._user_signature_stream_cache.entity_has_changed(row.user_id, token) continue # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. - if row.entity.startswith("@"): - self._device_list_stream_cache.entity_has_changed(row.entity, token) - self.get_cached_devices_for_user.invalidate((row.entity,)) - self._get_cached_user_device.invalidate((row.entity,)) - self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) - - else: - self._device_list_federation_stream_cache.entity_has_changed( - row.entity, token + if not row.hosts_calculated: + self._device_list_stream_cache.entity_has_changed(row.user_id, token) + self.get_cached_devices_for_user.invalidate((row.user_id,)) + self._get_cached_user_device.invalidate((row.user_id,)) + self.get_device_list_last_stream_id_for_remote.invalidate( + (row.user_id,) ) + def device_lists_outbound_pokes_have_changed( + self, destinations: StrCollection, token: int + ) -> None: + assert self._device_list_federation_stream_cache is not None + + for destination in destinations: + self._device_list_federation_stream_cache.entity_has_changed( + destination, token + ) + def device_lists_in_rooms_have_changed( self, room_ids: StrCollection, token: int ) -> None: @@ -363,6 +372,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): EDU contents. """ now_stream_id = self.get_device_stream_token() + if from_stream_id == now_stream_id: + return now_stream_id, [] + + if self._device_list_federation_stream_cache is None: + raise Exception("Func can only be used on federation senders") has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -1018,10 +1032,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): # This query Does The Right Thing where it'll correctly apply the # bounds to the inner queries. sql = """ - SELECT stream_id, entity FROM ( - SELECT stream_id, user_id AS entity FROM device_lists_stream + SELECT stream_id, user_id, hosts FROM ( + SELECT stream_id, user_id, false AS hosts FROM device_lists_stream UNION ALL - SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes + SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC @@ -1577,6 +1591,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): get_device_list_changes_in_room_txn, ) + async def get_destinations_for_device(self, stream_id: int) -> StrCollection: + return await self.db_pool.simple_select_onecol( + table="device_lists_outbound_pokes", + keyvalues={"stream_id": stream_id}, + retcol="destination", + desc="get_destinations_for_device", + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__( @@ -2112,12 +2134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_ids: List[int], context: Optional[Dict[str, str]], ) -> None: - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_ids[-1], - ) + if self._device_list_federation_stream_cache: + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) now = self._clock.time_msec() stream_id_iterator = iter(stream_ids) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 38d8785faa..9e6c9561ae 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if stream_name == DeviceListsStream.NAME: for row in rows: assert isinstance(row, DeviceListsStream.DeviceListsStreamRow) - if row.entity.startswith("@"): + if not row.hosts_calculated: self._get_e2e_device_keys_for_federation_query_inner.invalidate( - (row.entity,) + (row.user_id,) ) super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index fb132ef090..24abab4a23 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -148,6 +148,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas 500000, "_event_auth_cache", size_callback=len ) + # Flag used by unit tests to disable fallback when there is no chain cover + # index. + self.tests_allow_no_chain_cover_index = True + self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) if isinstance(self.database_engine, PostgresEngine): @@ -220,8 +224,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) except _NoChainCoverIndex: # For whatever reason we don't actually have a chain cover index - # for the events in question, so we fall back to the old method. - pass + # for the events in question, so we fall back to the old method + # (except in tests) + if not self.tests_allow_no_chain_cover_index: + raise return await self.db_pool.runInteraction( "get_auth_chain_ids", @@ -271,7 +277,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas if events_missing_chain_info: # This can happen due to e.g. downgrade/upgrade of the server. We # raise an exception and fall back to the previous algorithm. - logger.info( + logger.error( "Unexpectedly found that events don't have chain IDs in room %s: %s", room_id, events_missing_chain_info, @@ -482,8 +488,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) except _NoChainCoverIndex: # For whatever reason we don't actually have a chain cover index - # for the events in question, so we fall back to the old method. - pass + # for the events in question, so we fall back to the old method + # (except in tests) + if not self.tests_allow_no_chain_cover_index: + raise return await self.db_pool.runInteraction( "get_auth_chain_difference", @@ -710,7 +718,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas if events_missing_chain_info - event_to_auth_ids.keys(): # Uh oh, we somehow haven't correctly done the chain cover index, # bail and fall back to the old method. - logger.info( + logger.error( "Unexpectedly found that events don't have chain IDs in room %s: %s", room_id, events_missing_chain_info - event_to_auth_ids.keys(), diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 66428e6c8e..1f7acdb859 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -34,7 +34,6 @@ from typing import ( Optional, Set, Tuple, - Union, cast, ) @@ -100,6 +99,23 @@ class DeltaState: return not self.to_delete and not self.to_insert and not self.no_longer_in_room +@attr.s(slots=True, auto_attribs=True) +class NewEventChainLinks: + """Information about new auth chain links that need to be added to the DB. + + Attributes: + chain_id, sequence_number: the IDs corresponding to the event being + inserted, and the starting point of the links + links: Lists the links that need to be added, 2-tuple of the chain + ID/sequence number of the end point of the link. + """ + + chain_id: int + sequence_number: int + + links: List[Tuple[int, int]] = attr.Factory(list) + + class PersistEventsStore: """Contains all the functions for writing events to the database. @@ -148,6 +164,7 @@ class PersistEventsStore: *, state_delta_for_room: Optional[DeltaState], new_forward_extremities: Optional[Set[str]], + new_event_links: Dict[str, NewEventChainLinks], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, ) -> None: @@ -217,6 +234,7 @@ class PersistEventsStore: inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, new_forward_extremities=new_forward_extremities, + new_event_links=new_event_links, ) persist_event_counter.inc(len(events_and_contexts)) @@ -243,6 +261,87 @@ class PersistEventsStore: (room_id,), frozenset(new_forward_extremities) ) + async def calculate_chain_cover_index_for_events( + self, room_id: str, events: Collection[EventBase] + ) -> Dict[str, NewEventChainLinks]: + # Filter to state events, and ensure there are no duplicates. + state_events = [] + seen_events = set() + for event in events: + if not event.is_state() or event.event_id in seen_events: + continue + + state_events.append(event) + seen_events.add(event.event_id) + + if not state_events: + return {} + + return await self.db_pool.runInteraction( + "_calculate_chain_cover_index_for_events", + self.calculate_chain_cover_index_for_events_txn, + room_id, + state_events, + ) + + def calculate_chain_cover_index_for_events_txn( + self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase] + ) -> Dict[str, NewEventChainLinks]: + # We now calculate chain ID/sequence numbers for any state events we're + # persisting. We ignore out of band memberships as we're not in the room + # and won't have their auth chain (we'll fix it up later if we join the + # room). + # + # See: docs/auth_chain_difference_algorithm.md + + # We ignore legacy rooms that we aren't filling the chain cover index + # for. + row = self.db_pool.simple_select_one_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("room_id", "has_auth_chain_index"), + allow_none=True, + ) + if row is None or row[1] is False: + return {} + + # Filter out events that we've already calculated. + rows = self.db_pool.simple_select_many_txn( + txn, + table="event_auth_chains", + column="event_id", + iterable=[e.event_id for e in state_events], + keyvalues={}, + retcols=("event_id",), + ) + already_persisted_events = {event_id for event_id, in rows} + state_events = [ + event + for event in state_events + if event.event_id not in already_persisted_events + ] + + if not state_events: + return {} + + # We need to know the type/state_key and auth events of the events we're + # calculating chain IDs for. We don't rely on having the full Event + # instances as we'll potentially be pulling more events from the DB and + # we don't need the overhead of fetching/parsing the full event JSON. + event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events} + event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events} + event_to_room_id = {e.event_id: e.room_id for e in state_events} + + return self._calculate_chain_cover_index( + txn, + self.db_pool, + self.store.event_chain_id_gen, + event_to_room_id, + event_to_types, + event_to_auth_chain, + ) + async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: """Filter the supplied list of event_ids to get those which are prev_events of existing (non-outlier/rejected) events. @@ -358,6 +457,7 @@ class PersistEventsStore: inhibit_local_membership_updates: bool, state_delta_for_room: Optional[DeltaState], new_forward_extremities: Optional[Set[str]], + new_event_links: Dict[str, NewEventChainLinks], ) -> None: """Insert some number of room events into the necessary database tables. @@ -466,7 +566,9 @@ class PersistEventsStore: # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) - self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) + self._persist_event_auth_chain_txn( + txn, [e for e, _ in events_and_contexts], new_event_links + ) # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. @@ -496,7 +598,11 @@ class PersistEventsStore: self, txn: LoggingTransaction, events: List[EventBase], + new_event_links: Dict[str, NewEventChainLinks], ) -> None: + if new_event_links: + self._persist_chain_cover_index(txn, self.db_pool, new_event_links) + # We only care about state events, so this if there are no state events. if not any(e.is_state() for e in events): return @@ -519,60 +625,6 @@ class PersistEventsStore: ], ) - # We now calculate chain ID/sequence numbers for any state events we're - # persisting. We ignore out of band memberships as we're not in the room - # and won't have their auth chain (we'll fix it up later if we join the - # room). - # - # See: docs/auth_chain_difference_algorithm.md - - # We ignore legacy rooms that we aren't filling the chain cover index - # for. - rows = cast( - List[Tuple[str, Optional[Union[int, bool]]]], - self.db_pool.simple_select_many_txn( - txn, - table="rooms", - column="room_id", - iterable={event.room_id for event in events if event.is_state()}, - keyvalues={}, - retcols=("room_id", "has_auth_chain_index"), - ), - ) - rooms_using_chain_index = { - room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index - } - - state_events = { - event.event_id: event - for event in events - if event.is_state() and event.room_id in rooms_using_chain_index - } - - if not state_events: - return - - # We need to know the type/state_key and auth events of the events we're - # calculating chain IDs for. We don't rely on having the full Event - # instances as we'll potentially be pulling more events from the DB and - # we don't need the overhead of fetching/parsing the full event JSON. - event_to_types = { - e.event_id: (e.type, e.state_key) for e in state_events.values() - } - event_to_auth_chain = { - e.event_id: e.auth_event_ids() for e in state_events.values() - } - event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} - - self._add_chain_cover_index( - txn, - self.db_pool, - self.store.event_chain_id_gen, - event_to_room_id, - event_to_types, - event_to_auth_chain, - ) - @classmethod def _add_chain_cover_index( cls, @@ -583,6 +635,35 @@ class PersistEventsStore: event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, StrCollection], ) -> None: + """Calculate and persist the chain cover index for the given events. + + Args: + event_to_room_id: Event ID to the room ID of the event + event_to_types: Event ID to type and state_key of the event + event_to_auth_chain: Event ID to list of auth event IDs of the + event (events with no auth events can be excluded). + """ + + new_event_links = cls._calculate_chain_cover_index( + txn, + db_pool, + event_chain_id_gen, + event_to_room_id, + event_to_types, + event_to_auth_chain, + ) + cls._persist_chain_cover_index(txn, db_pool, new_event_links) + + @classmethod + def _calculate_chain_cover_index( + cls, + txn: LoggingTransaction, + db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, StrCollection], + ) -> Dict[str, NewEventChainLinks]: """Calculate the chain cover index for the given events. Args: @@ -590,6 +671,10 @@ class PersistEventsStore: event_to_types: Event ID to type and state_key of the event event_to_auth_chain: Event ID to list of auth event IDs of the event (events with no auth events can be excluded). + + Returns: + A mapping with any new auth chain links we need to add, keyed by + event ID. """ # Map from event ID to chain ID/sequence number. @@ -708,11 +793,11 @@ class PersistEventsStore: room_id = event_to_room_id.get(event_id) if room_id: e_type, state_key = event_to_types[event_id] - db_pool.simple_insert_txn( + db_pool.simple_upsert_txn( txn, table="event_auth_chain_to_calculate", + keyvalues={"event_id": event_id}, values={ - "event_id": event_id, "room_id": room_id, "type": e_type, "state_key": state_key, @@ -724,7 +809,7 @@ class PersistEventsStore: break if not events_to_calc_chain_id_for: - return + return {} # Allocate chain ID/sequence numbers to each new event. new_chain_tuples = cls._allocate_chain_ids( @@ -739,23 +824,10 @@ class PersistEventsStore: ) chain_map.update(new_chain_tuples) - db_pool.simple_insert_many_txn( - txn, - table="event_auth_chains", - keys=("event_id", "chain_id", "sequence_number"), - values=[ - (event_id, c_id, seq) - for event_id, (c_id, seq) in new_chain_tuples.items() - ], - ) - - db_pool.simple_delete_many_txn( - txn, - table="event_auth_chain_to_calculate", - keyvalues={}, - column="event_id", - values=new_chain_tuples, - ) + to_return = { + event_id: NewEventChainLinks(chain_id, sequence_number) + for event_id, (chain_id, sequence_number) in new_chain_tuples.items() + } # Now we need to calculate any new links between chains caused by # the new events. @@ -825,10 +897,38 @@ class PersistEventsStore: auth_chain_id, auth_sequence_number = chain_map[auth_id] # Step 2a, add link between the event and auth event + to_return[event_id].links.append((auth_chain_id, auth_sequence_number)) chain_links.add_link( (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) ) + return to_return + + @classmethod + def _persist_chain_cover_index( + cls, + txn: LoggingTransaction, + db_pool: DatabasePool, + new_event_links: Dict[str, NewEventChainLinks], + ) -> None: + db_pool.simple_insert_many_txn( + txn, + table="event_auth_chains", + keys=("event_id", "chain_id", "sequence_number"), + values=[ + (event_id, new_links.chain_id, new_links.sequence_number) + for event_id, new_links in new_event_links.items() + ], + ) + + db_pool.simple_delete_many_txn( + txn, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="event_id", + values=new_event_links, + ) + db_pool.simple_insert_many_txn( txn, table="event_auth_chain_links", @@ -838,7 +938,16 @@ class PersistEventsStore: "target_chain_id", "target_sequence_number", ), - values=list(chain_links.get_additions()), + values=[ + ( + new_links.chain_id, + new_links.sequence_number, + target_chain_id, + target_sequence_number, + ) + for new_links in new_event_links.values() + for (target_chain_id, target_sequence_number) in new_links.links + ], ) @staticmethod diff --git a/synapse/streams/config.py b/synapse/streams/config.py index eeafe889de..9fee5bfb92 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -75,9 +75,6 @@ class PaginationConfig: raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) - if limit < 0: - raise SynapseError(400, "Limit must be 0 or above") - limit = min(limit, MAX_LIMIT) try: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index c5da1e9686..16bb4349f5 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -37,6 +37,7 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.media.filepath import MediaFilePaths +from synapse.rest import admin from synapse.rest.client import ( devices, login, @@ -5005,3 +5006,86 @@ class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase): ) assert timestamp is not None self.assertGreater(timestamp, self.clock.time_msec()) + + +class UserSuspensionTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin = self.register_user("thomas", "hackme", True) + self.admin_tok = self.login("thomas", "hackme") + + self.bad_user = self.register_user("teresa", "hackme") + self.bad_user_tok = self.login("teresa", "hackme") + + self.store = hs.get_datastores().main + + @override_config({"experimental_features": {"msc3823_account_suspension": True}}) + def test_suspend_user(self) -> None: + # test that suspending user works + channel = self.make_request( + "PUT", + f"/_synapse/admin/v1/suspend/{self.bad_user}", + {"suspend": True}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body, {f"user_{self.bad_user}_suspended": True}) + + res = self.get_success(self.store.get_user_suspended_status(self.bad_user)) + self.assertEqual(True, res) + + # test that un-suspending user works + channel2 = self.make_request( + "PUT", + f"/_synapse/admin/v1/suspend/{self.bad_user}", + {"suspend": False}, + access_token=self.admin_tok, + ) + self.assertEqual(channel2.code, 200) + self.assertEqual(channel2.json_body, {f"user_{self.bad_user}_suspended": False}) + + res2 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) + self.assertEqual(False, res2) + + # test that trying to un-suspend user who isn't suspended doesn't cause problems + channel3 = self.make_request( + "PUT", + f"/_synapse/admin/v1/suspend/{self.bad_user}", + {"suspend": False}, + access_token=self.admin_tok, + ) + self.assertEqual(channel3.code, 200) + self.assertEqual(channel3.json_body, {f"user_{self.bad_user}_suspended": False}) + + res3 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) + self.assertEqual(False, res3) + + # test that trying to suspend user who is already suspended doesn't cause problems + channel4 = self.make_request( + "PUT", + f"/_synapse/admin/v1/suspend/{self.bad_user}", + {"suspend": True}, + access_token=self.admin_tok, + ) + self.assertEqual(channel4.code, 200) + self.assertEqual(channel4.json_body, {f"user_{self.bad_user}_suspended": True}) + + res4 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) + self.assertEqual(True, res4) + + channel5 = self.make_request( + "PUT", + f"/_synapse/admin/v1/suspend/{self.bad_user}", + {"suspend": True}, + access_token=self.admin_tok, + ) + self.assertEqual(channel5.code, 200) + self.assertEqual(channel5.json_body, {f"user_{self.bad_user}_suspended": True}) + + res5 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) + self.assertEqual(True, res5) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index d398cead1c..c559dfda83 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -3819,3 +3819,108 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): # Make sure the outlier event is not returned self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id) + + +class UserSuspensionTests(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + profile.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user1 = self.register_user("thomas", "hackme") + self.tok1 = self.login("thomas", "hackme") + + self.user2 = self.register_user("teresa", "hackme") + self.tok2 = self.login("teresa", "hackme") + + self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.store = hs.get_datastores().main + + def test_suspended_user_cannot_send_message_to_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user1, True)) + + channel = self.make_request( + "PUT", + f"/rooms/{self.room1}/send/m.room.message/1", + access_token=self.tok1, + content={"body": "hello", "msgtype": "m.text"}, + ) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + def test_suspended_user_cannot_change_profile_data(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user1, True)) + + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user1}/avatar_url", + access_token=self.tok1, + content={"avatar_url": "mxc://matrix.org/wefh34uihSDRGhw34"}, + shorthand=False, + ) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + channel2 = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.user1}/displayname", + access_token=self.tok1, + content={"displayname": "something offensive"}, + shorthand=False, + ) + self.assertEqual( + channel2.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + def test_suspended_user_cannot_redact_messages_other_than_their_own(self) -> None: + # first user sends message + self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok2) + res = self.helper.send_event( + self.room1, + "m.room.message", + {"body": "hello", "msgtype": "m.text"}, + tok=self.tok2, + ) + event_id = res["event_id"] + + # second user sends message + self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok1) + res2 = self.helper.send_event( + self.room1, + "m.room.message", + {"body": "bad_message", "msgtype": "m.text"}, + tok=self.tok1, + ) + event_id2 = res2["event_id"] + + # set the second user as suspended + self.get_success(self.store.set_user_suspended_status(self.user1, True)) + + # second user can't redact first user's message + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/rooms/{self.room1}/redact/{event_id}/1", + access_token=self.tok1, + content={"reason": "bogus"}, + shorthand=False, + ) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + # but can redact their own + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/rooms/{self.room1}/redact/{event_id2}/1", + access_token=self.tok1, + content={"reason": "bogus"}, + shorthand=False, + ) + self.assertEqual(channel.code, 200) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 7f975d04ff..ba01b038ab 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -36,6 +36,14 @@ class DeviceStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main + def default_config(self) -> JsonDict: + config = super().default_config() + + # We 'enable' federation otherwise `get_device_updates_by_remote` will + # throw an exception. + config["federation_sender_instances"] = ["master"] + return config + def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None: """Add a device list change for the given device to `device_lists_outbound_pokes` table. diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 81feb3ec29..c4e216c308 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase): ) # Actually call the function that calculates the auth chain stuff. - persist_events_store._persist_event_auth_chain_txn(txn, events) + new_event_links = ( + persist_events_store.calculate_chain_cover_index_for_events_txn( + txn, events[0].room_id, [e for e in events if e.is_state()] + ) + ) + persist_events_store._persist_event_auth_chain_txn( + txn, events, new_event_links + ) self.get_success( persist_events_store.db_pool.runInteraction( diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 0a6253e22c..088f0d24f9 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): }, ) + events = [ + cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) + for event_id in AUTH_GRAPH + ] + new_event_links = ( + self.persist_events.calculate_chain_cover_index_for_events_txn( + txn, room_id, [e for e in events if e.is_state()] + ) + ) self.persist_events._persist_event_auth_chain_txn( txn, - [ - cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) - for event_id in AUTH_GRAPH - ], + events, + new_event_links, ) self.get_success( @@ -544,6 +551,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): rooms. """ + # We allow partial covers for this test + self.hs.get_datastores().main.tests_allow_no_chain_cover_index = True + room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -628,13 +638,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) # Insert all events apart from 'B' + events = [ + cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) + for event_id in auth_graph + if event_id != "b" + ] + new_event_links = ( + self.persist_events.calculate_chain_cover_index_for_events_txn( + txn, room_id, [e for e in events if e.is_state()] + ) + ) self.persist_events._persist_event_auth_chain_txn( txn, - [ - cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) - for event_id in auth_graph - if event_id != "b" - ], + events, + new_event_links, ) # Now we insert the event 'B' without a chain cover, by temporarily @@ -647,9 +664,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): updatevalues={"has_auth_chain_index": False}, ) + events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))] + new_event_links = ( + self.persist_events.calculate_chain_cover_index_for_events_txn( + txn, room_id, [e for e in events if e.is_state()] + ) + ) self.persist_events._persist_event_auth_chain_txn( - txn, - [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], + txn, events, new_event_links ) self.store.db_pool.simple_update_txn( diff --git a/tests/unittest.py b/tests/unittest.py index 18963b9e32..a7c20556a0 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -344,6 +344,8 @@ class HomeserverTestCase(TestCase): self._hs_args = {"clock": self.clock, "reactor": self.reactor} self.hs = self.make_homeserver(self.reactor, self.clock) + self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False + # Honour the `use_frozen_dicts` config option. We have to do this # manually because this is taken care of in the app `start` code, which # we don't run. Plus we want to reset it on tearDown.