diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 4aa4ebf7e4..e0efc93f2e 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.types import UserID from synapse.util import stringutils from synapse.util.async_helpers import delay_cancellation @@ -163,9 +162,7 @@ class AccountValidityHandler: return try: - user_display_name = await self.store.get_profile_displayname( - UserID.from_string(user_id).localpart - ) + user_display_name = await self.store.get_profile_displayname(user_id) if user_display_name is None: user_display_name = user_id except StoreError: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 4fa5a8611f..1c5bdb15f1 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -100,7 +100,7 @@ class ProfileHandler: if self.hs.is_mine(target_user): try: displayname = await self.store.get_profile_displayname( - target_user.localpart + target_user.to_string() ) except StoreError as e: if e.code == 404: @@ -364,7 +364,8 @@ class ProfileHandler: Codes.FORBIDDEN, ) - user = UserID.from_string(args["user_id"]) + user_id = args["user_id"] + user = UserID.from_string(user_id) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -374,7 +375,7 @@ class ProfileHandler: try: if just_field is None or just_field == "displayname": response["displayname"] = await self.store.get_profile_displayname( - user.localpart + user_id ) if just_field is None or just_field == "avatar_url": diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 491a09b71d..bf9cd4109c 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -37,7 +37,7 @@ from synapse.push.push_types import ( TemplateVars, ) from synapse.storage.databases.main.event_push_actions import EmailPushAction -from synapse.types import StateMap, UserID +from synapse.types import StateMap from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client @@ -246,9 +246,7 @@ class Mailer: state_by_room = {} try: - user_display_name = await self.store.get_profile_displayname( - UserID.from_string(user_id).localpart - ) + user_display_name = await self.store.get_profile_displayname(user_id) if user_display_name is None: user_display_name = user_id except StoreError: diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 23021a1f1f..a5e2ea9f04 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -57,13 +57,26 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: - return await self.db_pool.simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="displayname", - desc="get_profile_displayname", - ) + async def get_profile_displayname(self, user_id: str) -> Optional[str]: + try: + return await self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"full_user_id": user_id}, + retcol="displayname", + desc="get_profile_displayname", + ) + except StoreError as e: + if e.code == 404: + # Fall back to the `user_id` column. + user_localpart = UserID.from_string(user_id).localpart + return await self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="displayname", + desc="get_profile_displayname", + ) + else: + raise async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 7c174782da..d8b2797859 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -84,7 +84,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual( ( self.get_success( - self.store.get_profile_displayname(self.frank.localpart) + self.store.get_profile_displayname(self.frank.to_string()) ) ), "Frank Jr.", @@ -100,7 +100,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual( ( self.get_success( - self.store.get_profile_displayname(self.frank.localpart) + self.store.get_profile_displayname(self.frank.to_string()) ) ), "Frank", @@ -114,7 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - self.get_success(self.store.get_profile_displayname(self.frank.localpart)) + self.get_success(self.store.get_profile_displayname(self.frank.to_string())) ) def test_set_my_name_if_disabled(self) -> None: @@ -128,7 +128,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual( ( self.get_success( - self.store.get_profile_displayname(self.frank.localpart) + self.store.get_profile_displayname(self.frank.to_string()) ) ), "Frank", diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 758b4bc38b..23364b8f7e 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -103,7 +103,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): self.assertEqual(email["added_at"], 0) # Check that the displayname was assigned - displayname = self.get_success(self.store.get_profile_displayname("bob")) + displayname = self.get_success(self.store.get_profile_displayname("@bob:test")) self.assertEqual(displayname, "Bobberino") def test_can_register_admin_user(self) -> None: diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index a019d06e09..702430f513 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -37,7 +37,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): "Frank", ( self.get_success( - self.store.get_profile_displayname(self.u_frank.localpart) + self.store.get_profile_displayname(self.u_frank.to_string()) ) ), ) @@ -48,7 +48,9 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) + self.get_success( + self.store.get_profile_displayname(self.u_frank.to_string()) + ) ) def test_avatar_url(self) -> None: