Add support for MSC3823 - Account Suspension (#17051)

This commit is contained in:
Shay 2024-05-01 09:45:17 -07:00 committed by GitHub
parent 0b358f8643
commit 37558d5e4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 173 additions and 7 deletions

View file

@ -0,0 +1 @@
Add preliminary support for [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823) - Account Suspension.

View file

@ -127,7 +127,7 @@ BOOLEAN_COLUMNS = {
"redactions": ["have_censored"], "redactions": ["have_censored"],
"room_stats_state": ["is_federatable"], "room_stats_state": ["is_federatable"],
"rooms": ["is_public", "has_auth_chain_index"], "rooms": ["is_public", "has_auth_chain_index"],
"users": ["shadow_banned", "approved", "locked"], "users": ["shadow_banned", "approved", "locked", "suspended"],
"un_partial_stated_event_stream": ["rejection_status_changed"], "un_partial_stated_event_stream": ["rejection_status_changed"],
"users_who_share_rooms": ["share_private"], "users_who_share_rooms": ["share_private"],
"per_user_experimental_features": ["enabled"], "per_user_experimental_features": ["enabled"],

View file

@ -752,6 +752,36 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
and requester.user.to_string() == self._server_notices_mxid and requester.user.to_string() == self._server_notices_mxid
) )
requester_suspended = await self.store.get_user_suspended_status(
requester.user.to_string()
)
if action == Membership.INVITE and requester_suspended:
raise SynapseError(
403,
"Sending invites while account is suspended is not allowed.",
Codes.USER_ACCOUNT_SUSPENDED,
)
if target.to_string() != requester.user.to_string():
target_suspended = await self.store.get_user_suspended_status(
target.to_string()
)
else:
target_suspended = requester_suspended
if action == Membership.JOIN and target_suspended:
raise SynapseError(
403,
"Joining rooms while account is suspended is not allowed.",
Codes.USER_ACCOUNT_SUSPENDED,
)
if action == Membership.KNOCK and target_suspended:
raise SynapseError(
403,
"Knocking on rooms while account is suspended is not allowed.",
Codes.USER_ACCOUNT_SUSPENDED,
)
if ( if (
not self.allow_per_room_profiles and not is_requester_server_notices_user not self.allow_per_room_profiles and not is_requester_server_notices_user
) or requester.shadow_banned: ) or requester.shadow_banned:

View file

