De-localpart ProfileWorkerStore.get_profile_avatar_url()

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2023-04-15 02:15:02 +01:00
parent e6c582095f
commit 96bb319d14
4 changed files with 46 additions and 15 deletions

View file

@ -197,7 +197,7 @@ class ProfileHandler:
if self.hs.is_mine(target_user):
try:
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
target_user.to_string()
)
except StoreError as e:
if e.code == 404:
@ -380,7 +380,7 @@ class ProfileHandler:
if just_field is None or just_field == "avatar_url":
response["avatar_url"] = await self.store.get_profile_avatar_url(
user.localpart
user_id
)
except StoreError as e:
if e.code == 404:

View file

@ -78,13 +78,26 @@ class ProfileWorkerStore(SQLBaseStore):
else:
raise
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
async def get_profile_avatar_url(self, user_id: str) -> Optional[str]:
try:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"full_user_id": user_id},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
except StoreError as e:
if e.code == 404:
# Fall back to the `user_id` column.
user_localpart = UserID.from_string(user_id).localpart
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
else:
raise
async def create_profile(self, user_localpart: str) -> None:
await self.db_pool.simple_insert(

View file

@ -201,7 +201,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
(
self.get_success(
self.store.get_profile_avatar_url(self.frank.to_string())
)
),
"http://my.server/pic.gif",
)
@ -215,7 +219,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
(
self.get_success(
self.store.get_profile_avatar_url(self.frank.to_string())
)
),
"http://my.server/me.png",
)
@ -229,7 +237,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
(
self.get_success(
self.store.get_profile_avatar_url(self.frank.to_string())
)
),
)
def test_set_my_avatar_if_disabled(self) -> None:
@ -243,7 +255,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
(
self.get_success(
self.store.get_profile_avatar_url(self.frank.to_string())
)
),
"http://my.server/me.png",
)

View file

@ -66,7 +66,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
"http://my.site/here",
(
self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
self.store.get_profile_avatar_url(self.u_frank.to_string())
)
),
)
@ -77,5 +77,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
self.get_success(
self.store.get_profile_avatar_url(self.u_frank.to_string())
)
)