From c2000ab35b76288a625f598d2382d4e3f29f65f6 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Wed, 4 Aug 2021 13:40:25 +0300 Subject: [PATCH] Add `get_userinfo_by_id` method to `ModuleApi` (#9581) Makes it easier to fetch user details in for example spam checker modules, without needing to use api._store or figure out database interactions. Signed-off-by: Jason Robinson --- changelog.d/9581.feature | 1 + synapse/module_api/__init__.py | 12 +++++++- .../storage/databases/main/registration.py | 30 ++++++++++++++++++- synapse/types.py | 29 ++++++++++++++++++ tests/module_api/test_api.py | 10 +++++++ 5 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9581.feature diff --git a/changelog.d/9581.feature b/changelog.d/9581.feature new file mode 100644 index 0000000000..fa1949cd4b --- /dev/null +++ b/changelog.d/9581.feature @@ -0,0 +1 @@ +Add `get_userinfo_by_id` method to ModuleApi. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 473812b8e2..1cc13fc97b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester from synapse.util import Clock from synapse.util.caches.descriptors import cached @@ -174,6 +174,16 @@ class ModuleApi: """The application name configured in the homeserver's configuration.""" return self._hs.config.email.email_app_name + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get user info by user_id + + Args: + user_id: Fully qualified user id. + Returns: + UserInfo object if a user was found, otherwise None + """ + return await self._store.get_userinfo_by_id(user_id) + async def get_user_by_req( self, req: SynapseRequest, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6ad1a0cf7f..14670c2881 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config diff --git a/synapse/types.py b/synapse/types.py index 429bb013d2..80fa903c4b 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info): # and return that one key for key_id, key_data in keys.items(): return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 81d9e2f484..0b817cc701 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_get_userinfo_by_id(self): + user_id = self.register_user("alice", "1234") + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, False) + + def test_get_userinfo_by_id__no_user_found(self): + found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) + self.assertIsNone(found_user) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent