From 96bb319d14f8c16a1f7b712ccc672dfbfe51f59c Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Sat, 15 Apr 2023 02:15:02 +0100 Subject: [PATCH] De-localpart `ProfileWorkerStore.get_profile_avatar_url()` Signed-off-by: Sean Quah --- synapse/handlers/profile.py | 4 ++-- synapse/storage/databases/main/profile.py | 27 +++++++++++++++++------ tests/handlers/test_profile.py | 24 ++++++++++++++++---- tests/storage/test_profile.py | 6 +++-- 4 files changed, 46 insertions(+), 15 deletions(-) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 1c5bdb15f1..aa90e38f5c 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -197,7 +197,7 @@ class ProfileHandler: if self.hs.is_mine(target_user): try: avatar_url = await self.store.get_profile_avatar_url( - target_user.localpart + target_user.to_string() ) except StoreError as e: if e.code == 404: @@ -380,7 +380,7 @@ class ProfileHandler: if just_field is None or just_field == "avatar_url": response["avatar_url"] = await self.store.get_profile_avatar_url( - user.localpart + user_id ) except StoreError as e: if e.code == 404: diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index a5e2ea9f04..12f984d433 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -78,13 +78,26 @@ class ProfileWorkerStore(SQLBaseStore): else: raise - async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: - return await self.db_pool.simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="avatar_url", - desc="get_profile_avatar_url", - ) + async def get_profile_avatar_url(self, user_id: str) -> Optional[str]: + try: + return await self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"full_user_id": user_id}, + retcol="avatar_url", + desc="get_profile_avatar_url", + ) + except StoreError as e: + if e.code == 404: + # Fall back to the `user_id` column. + user_localpart = UserID.from_string(user_id).localpart + return await self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="avatar_url", + desc="get_profile_avatar_url", + ) + else: + raise async def create_profile(self, user_localpart: str) -> None: await self.db_pool.simple_insert( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d8b2797859..2cf3fd2119 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -201,7 +201,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), + ( + self.get_success( + self.store.get_profile_avatar_url(self.frank.to_string()) + ) + ), "http://my.server/pic.gif", ) @@ -215,7 +219,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), + ( + self.get_success( + self.store.get_profile_avatar_url(self.frank.to_string()) + ) + ), "http://my.server/me.png", ) @@ -229,7 +237,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), + ( + self.get_success( + self.store.get_profile_avatar_url(self.frank.to_string()) + ) + ), ) def test_set_my_avatar_if_disabled(self) -> None: @@ -243,7 +255,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), + ( + self.get_success( + self.store.get_profile_avatar_url(self.frank.to_string()) + ) + ), "http://my.server/me.png", ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 702430f513..136352f838 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -66,7 +66,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): "http://my.site/here", ( self.get_success( - self.store.get_profile_avatar_url(self.u_frank.localpart) + self.store.get_profile_avatar_url(self.u_frank.to_string()) ) ), ) @@ -77,5 +77,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) + self.get_success( + self.store.get_profile_avatar_url(self.u_frank.to_string()) + ) )