Refactor get_user_by_id (#16316)

This commit is contained in:
Erik Johnston 2023-09-14 12:46:30 +01:00 committed by GitHub
parent 032cf84f52
commit 954921736b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 108 additions and 123 deletions

1
changelog.d/16316.misc Normal file
View file

@ -0,0 +1 @@
Refactor `get_user_by_id`.

View file

@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
stored_user = await self.store.get_user_by_id(user_id) stored_user = await self.store.get_user_by_id(user_id)
if not stored_user: if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id) raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]: if not stored_user.is_guest:
raise InvalidClientTokenError( raise InvalidClientTokenError(
"Guest access token used for regular user" "Guest access token used for regular user"
) )

View file

@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
user_id = UserID(username, self._hostname) user_id = UserID(username, self._hostname)
# First try to find a user from the username claim # First try to find a user from the username claim
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string()) user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None: if user_info is None:
# If the user does not exist, we should create it on the fly # If the user does not exist, we should create it on the fly
# TODO: we could use SCIM to provision users ahead of time and listen # TODO: we could use SCIM to provision users ahead of time and listen

View file

@ -102,7 +102,7 @@ class AccountHandler:
""" """
status = {"exists": False} status = {"exists": False}
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) userinfo = await self._main_store.get_user_by_id(user_id.to_string())
if userinfo is not None: if userinfo is not None:
status = { status = {

View file

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership from synapse.api.constants import Direction, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING: if TYPE_CHECKING:
@ -57,38 +57,30 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]: async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details""" """Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string()) user_info: Optional[UserInfo] = await self._store.get_user_by_id(
if user_info_dict is None: user.to_string()
)
if user_info is None:
return None return None
# Restrict returned information to a known set of fields. This prevents additional user_info_dict = {
# fields added to get_user_by_id from modifying Synapse's external API surface. "name": user.to_string(),
user_info_to_return = { "admin": user_info.is_admin,
"name", "deactivated": user_info.is_deactivated,
"admin", "locked": user_info.locked,
"deactivated", "shadow_banned": user_info.is_shadow_banned,
"locked", "creation_ts": user_info.creation_ts,
"shadow_banned", "appservice_id": user_info.appservice_id,
"creation_ts", "consent_server_notice_sent": user_info.consent_server_notice_sent,
"appservice_id", "consent_version": user_info.consent_version,
"consent_server_notice_sent", "consent_ts": user_info.consent_ts,
"consent_version", "user_type": user_info.user_type,
"consent_ts", "is_guest": user_info.is_guest,
"user_type",
"is_guest",
"last_seen_ts",
} }
if self._msc3866_enabled: if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled. # Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved") user_info_dict["approved"] = user_info.approved
# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
# Add additional user metadata # Add additional user metadata
profile = await self._store.get_profileinfo(user) profile = await self._store.get_profileinfo(user)
@ -105,6 +97,9 @@ class AdminHandler:
user_info_dict["external_ids"] = external_ids user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
user_info_dict["last_seen_ts"] = last_seen_ts
return user_info_dict return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:

View file

@ -828,13 +828,13 @@ class EventCreationHandler:
u = await self.store.get_user_by_id(user_id) u = await self.store.get_user_by_id(user_id)
assert u is not None assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent # support and bot users are not required to consent
return return
if u["appservice_id"] is not None: if u.appservice_id is not None:
# users registered by an appservice are exempt # users registered by an appservice are exempt
return return
if u["consent_version"] == self.config.consent.user_consent_version: if u.consent_version == self.config.consent.user_consent_version:
return return
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)

View file

@ -572,7 +572,7 @@ class ModuleApi:
Returns: Returns:
UserInfo object if a user was found, otherwise None UserInfo object if a user was found, otherwise None
""" """
return await self._store.get_userinfo_by_id(user_id) return await self._store.get_user_by_id(user_id)
async def get_user_by_req( async def get_user_by_req(
self, self,
@ -1878,7 +1878,7 @@ class AccountDataManager:
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}") raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
# Ensure the user exists, so we don't just write to users that aren't there. # Ensure the user exists, so we don't just write to users that aren't there.
if await self._store.get_userinfo_by_id(user_id) is None: if await self._store.get_user_by_id(user_id) is None:
raise ValueError(f"User {user_id} does not exist on this server.") raise ValueError(f"User {user_id} does not exist on this server.")
await self._handler.add_account_data_for_user(user_id, data_type, new_data) await self._handler.add_account_data_for_user(user_id, data_type, new_data)

View file

@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
if u is None: if u is None:
raise NotFoundError("Unknown user") raise NotFoundError("Unknown user")
has_consented = u["consent_version"] == version has_consented = u.consent_version == version
userhmac = userhmac_bytes.decode("ascii") userhmac = userhmac_bytes.decode("ascii")
try: try:

View file

@ -79,15 +79,15 @@ class ConsentServerNotices:
if u is None: if u is None:
return return
if u["is_guest"] and not self._send_to_guests: if u.is_guest and not self._send_to_guests:
# don't send to guests # don't send to guests
return return
if u["consent_version"] == self._current_consent_version: if u.consent_version == self._current_consent_version:
# user has already consented # user has already consented
return return
if u["consent_server_notice_sent"] == self._current_consent_version: if u.consent_server_notice_sent == self._current_consent_version:
# we've already sent a notice to the user # we've already sent a notice to the user
return return

View file

@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
} }
return list(results.values()) return list(results.values())
async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
"""Get the last seen timestamp for a user, if we have it."""
return await self.db_pool.simple_select_one_onecol(
table="user_ips",
keyvalues={"user_id": user_id},
retcol="MAX(last_seen)",
allow_none=True,
desc="get_last_seen_for_user_id",
)

View file

@ -16,7 +16,7 @@
import logging import logging
import random import random
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr import attr
@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
@cached() @cached()
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Deprecated: use get_userinfo_by_id instead""" """Returns info about the user account, if it exists."""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform # We could technically use simple_select_one here, but it would not perform
@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute( txn.execute(
""" """
SELECT SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts, name, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type, consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved, COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked, last_seen_ts COALESCE(locked, FALSE) AS locked
FROM users FROM users
LEFT JOIN (
SELECT user_id, MAX(last_seen) AS last_seen_ts
FROM user_ips GROUP BY user_id
) ls ON users.name = ls.user_id
WHERE name = ? WHERE name = ?
""", """,
(user_id,), (user_id,),
@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id", desc="get_user_by_id",
func=get_user_by_id_txn, func=get_user_by_id_txn,
) )
if row is None:
if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = [
"admin",
"deactivated",
"shadow_banned",
"approved",
"locked",
]
for column in boolean_columns:
row[column] = bool(row[column])
return row
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
this method should be cached.
Args:
user_id: The user to fetch user info for.
Returns:
`UserInfo` object if user found, otherwise `None`.
"""
user_data = await self.get_user_by_id(user_id)
if not user_data:
return None return None
return UserInfo( return UserInfo(
appservice_id=user_data["appservice_id"], appservice_id=row["appservice_id"],
consent_server_notice_sent=user_data["consent_server_notice_sent"], consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=user_data["consent_version"], consent_version=row["consent_version"],
creation_ts=user_data["creation_ts"], consent_ts=row["consent_ts"],
is_admin=bool(user_data["admin"]), creation_ts=row["creation_ts"],
is_deactivated=bool(user_data["deactivated"]), is_admin=bool(row["admin"]),
is_guest=bool(user_data["is_guest"]), is_deactivated=bool(row["deactivated"]),
is_shadow_banned=bool(user_data["shadow_banned"]), is_guest=bool(row["is_guest"]),
user_id=UserID.from_string(user_data["name"]), is_shadow_banned=bool(row["shadow_banned"]),
user_type=user_data["user_type"], user_id=UserID.from_string(row["name"]),
last_seen_ts=user_data["last_seen_ts"], user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
) )
async def is_trial_user(self, user_id: str) -> bool: async def is_trial_user(self, user_id: str) -> bool:
@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
now = self._clock.time_msec() now = self._clock.time_msec()
days = self.config.server.mau_appservice_trial_days.get( days = self.config.server.mau_appservice_trial_days.get(
info["appservice_id"], self.config.server.mau_trial_days info.appservice_id, self.config.server.mau_trial_days
) )
trial_duration_ms = days * 24 * 60 * 60 * 1000 trial_duration_ms = days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
return is_trial return is_trial
@cached() @cached()

View file

@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
@attr.s(auto_attribs=True, frozen=True, slots=True) @attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo: class UserInfo:
"""Holds information about a user. Result of get_userinfo_by_id. """Holds information about a user. Result of get_user_by_id.
Attributes: Attributes:
user_id: ID of the user. user_id: ID of the user.
appservice_id: Application service ID that created this user. appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent. consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to. consent_version: Version of policy documents the user has consented to.
consent_ts: Time the user consented
creation_ts: Creation timestamp of the user. creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin. is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated. is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user. is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned. is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options). user_type: User type (None for normal user, 'support' and 'bot' other options).
last_seen_ts: Last activity timestamp of the user. approved: If the user has been "approved" to register on the server.
locked: Whether the user's account has been locked
""" """
user_id: UserID user_id: UserID
appservice_id: Optional[int] appservice_id: Optional[int]
consent_server_notice_sent: Optional[str] consent_server_notice_sent: Optional[str]
consent_version: Optional[str] consent_version: Optional[str]
consent_ts: Optional[int]
user_type: Optional[str] user_type: Optional[str]
creation_ts: int creation_ts: int
is_admin: bool is_admin: bool
is_deactivated: bool is_deactivated: bool
is_guest: bool is_guest: bool
is_shadow_banned: bool is_shadow_banned: bool
last_seen_ts: Optional[int] approved: bool
locked: bool
class UserProfile(TypedDict): class UserProfile(TypedDict):

View file

@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) class FakeUserInfo:
is_guest = False
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
def test_get_guest_user_from_macaroon(self) -> None: def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) class FakeUserInfo:
is_guest = True
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"

View file

@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID, UserInfo
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEqual( self.assertEqual(
{ UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name' # TODO(paul): Surely this field should be 'user_id', not 'name'
"name": self.user_id, user_id=UserID.from_string(self.user_id),
"password_hash": self.pwhash, is_admin=False,
"admin": 0, is_guest=False,
"is_guest": 0, consent_server_notice_sent=None,
"consent_version": None, consent_ts=None,
"consent_ts": None, consent_version=None,
"consent_server_notice_sent": None, appservice_id=None,
"appservice_id": None, creation_ts=0,
"creation_ts": 0, user_type=None,
"user_type": None, is_deactivated=False,
"deactivated": 0, locked=False,
"locked": 0, is_shadow_banned=False,
"shadow_banned": 0, approved=True,
"approved": 1, ),
"last_seen_ts": None,
},
(self.get_success(self.store.get_user_by_id(self.user_id))), (self.get_success(self.store.get_user_by_id(self.user_id))),
) )
@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id)) user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user assert user
self.assertEqual(user["consent_version"], "1") self.assertEqual(user.consent_version, "1")
self.assertGreater(user["consent_ts"], before_consent) self.assertIsNotNone(user.consent_ts)
self.assertLess(user["consent_ts"], self.clock.time_msec()) assert user.consent_ts is not None
self.assertGreater(user.consent_ts, before_consent)
self.assertLess(user.consent_ts, self.clock.time_msec())
def test_add_tokens(self) -> None: def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success(self.store.register_user(self.user_id, self.pwhash))
@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id)) user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None assert user is not None
self.assertTrue(user["approved"]) self.assertTrue(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id)) approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved) self.assertTrue(approved)
@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id)) user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None assert user is not None
self.assertFalse(user["approved"]) self.assertFalse(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id)) approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved) self.assertFalse(approved)
@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id)) user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user) self.assertIsNotNone(user)
assert user is not None assert user is not None
self.assertEqual(user["approved"], 1) self.assertEqual(user.approved, 1)
approved = self.get_success(self.store.is_user_approved(self.user_id)) approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved) self.assertTrue(approved)