@ -236,7 +236,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
consent_server_notice_sent, appservice_id, creation_ts, user_type, consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved, COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked COALESCE(locked, FALSE) AS locked,
suspended
FROM users FROM users
WHERE name = ? WHERE name = ?
""", """,
@ -261,6 +262,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
shadow_banned, shadow_banned,
approved, approved,
locked, locked,
suspended,
) = row ) = row
return UserInfo( return UserInfo(
@ -277,6 +279,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_type=user_type, user_type=user_type,
approved=bool(approved), approved=bool(approved),
locked=bool(locked), locked=bool(locked),
suspended=bool(suspended),
) )
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -1180,6 +1183,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Convert the potential integer into a boolean. # Convert the potential integer into a boolean.
return bool(res) return bool(res)
@cached()
async def get_user_suspended_status(self, user_id: str) -> bool:
"""
Determine whether the user's account is suspended.
Args:
user_id: The user ID of the user in question
Returns:
True if the user's account is suspended, false if it is not suspended or
if the user ID cannot be found.
"""
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="suspended",
allow_none=True,
desc="get_user_suspended",
)
return bool(res)
async def get_threepid_validation_session( async def get_threepid_validation_session(
self, self,
medium: Optional[str], medium: Optional[str],
@ -2213,6 +2237,35 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None:
"""
Set whether the user's account is suspended in the `users` table.
Args:
user_id: The user ID of the user in question
suspended: True if the user is suspended, false if not
"""
await self.db_pool.runInteraction(
"set_user_suspended_status",
self.set_user_suspended_status_txn,
user_id,
suspended,
)
def set_user_suspended_status_txn(
self, txn: LoggingTransaction, user_id: str, suspended: bool
) -> None:
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"suspended": suspended},
)
self._invalidate_cache_and_stream(
txn, self.get_user_suspended_status, (user_id,)
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
async def set_user_locked_status(self, user_id: str, locked: bool) -> None: async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
"""Set the `locked` property for the provided user to the provided value. """Set the `locked` property for the provided user to the provided value.

View file

@ -19,7 +19,7 @@
# #
# #
SCHEMA_VERSION = 84 # remember to update the list below when updating SCHEMA_VERSION = 85 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema """Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the This should be incremented whenever the codebase changes its requirements on the
@ -136,6 +136,9 @@ Changes in SCHEMA_VERSION = 83
Changes in SCHEMA_VERSION = 84 Changes in SCHEMA_VERSION = 84
- No longer assumes that `event_auth_chain_links` holds transitive links, and - No longer assumes that `event_auth_chain_links` holds transitive links, and
so read operations must do graph traversal. so read operations must do graph traversal.
Changes in SCHEMA_VERSION = 85
- Add a column `suspended` to the `users` table
""" """

View file

@ -0,0 +1,14 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
ALTER TABLE users ADD COLUMN suspended BOOLEAN DEFAULT FALSE NOT NULL;

View file

@ -1156,6 +1156,7 @@ class UserInfo:
user_type: User type (None for normal user, 'support' and 'bot' other options). user_type: User type (None for normal user, 'support' and 'bot' other options).
approved: If the user has been "approved" to register on the server. approved: If the user has been "approved" to register on the server.
locked: Whether the user's account has been locked locked: Whether the user's account has been locked
suspended: Whether the user's account is currently suspended
""" """
user_id: UserID user_id: UserID
@ -1171,6 +1172,7 @@ class UserInfo:
is_shadow_banned: bool is_shadow_banned: bool
approved: bool approved: bool
locked: bool locked: bool
suspended: bool
class UserProfile(TypedDict): class UserProfile(TypedDict):

View file

@ -48,7 +48,16 @@ from synapse.appservice import ApplicationService
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, register, room, sync from synapse.rest.client import (
account,
directory,
knock,
login,
profile,
register,
room,
sync,
)
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
@ -733,7 +742,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None assert channel.resource_usage is not None
self.assertEqual(32, channel.resource_usage.db_txn_count) self.assertEqual(33, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None: def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id # POST with initial_state config key, expect new room id
@ -746,7 +755,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None assert channel.resource_usage is not None
self.assertEqual(34, channel.resource_usage.db_txn_count) self.assertEqual(35, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None: def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id # POST with visibility config key, expect new room id
@ -1154,6 +1163,7 @@ class RoomJoinTestCase(RoomBase):
admin.register_servlets, admin.register_servlets,
login.register_servlets, login.register_servlets,
room.register_servlets, room.register_servlets,
knock.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@ -1167,6 +1177,8 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.store = hs.get_datastores().main
def test_spam_checker_may_join_room_deprecated(self) -> None: def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called """Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed. and blocks room joins when needed.
@ -1317,6 +1329,57 @@ class RoomJoinTestCase(RoomBase):
expect_additional_fields=return_value[1], expect_additional_fields=return_value[1],
) )
def test_suspended_user_cannot_join_room(self) -> None:
# set the user as suspended
self.get_success(self.store.set_user_suspended_status(self.user2, True))
channel = self.make_request(
"POST", f"/join/{self.room1}", access_token=self.tok2
)
self.assertEqual(channel.code, 403)
self.assertEqual(
channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
)
channel = self.make_request(
"POST", f"/rooms/{self.room1}/join", access_token=self.tok2
)
self.assertEqual(channel.code, 403)
self.assertEqual(
channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
)
def test_suspended_user_cannot_knock_on_room(self) -> None:
# set the user as suspended
self.get_success(self.store.set_user_suspended_status(self.user2, True))
channel = self.make_request(
"POST",
f"/_matrix/client/v3/knock/{self.room1}",
access_token=self.tok2,
content={},
shorthand=False,
)
self.assertEqual(channel.code, 403)
self.assertEqual(
channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
)
def test_suspended_user_cannot_invite_to_room(self) -> None:
# set the user as suspended
self.get_success(self.store.set_user_suspended_status(self.user1, True))
# first user invites second user
channel = self.make_request(
"POST",
f"/rooms/{self.room1}/invite",
access_token=self.tok1,
content={"user_id": self.user2},
)
self.assertEqual(
channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
)
class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [

View file

@ -43,7 +43,6 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.assertEqual( self.assertEqual(
UserInfo( UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name'
user_id=UserID.from_string(self.user_id), user_id=UserID.from_string(self.user_id),
is_admin=False, is_admin=False,
is_guest=False, is_guest=False,
@ -57,6 +56,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
locked=False, locked=False,
is_shadow_banned=False, is_shadow_banned=False,
approved=True, approved=True,
suspended=False,
), ),
(self.get_success(self.store.get_user_by_id(self.user_id))), (self.get_success(self.store.get_user_by_id(self.user_id))),
) )