move _update_join_states to be processed by a ScheduledTask

This commit is contained in:
Neil Johnson 2024-04-11 22:19:04 +01:00
parent 89f1092284
commit 74d1334182

View file

@ -20,7 +20,7 @@
#
import logging
import random
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union
from synapse.api.errors import (
AuthError,
@ -31,7 +31,15 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.types import (
JsonDict,
JsonMapping,
Requester,
ScheduledTask,
TaskStatus,
UserID,
create_requester,
)
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@ -43,6 +51,8 @@ logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
UPDATE_JOIN_STATES_TASK_NAME = "update_join_states"
class ProfileHandler:
"""Handles fetching and updating user profile information.
@ -71,6 +81,11 @@ class ProfileHandler:
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
self._task_scheduler = hs.get_task_scheduler()
self._task_scheduler.register_action(
self._update_join_states, UPDATE_JOIN_STATES_TASK_NAME
)
async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict:
target_user = UserID.from_string(user_id)
@ -198,7 +213,13 @@ class ProfileHandler:
)
if propagate:
await self._update_join_states(requester, target_user)
await self._task_scheduler.schedule_task(
UPDATE_JOIN_STATES_TASK_NAME,
params={
"requester": requester.serialize(),
"target_user": target_user.to_string(),
},
)
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user):
@ -291,7 +312,13 @@ class ProfileHandler:
)
if propagate:
await self._update_join_states(requester, target_user)
await self._task_scheduler.schedule_task(
UPDATE_JOIN_STATES_TASK_NAME,
params={
"requester": requester.serialize(),
"target_user": target_user.to_string(),
},
)
@cached()
async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:
@ -393,10 +420,21 @@ class ProfileHandler:
return response
async def _update_join_states(
self, requester: Requester, target_user: UserID
) -> None:
self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
"""Updates join states following a change to display name or avatar
Args:
target_user: The owner of the queried profile. This is a str rather
than a UserID because the task_scheduler requires JSON serializable
parameters
requester: The user querying for the profile.
"""
assert task.params is not None
requester = Requester.deserialize(self.store, task.params["requester"])
target_user = UserID.from_string(task.params["target_user"])
if not self.hs.is_mine(target_user):
return
return TaskStatus.COMPLETE, None, None
await self.request_ratelimiter.ratelimit(requester)
@ -404,7 +442,7 @@ class ProfileHandler:
if requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
return
return TaskStatus.COMPLETE, None, None
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
@ -424,6 +462,7 @@ class ProfileHandler:
logger.warning(
"Failed to update join event for room %s - %s", room_id, str(e)
)
return TaskStatus.COMPLETE, None, None
async def check_profile_query_allowed(
self, target_user: UserID, requester: Optional[UserID] = None