mirror of
https://github.com/element-hq/synapse
synced 2024-07-15 12:54:05 +00:00
De-localpart ProfileWorkerStore.get_profile_displayname()
Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
parent
b375e2abd9
commit
e6c582095f
|
@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|||
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.types import UserID
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
|
||||
|
@ -163,9 +162,7 @@ class AccountValidityHandler:
|
|||
return
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
user_display_name = await self.store.get_profile_displayname(user_id)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
except StoreError:
|
||||
|
|
|
@ -100,7 +100,7 @@ class ProfileHandler:
|
|||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
displayname = await self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
target_user.to_string()
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
|
@ -364,7 +364,8 @@ class ProfileHandler:
|
|||
Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
user = UserID.from_string(args["user_id"])
|
||||
user_id = args["user_id"]
|
||||
user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||
|
||||
|
@ -374,7 +375,7 @@ class ProfileHandler:
|
|||
try:
|
||||
if just_field is None or just_field == "displayname":
|
||||
response["displayname"] = await self.store.get_profile_displayname(
|
||||
user.localpart
|
||||
user_id
|
||||
)
|
||||
|
||||
if just_field is None or just_field == "avatar_url":
|
||||
|
|
|
@ -37,7 +37,7 @@ from synapse.push.push_types import (
|
|||
TemplateVars,
|
||||
)
|
||||
from synapse.storage.databases.main.event_push_actions import EmailPushAction
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.types import StateMap
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
@ -246,9 +246,7 @@ class Mailer:
|
|||
state_by_room = {}
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
user_display_name = await self.store.get_profile_displayname(user_id)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
except StoreError:
|
||||
|
|
|
@ -57,13 +57,26 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
)
|
||||
|
||||
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
async def get_profile_displayname(self, user_id: str) -> Optional[str]:
|
||||
try:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"full_user_id": user_id},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
# Fall back to the `user_id` column.
|
||||
user_localpart = UserID.from_string(user_id).localpart
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
|
|
|
@ -84,7 +84,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
self.store.get_profile_displayname(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"Frank Jr.",
|
||||
|
@ -100,7 +100,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
self.store.get_profile_displayname(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
|
@ -114,7 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
||||
self.get_success(self.store.get_profile_displayname(self.frank.to_string()))
|
||||
)
|
||||
|
||||
def test_set_my_name_if_disabled(self) -> None:
|
||||
|
@ -128,7 +128,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
self.store.get_profile_displayname(self.frank.to_string())
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
|
|
|
@ -103,7 +103,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
|||
self.assertEqual(email["added_at"], 0)
|
||||
|
||||
# Check that the displayname was assigned
|
||||
displayname = self.get_success(self.store.get_profile_displayname("bob"))
|
||||
displayname = self.get_success(self.store.get_profile_displayname("@bob:test"))
|
||||
self.assertEqual(displayname, "Bobberino")
|
||||
|
||||
def test_can_register_admin_user(self) -> None:
|
||||
|
|
|
@ -37,7 +37,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
|||
"Frank",
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.u_frank.localpart)
|
||||
self.store.get_profile_displayname(self.u_frank.to_string())
|
||||
)
|
||||
),
|
||||
)
|
||||
|
@ -48,7 +48,9 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.u_frank.to_string())
|
||||
)
|
||||
)
|
||||
|
||||
def test_avatar_url(self) -> None:
|
||||
|
|
Loading…
Reference in a new issue