De-localpart ProfileWorkerStore.get_profile_displayname()

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2023-04-15 02:02:01 +01:00
parent b375e2abd9
commit e6c582095f
7 changed files with 36 additions and 25 deletions

View file

@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async_helpers import delay_cancellation from synapse.util.async_helpers import delay_cancellation
@ -163,9 +162,7 @@ class AccountValidityHandler:
return return
try: try:
user_display_name = await self.store.get_profile_displayname( user_display_name = await self.store.get_profile_displayname(user_id)
UserID.from_string(user_id).localpart
)
if user_display_name is None: if user_display_name is None:
user_display_name = user_id user_display_name = user_id
except StoreError: except StoreError:

View file

@ -100,7 +100,7 @@ class ProfileHandler:
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
displayname = await self.store.get_profile_displayname( displayname = await self.store.get_profile_displayname(
target_user.localpart target_user.to_string()
) )
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
@ -364,7 +364,8 @@ class ProfileHandler:
Codes.FORBIDDEN, Codes.FORBIDDEN,
) )
user = UserID.from_string(args["user_id"]) user_id = args["user_id"]
user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
@ -374,7 +375,7 @@ class ProfileHandler:
try: try:
if just_field is None or just_field == "displayname": if just_field is None or just_field == "displayname":
response["displayname"] = await self.store.get_profile_displayname( response["displayname"] = await self.store.get_profile_displayname(
user.localpart user_id
) )
if just_field is None or just_field == "avatar_url": if just_field is None or just_field == "avatar_url":

View file

@ -37,7 +37,7 @@ from synapse.push.push_types import (
TemplateVars, TemplateVars,
) )
from synapse.storage.databases.main.event_push_actions import EmailPushAction from synapse.storage.databases.main.event_push_actions import EmailPushAction
from synapse.types import StateMap, UserID from synapse.types import StateMap
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -246,9 +246,7 @@ class Mailer:
state_by_room = {} state_by_room = {}
try: try:
user_display_name = await self.store.get_profile_displayname( user_display_name = await self.store.get_profile_displayname(user_id)
UserID.from_string(user_id).localpart
)
if user_display_name is None: if user_display_name is None:
user_display_name = user_id user_display_name = user_id
except StoreError: except StoreError:

View file

@ -57,13 +57,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"] avatar_url=profile["avatar_url"], display_name=profile["displayname"]
) )
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: async def get_profile_displayname(self, user_id: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol( try:
table="profiles", return await self.db_pool.simple_select_one_onecol(
keyvalues={"user_id": user_localpart}, table="profiles",
retcol="displayname", keyvalues={"full_user_id": user_id},
desc="get_profile_displayname", retcol="displayname",
) desc="get_profile_displayname",
)
except StoreError as e:
if e.code == 404:
# Fall back to the `user_id` column.
user_localpart = UserID.from_string(user_id).localpart
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
else:
raise
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(

View file

@ -84,7 +84,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
( (
self.get_success( self.get_success(
self.store.get_profile_displayname(self.frank.localpart) self.store.get_profile_displayname(self.frank.to_string())
) )
), ),
"Frank Jr.", "Frank Jr.",
@ -100,7 +100,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
( (
self.get_success( self.get_success(
self.store.get_profile_displayname(self.frank.localpart) self.store.get_profile_displayname(self.frank.to_string())
) )
), ),
"Frank", "Frank",
@ -114,7 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
) )
self.assertIsNone( self.assertIsNone(
self.get_success(self.store.get_profile_displayname(self.frank.localpart)) self.get_success(self.store.get_profile_displayname(self.frank.to_string()))
) )
def test_set_my_name_if_disabled(self) -> None: def test_set_my_name_if_disabled(self) -> None:
@ -128,7 +128,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
( (
self.get_success( self.get_success(
self.store.get_profile_displayname(self.frank.localpart) self.store.get_profile_displayname(self.frank.to_string())
) )
), ),
"Frank", "Frank",

View file

@ -103,7 +103,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(email["added_at"], 0) self.assertEqual(email["added_at"], 0)
# Check that the displayname was assigned # Check that the displayname was assigned
displayname = self.get_success(self.store.get_profile_displayname("bob")) displayname = self.get_success(self.store.get_profile_displayname("@bob:test"))
self.assertEqual(displayname, "Bobberino") self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self) -> None: def test_can_register_admin_user(self) -> None:

View file

@ -37,7 +37,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
"Frank", "Frank",
( (
self.get_success( self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart) self.store.get_profile_displayname(self.u_frank.to_string())
) )
), ),
) )
@ -48,7 +48,9 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
) )
self.assertIsNone( self.assertIsNone(
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) self.get_success(
self.store.get_profile_displayname(self.u_frank.to_string())
)
) )
def test_avatar_url(self) -> None: def test_avatar_url(self) -> None: