De-localpart ProfileWorkerStore.get_profile_displayname()

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2023-04-15 02:02:01 +01:00
parent b375e2abd9
commit e6c582095f
7 changed files with 36 additions and 25 deletions

View file

@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
from synapse.util.async_helpers import delay_cancellation
@ -163,9 +162,7 @@ class AccountValidityHandler:
return
try:
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
user_display_name = await self.store.get_profile_displayname(user_id)
if user_display_name is None:
user_display_name = user_id
except StoreError:

View file

@ -100,7 +100,7 @@ class ProfileHandler:
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(
target_user.localpart
target_user.to_string()
)
except StoreError as e:
if e.code == 404:
@ -364,7 +364,8 @@ class ProfileHandler:
Codes.FORBIDDEN,
)
user = UserID.from_string(args["user_id"])
user_id = args["user_id"]
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
@ -374,7 +375,7 @@ class ProfileHandler:
try:
if just_field is None or just_field == "displayname":
response["displayname"] = await self.store.get_profile_displayname(
user.localpart
user_id
)
if just_field is None or just_field == "avatar_url":

View file

@ -37,7 +37,7 @@ from synapse.push.push_types import (
TemplateVars,
)
from synapse.storage.databases.main.event_push_actions import EmailPushAction
from synapse.types import StateMap, UserID
from synapse.types import StateMap
from synapse.types.state import StateFilter
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
@ -246,9 +246,7 @@ class Mailer:
state_by_room = {}
try:
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
user_display_name = await self.store.get_profile_displayname(user_id)
if user_display_name is None:
user_display_name = user_id
except StoreError:

View file

@ -57,13 +57,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
async def get_profile_displayname(self, user_id: str) -> Optional[str]:
try:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"full_user_id": user_id},
retcol="displayname",
desc="get_profile_displayname",
)
except StoreError as e:
if e.code == 404:
# Fall back to the `user_id` column.
user_localpart = UserID.from_string(user_id).localpart
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
else:
raise
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(

View file

@ -84,7 +84,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(
(
self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
self.store.get_profile_displayname(self.frank.to_string())
)
),
"Frank Jr.",
@ -100,7 +100,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(
(
self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
self.store.get_profile_displayname(self.frank.to_string())
)
),
"Frank",
@ -114,7 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
self.get_success(self.store.get_profile_displayname(self.frank.to_string()))
)
def test_set_my_name_if_disabled(self) -> None:
@ -128,7 +128,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual(
(
self.get_success(
self.store.get_profile_displayname(self.frank.localpart)
self.store.get_profile_displayname(self.frank.to_string())
)
),
"Frank",

View file

@ -103,7 +103,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(email["added_at"], 0)
# Check that the displayname was assigned
displayname = self.get_success(self.store.get_profile_displayname("bob"))
displayname = self.get_success(self.store.get_profile_displayname("@bob:test"))
self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self) -> None:

View file

@ -37,7 +37,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
"Frank",
(
self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
self.store.get_profile_displayname(self.u_frank.to_string())
)
),
)
@ -48,7 +48,9 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
self.get_success(
self.store.get_profile_displayname(self.u_frank.to_string())
)
)
def test_avatar_url(self) -> None: