mirror of
https://github.com/element-hq/synapse
synced 2024-07-07 12:32:50 +00:00
Merge branch 'develop' into madlittlemods/msc3575-sliding-sync-e2ee
Conflicts: synapse/handlers/sync.py
This commit is contained in:
commit
514aba5810
1
changelog.d/17147.feature
Normal file
1
changelog.d/17147.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add the ability to auto-accept invites on the behalf of users. See the [`auto_accept_invites`](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#auto-accept-invites) config option for details.
|
1
changelog.d/17176.misc
Normal file
1
changelog.d/17176.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Log exceptions when failing to auto-join new user according to the `auto_join_rooms` option.
|
1
changelog.d/17204.doc
Normal file
1
changelog.d/17204.doc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update OIDC documentation: by default Matrix doesn't query userinfo endpoint, then claims should be put on id_token.
|
1
changelog.d/17211.misc
Normal file
1
changelog.d/17211.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Reduce work of calculating outbound device lists updates.
|
1
changelog.d/17216.misc
Normal file
1
changelog.d/17216.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve performance of calculating device lists changes in `/sync`.
|
1
changelog.d/17219.feature
Normal file
1
changelog.d/17219.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add logging to tasks managed by the task scheduler, showing CPU and database usage.
|
|
@ -525,6 +525,8 @@ oidc_providers:
|
||||||
(`Options > Security > ID Token signature algorithm` and `Options > Security >
|
(`Options > Security > ID Token signature algorithm` and `Options > Security >
|
||||||
Access Token signature algorithm`)
|
Access Token signature algorithm`)
|
||||||
- Scopes: OpenID, Email and Profile
|
- Scopes: OpenID, Email and Profile
|
||||||
|
- Force claims into `id_token`
|
||||||
|
(`Options > Advanced > Force claims to be returned in ID Token`)
|
||||||
- Allowed redirection addresses for login (`Options > Basic > Allowed
|
- Allowed redirection addresses for login (`Options > Basic > Allowed
|
||||||
redirection addresses for login` ) :
|
redirection addresses for login` ) :
|
||||||
`[synapse public baseurl]/_synapse/client/oidc/callback`
|
`[synapse public baseurl]/_synapse/client/oidc/callback`
|
||||||
|
|
|
@ -4595,3 +4595,32 @@ background_updates:
|
||||||
min_batch_size: 10
|
min_batch_size: 10
|
||||||
default_batch_size: 50
|
default_batch_size: 50
|
||||||
```
|
```
|
||||||
|
---
|
||||||
|
## Auto Accept Invites
|
||||||
|
Configuration settings related to automatically accepting invites.
|
||||||
|
|
||||||
|
---
|
||||||
|
### `auto_accept_invites`
|
||||||
|
|
||||||
|
Automatically accepting invites controls whether users are presented with an invite request or if they
|
||||||
|
are instead automatically joined to a room when receiving an invite. Set the `enabled` sub-option to true to
|
||||||
|
enable auto-accepting invites. Defaults to false.
|
||||||
|
This setting has the following sub-options:
|
||||||
|
* `enabled`: Whether to run the auto-accept invites logic. Defaults to false.
|
||||||
|
* `only_for_direct_messages`: Whether invites should be automatically accepted for all room types, or only
|
||||||
|
for direct messages. Defaults to false.
|
||||||
|
* `only_from_local_users`: Whether to only automatically accept invites from users on this homeserver. Defaults to false.
|
||||||
|
* `worker_to_run_on`: Which worker to run this module on. This must match the "worker_name".
|
||||||
|
|
||||||
|
NOTE: Care should be taken not to enable this setting if the `synapse_auto_accept_invite` module is enabled and installed.
|
||||||
|
The two modules will compete to perform the same task and may result in undesired behaviour. For example, multiple join
|
||||||
|
events could be generated from a single invite.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
```yaml
|
||||||
|
auto_accept_invites:
|
||||||
|
enabled: true
|
||||||
|
only_for_direct_messages: true
|
||||||
|
only_from_local_users: true
|
||||||
|
worker_to_run_on: "worker_1"
|
||||||
|
```
|
||||||
|
|
|
@ -68,6 +68,7 @@ from synapse.config._base import format_config_error
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.server import ListenerConfig, ManholeConfig, TCPListenerConfig
|
from synapse.config.server import ListenerConfig, ManholeConfig, TCPListenerConfig
|
||||||
from synapse.crypto import context_factory
|
from synapse.crypto import context_factory
|
||||||
|
from synapse.events.auto_accept_invites import InviteAutoAccepter
|
||||||
from synapse.events.presence_router import load_legacy_presence_router
|
from synapse.events.presence_router import load_legacy_presence_router
|
||||||
from synapse.handlers.auth import load_legacy_password_auth_providers
|
from synapse.handlers.auth import load_legacy_password_auth_providers
|
||||||
from synapse.http.site import SynapseSite
|
from synapse.http.site import SynapseSite
|
||||||
|
@ -582,6 +583,11 @@ async def start(hs: "HomeServer") -> None:
|
||||||
m = module(config, module_api)
|
m = module(config, module_api)
|
||||||
logger.info("Loaded module %s", m)
|
logger.info("Loaded module %s", m)
|
||||||
|
|
||||||
|
if hs.config.auto_accept_invites.enabled:
|
||||||
|
# Start the local auto_accept_invites module.
|
||||||
|
m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
|
||||||
|
logger.info("Loaded local module %s", m)
|
||||||
|
|
||||||
load_legacy_spam_checkers(hs)
|
load_legacy_spam_checkers(hs)
|
||||||
load_legacy_third_party_event_rules(hs)
|
load_legacy_third_party_event_rules(hs)
|
||||||
load_legacy_presence_router(hs)
|
load_legacy_presence_router(hs)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from synapse.config import ( # noqa: F401
|
||||||
api,
|
api,
|
||||||
appservice,
|
appservice,
|
||||||
auth,
|
auth,
|
||||||
|
auto_accept_invites,
|
||||||
background_updates,
|
background_updates,
|
||||||
cache,
|
cache,
|
||||||
captcha,
|
captcha,
|
||||||
|
@ -120,6 +121,7 @@ class RootConfig:
|
||||||
federation: federation.FederationConfig
|
federation: federation.FederationConfig
|
||||||
retention: retention.RetentionConfig
|
retention: retention.RetentionConfig
|
||||||
background_updates: background_updates.BackgroundUpdateConfig
|
background_updates: background_updates.BackgroundUpdateConfig
|
||||||
|
auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig
|
||||||
|
|
||||||
config_classes: List[Type["Config"]] = ...
|
config_classes: List[Type["Config"]] = ...
|
||||||
config_files: List[str]
|
config_files: List[str]
|
||||||
|
|
43
synapse/config/auto_accept_invites.py
Normal file
43
synapse/config/auto_accept_invites.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright (C) 2024 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
#
|
||||||
|
# Originally licensed under the Apache License, Version 2.0:
|
||||||
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||||
|
#
|
||||||
|
# [This file includes modifications made by New Vector Limited]
|
||||||
|
#
|
||||||
|
#
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
class AutoAcceptInvitesConfig(Config):
|
||||||
|
section = "auto_accept_invites"
|
||||||
|
|
||||||
|
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||||
|
auto_accept_invites_config = config.get("auto_accept_invites") or {}
|
||||||
|
|
||||||
|
self.enabled = auto_accept_invites_config.get("enabled", False)
|
||||||
|
|
||||||
|
self.accept_invites_only_for_direct_messages = auto_accept_invites_config.get(
|
||||||
|
"only_for_direct_messages", False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.accept_invites_only_from_local_users = auto_accept_invites_config.get(
|
||||||
|
"only_from_local_users", False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.worker_to_run_on = auto_accept_invites_config.get("worker_to_run_on")
|
|
@ -23,6 +23,7 @@ from .account_validity import AccountValidityConfig
|
||||||
from .api import ApiConfig
|
from .api import ApiConfig
|
||||||
from .appservice import AppServiceConfig
|
from .appservice import AppServiceConfig
|
||||||
from .auth import AuthConfig
|
from .auth import AuthConfig
|
||||||
|
from .auto_accept_invites import AutoAcceptInvitesConfig
|
||||||
from .background_updates import BackgroundUpdateConfig
|
from .background_updates import BackgroundUpdateConfig
|
||||||
from .cache import CacheConfig
|
from .cache import CacheConfig
|
||||||
from .captcha import CaptchaConfig
|
from .captcha import CaptchaConfig
|
||||||
|
@ -105,4 +106,5 @@ class HomeServerConfig(RootConfig):
|
||||||
RedisConfig,
|
RedisConfig,
|
||||||
ExperimentalConfig,
|
ExperimentalConfig,
|
||||||
BackgroundUpdateConfig,
|
BackgroundUpdateConfig,
|
||||||
|
AutoAcceptInvitesConfig,
|
||||||
]
|
]
|
||||||
|
|
196
synapse/events/auto_accept_invites.py
Normal file
196
synapse/events/auto_accept_invites.py
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
# Copyright (C) 2024 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
#
|
||||||
|
# Originally licensed under the Apache License, Version 2.0:
|
||||||
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||||
|
#
|
||||||
|
# [This file includes modifications made by New Vector Limited]
|
||||||
|
#
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
|
||||||
|
from synapse.module_api import EventBase, ModuleApi, run_as_background_process
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InviteAutoAccepter:
|
||||||
|
def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi):
|
||||||
|
# Keep a reference to the Module API.
|
||||||
|
self._api = api
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
if not self._config.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
should_run_on_this_worker = config.worker_to_run_on == self._api.worker_name
|
||||||
|
|
||||||
|
if not should_run_on_this_worker:
|
||||||
|
logger.info(
|
||||||
|
"Not accepting invites on this worker (configured: %r, here: %r)",
|
||||||
|
config.worker_to_run_on,
|
||||||
|
self._api.worker_name,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Accepting invites on this worker (here: %r)", self._api.worker_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the callback.
|
||||||
|
self._api.register_third_party_rules_callbacks(
|
||||||
|
on_new_event=self.on_new_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_new_event(self, event: EventBase, *args: Any) -> None:
|
||||||
|
"""Listens for new events, and if the event is an invite for a local user then
|
||||||
|
automatically accepts it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The incoming event.
|
||||||
|
"""
|
||||||
|
# Check if the event is an invite for a local user.
|
||||||
|
is_invite_for_local_user = (
|
||||||
|
event.type == EventTypes.Member
|
||||||
|
and event.is_state()
|
||||||
|
and event.membership == Membership.INVITE
|
||||||
|
and self._api.is_mine(event.state_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only accept invites for direct messages if the configuration mandates it.
|
||||||
|
is_direct_message = event.content.get("is_direct", False)
|
||||||
|
is_allowed_by_direct_message_rules = (
|
||||||
|
not self._config.accept_invites_only_for_direct_messages
|
||||||
|
or is_direct_message is True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only accept invites from remote users if the configuration mandates it.
|
||||||
|
is_from_local_user = self._api.is_mine(event.sender)
|
||||||
|
is_allowed_by_local_user_rules = (
|
||||||
|
not self._config.accept_invites_only_from_local_users
|
||||||
|
or is_from_local_user is True
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
is_invite_for_local_user
|
||||||
|
and is_allowed_by_direct_message_rules
|
||||||
|
and is_allowed_by_local_user_rules
|
||||||
|
):
|
||||||
|
# Make the user join the room. We run this as a background process to circumvent a race condition
|
||||||
|
# that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
|
||||||
|
run_as_background_process(
|
||||||
|
"retry_make_join",
|
||||||
|
self._retry_make_join,
|
||||||
|
event.state_key,
|
||||||
|
event.state_key,
|
||||||
|
event.room_id,
|
||||||
|
"join",
|
||||||
|
bg_start_span=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_direct_message:
|
||||||
|
# Mark this room as a direct message!
|
||||||
|
await self._mark_room_as_direct_message(
|
||||||
|
event.state_key, event.sender, event.room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _mark_room_as_direct_message(
|
||||||
|
self, user_id: str, dm_user_id: str, room_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Marks a room (`room_id`) as a direct message with the counterparty `dm_user_id`
|
||||||
|
from the perspective of the user `user_id`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: the user for whom the membership is changing
|
||||||
|
dm_user_id: the user performing the membership change
|
||||||
|
room_id: room id of the room the user is invited to
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This is a dict of User IDs to tuples of Room IDs
|
||||||
|
# (get_global will return a frozendict of tuples as it freezes the data,
|
||||||
|
# but we should accept either frozen or unfrozen variants.)
|
||||||
|
# Be careful: we convert the outer frozendict into a dict here,
|
||||||
|
# but the contents of the dict are still frozen (tuples in lieu of lists,
|
||||||
|
# etc.)
|
||||||
|
dm_map: Dict[str, Tuple[str, ...]] = dict(
|
||||||
|
await self._api.account_data_manager.get_global(
|
||||||
|
user_id, AccountDataTypes.DIRECT
|
||||||
|
)
|
||||||
|
or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
if dm_user_id not in dm_map:
|
||||||
|
dm_map[dm_user_id] = (room_id,)
|
||||||
|
else:
|
||||||
|
dm_rooms_for_user = dm_map[dm_user_id]
|
||||||
|
assert isinstance(dm_rooms_for_user, (tuple, list))
|
||||||
|
|
||||||
|
dm_map[dm_user_id] = tuple(dm_rooms_for_user) + (room_id,)
|
||||||
|
|
||||||
|
await self._api.account_data_manager.put_global(
|
||||||
|
user_id, AccountDataTypes.DIRECT, dm_map
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _retry_make_join(
|
||||||
|
self, sender: str, target: str, room_id: str, new_membership: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
A function to retry sending the `make_join` request with an increasing backoff. This is
|
||||||
|
implemented to work around a race condition when receiving invites over federation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sender: the user performing the membership change
|
||||||
|
target: the user for whom the membership is changing
|
||||||
|
room_id: room id of the room to join to
|
||||||
|
new_membership: the type of membership event (in this case will be "join")
|
||||||
|
"""
|
||||||
|
|
||||||
|
sleep = 0
|
||||||
|
retries = 0
|
||||||
|
join_event = None
|
||||||
|
|
||||||
|
while retries < 5:
|
||||||
|
try:
|
||||||
|
await self._api.sleep(sleep)
|
||||||
|
join_event = await self._api.update_room_membership(
|
||||||
|
sender=sender,
|
||||||
|
target=target,
|
||||||
|
room_id=room_id,
|
||||||
|
new_membership=new_membership,
|
||||||
|
)
|
||||||
|
except SynapseError as e:
|
||||||
|
if e.code == HTTPStatus.FORBIDDEN:
|
||||||
|
logger.debug(
|
||||||
|
f"Update_room_membership was forbidden. This can sometimes be expected for remote invites. Exception: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
f"Update_room_membership raised the following unexpected (SynapseError) exception: {e}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(
|
||||||
|
f"Update_room_membership raised the following unexpected exception: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sleep = 2**retries
|
||||||
|
retries += 1
|
||||||
|
|
||||||
|
if join_event is not None:
|
||||||
|
break
|
|
@ -159,20 +159,32 @@ class DeviceWorkerHandler:
|
||||||
|
|
||||||
@cancellable
|
@cancellable
|
||||||
async def get_device_changes_in_shared_rooms(
|
async def get_device_changes_in_shared_rooms(
|
||||||
self, user_id: str, room_ids: StrCollection, from_token: StreamToken
|
self,
|
||||||
|
user_id: str,
|
||||||
|
room_ids: StrCollection,
|
||||||
|
from_token: StreamToken,
|
||||||
|
now_token: Optional[StreamToken] = None,
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Get the set of users whose devices have changed who share a room with
|
"""Get the set of users whose devices have changed who share a room with
|
||||||
the given user.
|
the given user.
|
||||||
"""
|
"""
|
||||||
|
now_device_lists_key = self.store.get_device_stream_token()
|
||||||
|
if now_token:
|
||||||
|
now_device_lists_key = now_token.device_list_key
|
||||||
|
|
||||||
changed_users = await self.store.get_device_list_changes_in_rooms(
|
changed_users = await self.store.get_device_list_changes_in_rooms(
|
||||||
room_ids, from_token.device_list_key
|
room_ids,
|
||||||
|
from_token.device_list_key,
|
||||||
|
now_device_lists_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
if changed_users is not None:
|
if changed_users is not None:
|
||||||
# We also check if the given user has changed their device. If
|
# We also check if the given user has changed their device. If
|
||||||
# they're in no rooms then the above query won't include them.
|
# they're in no rooms then the above query won't include them.
|
||||||
changed = await self.store.get_users_whose_devices_changed(
|
changed = await self.store.get_users_whose_devices_changed(
|
||||||
from_token.device_list_key, [user_id]
|
from_token.device_list_key,
|
||||||
|
[user_id],
|
||||||
|
to_key=now_device_lists_key,
|
||||||
)
|
)
|
||||||
changed_users.update(changed)
|
changed_users.update(changed)
|
||||||
return changed_users
|
return changed_users
|
||||||
|
@ -190,7 +202,9 @@ class DeviceWorkerHandler:
|
||||||
tracked_users.add(user_id)
|
tracked_users.add(user_id)
|
||||||
|
|
||||||
changed = await self.store.get_users_whose_devices_changed(
|
changed = await self.store.get_users_whose_devices_changed(
|
||||||
from_token.device_list_key, tracked_users
|
from_token.device_list_key,
|
||||||
|
tracked_users,
|
||||||
|
to_key=now_device_lists_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
return changed
|
return changed
|
||||||
|
@ -892,6 +906,13 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
context=opentracing_context,
|
context=opentracing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.store.mark_redundant_device_lists_pokes(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
room_id=room_id,
|
||||||
|
converted_upto_stream_id=stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Notify replication that we've updated the device list stream.
|
# Notify replication that we've updated the device list stream.
|
||||||
self.notifier.notify_replication()
|
self.notifier.notify_replication()
|
||||||
|
|
||||||
|
|
|
@ -590,7 +590,7 @@ class RegistrationHandler:
|
||||||
# moving away from bare excepts is a good thing to do.
|
# moving away from bare excepts is a good thing to do.
|
||||||
logger.error("Failed to join new user to %r: %r", r, e)
|
logger.error("Failed to join new user to %r: %r", r, e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to join new user to %r: %r", r, e)
|
logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
|
||||||
|
|
||||||
async def _auto_join_rooms(self, user_id: str) -> None:
|
async def _auto_join_rooms(self, user_id: str) -> None:
|
||||||
"""Automatically joins users to auto join rooms - creating the room in the first place
|
"""Automatically joins users to auto join rooms - creating the room in the first place
|
||||||
|
|
|
@ -817,7 +817,7 @@ class SsoHandler:
|
||||||
server_name = profile["avatar_url"].split("/")[-2]
|
server_name = profile["avatar_url"].split("/")[-2]
|
||||||
media_id = profile["avatar_url"].split("/")[-1]
|
media_id = profile["avatar_url"].split("/")[-1]
|
||||||
if self._is_mine_server_name(server_name):
|
if self._is_mine_server_name(server_name):
|
||||||
media = await self._media_repo.store.get_local_media(media_id)
|
media = await self._media_repo.store.get_local_media(media_id) # type: ignore[has-type]
|
||||||
if media is not None and upload_name == media.upload_name:
|
if media is not None and upload_name == media.upload_name:
|
||||||
logger.info("skipping saving the user avatar")
|
logger.info("skipping saving the user avatar")
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -2113,38 +2113,14 @@ class SyncHandler:
|
||||||
|
|
||||||
# Step 1a, check for changes in devices of users we share a room
|
# Step 1a, check for changes in devices of users we share a room
|
||||||
# with
|
# with
|
||||||
#
|
users_that_have_changed = (
|
||||||
# We do this in two different ways depending on what we have cached.
|
await self._device_handler.get_device_changes_in_shared_rooms(
|
||||||
# If we already have a list of all the user that have changed since
|
user_id,
|
||||||
# the last sync then it's likely more efficient to compare the rooms
|
joined_room_ids,
|
||||||
# they're in with the rooms the syncing user is in.
|
from_token=since_token,
|
||||||
#
|
now_token=sync_result_builder.now_token,
|
||||||
# If we don't have that info cached then we get all the users that
|
|
||||||
# share a room with our user and check if those users have changed.
|
|
||||||
cache_result = self.store.get_cached_device_list_changes(
|
|
||||||
since_token.device_list_key
|
|
||||||
)
|
|
||||||
if cache_result.hit:
|
|
||||||
changed_users = cache_result.entities
|
|
||||||
|
|
||||||
result = await self.store.get_rooms_for_users(changed_users)
|
|
||||||
|
|
||||||
for changed_user_id, entries in result.items():
|
|
||||||
# Check if the changed user shares any rooms with the user,
|
|
||||||
# or if the changed user is the syncing user (as we always
|
|
||||||
# want to include device list updates of their own devices).
|
|
||||||
if user_id == changed_user_id or any(
|
|
||||||
rid in joined_room_ids for rid in entries
|
|
||||||
):
|
|
||||||
users_that_have_changed.add(changed_user_id)
|
|
||||||
else:
|
|
||||||
users_that_have_changed = (
|
|
||||||
await self._device_handler.get_device_changes_in_shared_rooms(
|
|
||||||
user_id,
|
|
||||||
joined_room_ids,
|
|
||||||
from_token=since_token,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Step 1b, check for newly joined rooms
|
# Step 1b, check for newly joined rooms
|
||||||
for room_id in newly_joined_rooms:
|
for room_id in newly_joined_rooms:
|
||||||
|
|
|
@ -112,6 +112,15 @@ class ReplicationDataHandler:
|
||||||
token: stream token for this batch of rows
|
token: stream token for this batch of rows
|
||||||
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
||||||
"""
|
"""
|
||||||
|
all_room_ids: Set[str] = set()
|
||||||
|
if stream_name == DeviceListsStream.NAME:
|
||||||
|
if any(row.entity.startswith("@") and not row.is_signature for row in rows):
|
||||||
|
prev_token = self.store.get_device_stream_token()
|
||||||
|
all_room_ids = await self.store.get_all_device_list_changes(
|
||||||
|
prev_token, token
|
||||||
|
)
|
||||||
|
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
|
||||||
|
|
||||||
self.store.process_replication_rows(stream_name, instance_name, token, rows)
|
self.store.process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
# NOTE: this must be called after process_replication_rows to ensure any
|
# NOTE: this must be called after process_replication_rows to ensure any
|
||||||
# cache invalidations are first handled before any stream ID advances.
|
# cache invalidations are first handled before any stream ID advances.
|
||||||
|
@ -146,12 +155,6 @@ class ReplicationDataHandler:
|
||||||
StreamKeyType.TO_DEVICE, token, users=entities
|
StreamKeyType.TO_DEVICE, token, users=entities
|
||||||
)
|
)
|
||||||
elif stream_name == DeviceListsStream.NAME:
|
elif stream_name == DeviceListsStream.NAME:
|
||||||
all_room_ids: Set[str] = set()
|
|
||||||
for row in rows:
|
|
||||||
if row.entity.startswith("@") and not row.is_signature:
|
|
||||||
room_ids = await self.store.get_rooms_for_user(row.entity)
|
|
||||||
all_room_ids.update(room_ids)
|
|
||||||
|
|
||||||
# `all_room_ids` can be large, so let's wake up those streams in batches
|
# `all_room_ids` can be large, so let's wake up those streams in batches
|
||||||
for batched_room_ids in batch_iter(all_room_ids, 100):
|
for batched_room_ids in batch_iter(all_room_ids, 100):
|
||||||
self.notifier.on_new_event(
|
self.notifier.on_new_event(
|
||||||
|
|
|
@ -70,10 +70,7 @@ from synapse.types import (
|
||||||
from synapse.util import json_decoder, json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.stream_change_cache import (
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
AllEntitiesChangedResult,
|
|
||||||
StreamChangeCache,
|
|
||||||
)
|
|
||||||
from synapse.util.cancellation import cancellable
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.stringutils import shortstr
|
from synapse.util.stringutils import shortstr
|
||||||
|
@ -132,6 +129,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
prefilled_cache=device_list_prefill,
|
prefilled_cache=device_list_prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
|
||||||
|
db_conn,
|
||||||
|
"device_lists_changes_in_room",
|
||||||
|
entity_column="room_id",
|
||||||
|
stream_column="stream_id",
|
||||||
|
max_value=device_list_max,
|
||||||
|
limit=10000,
|
||||||
|
)
|
||||||
|
self._device_list_room_stream_cache = StreamChangeCache(
|
||||||
|
"DeviceListRoomStreamChangeCache",
|
||||||
|
min_device_list_room_id,
|
||||||
|
prefilled_cache=device_list_room_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
user_signature_stream_prefill,
|
user_signature_stream_prefill,
|
||||||
user_signature_stream_list_id,
|
user_signature_stream_list_id,
|
||||||
|
@ -209,6 +220,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
row.entity, token
|
row.entity, token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def device_lists_in_rooms_have_changed(
|
||||||
|
self, room_ids: StrCollection, token: int
|
||||||
|
) -> None:
|
||||||
|
"Record that device lists have changed in rooms"
|
||||||
|
for room_id in room_ids:
|
||||||
|
self._device_list_room_stream_cache.entity_has_changed(room_id, token)
|
||||||
|
|
||||||
def get_device_stream_token(self) -> int:
|
def get_device_stream_token(self) -> int:
|
||||||
return self._device_list_id_gen.get_current_token()
|
return self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
|
@ -832,16 +850,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
)
|
)
|
||||||
return {device[0]: db_to_json(device[1]) for device in devices}
|
return {device[0]: db_to_json(device[1]) for device in devices}
|
||||||
|
|
||||||
def get_cached_device_list_changes(
|
|
||||||
self,
|
|
||||||
from_key: int,
|
|
||||||
) -> AllEntitiesChangedResult:
|
|
||||||
"""Get set of users whose devices have changed since `from_key`, or None
|
|
||||||
if that information is not in our cache.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
|
||||||
|
|
||||||
@cancellable
|
@cancellable
|
||||||
async def get_all_devices_changed(
|
async def get_all_devices_changed(
|
||||||
self,
|
self,
|
||||||
|
@ -1457,7 +1465,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
|
|
||||||
@cancellable
|
@cancellable
|
||||||
async def get_device_list_changes_in_rooms(
|
async def get_device_list_changes_in_rooms(
|
||||||
self, room_ids: Collection[str], from_id: int
|
self, room_ids: Collection[str], from_id: int, to_id: int
|
||||||
) -> Optional[Set[str]]:
|
) -> Optional[Set[str]]:
|
||||||
"""Return the set of users whose devices have changed in the given rooms
|
"""Return the set of users whose devices have changed in the given rooms
|
||||||
since the given stream ID.
|
since the given stream ID.
|
||||||
|
@ -1473,9 +1481,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
if min_stream_id > from_id:
|
if min_stream_id > from_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
|
||||||
|
room_ids, from_id
|
||||||
|
)
|
||||||
|
if not changed_room_ids:
|
||||||
|
return set()
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT DISTINCT user_id FROM device_lists_changes_in_room
|
SELECT DISTINCT user_id FROM device_lists_changes_in_room
|
||||||
WHERE {clause} AND stream_id >= ?
|
WHERE {clause} AND stream_id > ? AND stream_id <= ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_device_list_changes_in_rooms_txn(
|
def _get_device_list_changes_in_rooms_txn(
|
||||||
|
@ -1487,11 +1501,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
return {user_id for user_id, in txn}
|
return {user_id for user_id, in txn}
|
||||||
|
|
||||||
changes = set()
|
changes = set()
|
||||||
for chunk in batch_iter(room_ids, 1000):
|
for chunk in batch_iter(changed_room_ids, 1000):
|
||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
self.database_engine, "room_id", chunk
|
self.database_engine, "room_id", chunk
|
||||||
)
|
)
|
||||||
args.append(from_id)
|
args.append(from_id)
|
||||||
|
args.append(to_id)
|
||||||
|
|
||||||
changes |= await self.db_pool.runInteraction(
|
changes |= await self.db_pool.runInteraction(
|
||||||
"get_device_list_changes_in_rooms",
|
"get_device_list_changes_in_rooms",
|
||||||
|
@ -1502,6 +1517,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
|
|
||||||
return changes
|
return changes
|
||||||
|
|
||||||
|
async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
|
||||||
|
"""Return the set of rooms where devices have changed since the given
|
||||||
|
stream ID.
|
||||||
|
|
||||||
|
Will raise an exception if the given stream ID is too old.
|
||||||
|
"""
|
||||||
|
|
||||||
|
min_stream_id = await self._get_min_device_lists_changes_in_room()
|
||||||
|
|
||||||
|
if min_stream_id > from_id:
|
||||||
|
raise Exception("stream ID is too old")
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT DISTINCT room_id FROM device_lists_changes_in_room
|
||||||
|
WHERE stream_id > ? AND stream_id <= ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_all_device_list_changes_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Set[str]:
|
||||||
|
txn.execute(sql, (from_id, to_id))
|
||||||
|
return {room_id for room_id, in txn}
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_all_device_list_changes",
|
||||||
|
_get_all_device_list_changes_txn,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_device_list_changes_in_room(
|
async def get_device_list_changes_in_room(
|
||||||
self, room_id: str, min_stream_id: int
|
self, room_id: str, min_stream_id: int
|
||||||
) -> Collection[Tuple[str, str]]:
|
) -> Collection[Tuple[str, str]]:
|
||||||
|
@ -1962,8 +2005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
async def add_device_change_to_streams(
|
async def add_device_change_to_streams(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
device_ids: Collection[str],
|
device_ids: StrCollection,
|
||||||
room_ids: Collection[str],
|
room_ids: StrCollection,
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
"""Persist that a user's devices have been updated, and which hosts
|
"""Persist that a user's devices have been updated, and which hosts
|
||||||
(if any) should be poked.
|
(if any) should be poked.
|
||||||
|
@ -2118,12 +2161,36 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def mark_redundant_device_lists_pokes(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
room_id: str,
|
||||||
|
converted_upto_stream_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""If we've calculated the outbound pokes for a given room/device list
|
||||||
|
update, mark any subsequent changes as already converted"""
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
UPDATE device_lists_changes_in_room
|
||||||
|
SET converted_to_destinations = true
|
||||||
|
WHERE stream_id > ? AND user_id = ? AND device_id = ?
|
||||||
|
AND room_id = ? AND NOT converted_to_destinations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mark_redundant_device_lists_pokes_txn(txn: LoggingTransaction) -> None:
|
||||||
|
txn.execute(sql, (converted_upto_stream_id, user_id, device_id, room_id))
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"mark_redundant_device_lists_pokes", mark_redundant_device_lists_pokes_txn
|
||||||
|
)
|
||||||
|
|
||||||
def _add_device_outbound_room_poke_txn(
|
def _add_device_outbound_room_poke_txn(
|
||||||
self,
|
self,
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
device_ids: Iterable[str],
|
device_ids: StrCollection,
|
||||||
room_ids: Collection[str],
|
room_ids: StrCollection,
|
||||||
stream_ids: List[int],
|
stream_ids: List[int],
|
||||||
context: Dict[str, str],
|
context: Dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -2161,6 +2228,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_uncoverted_outbound_room_pokes(
|
async def get_uncoverted_outbound_room_pokes(
|
||||||
self, start_stream_id: int, start_room_id: str, limit: int = 10
|
self, start_stream_id: int, start_room_id: str, limit: int = 10
|
||||||
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
|
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
|
||||||
|
|
|
@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging.context import nested_logging_context
|
from synapse.logging.context import (
|
||||||
|
ContextResourceUsage,
|
||||||
|
LoggingContext,
|
||||||
|
nested_logging_context,
|
||||||
|
set_current_context,
|
||||||
|
)
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import (
|
from synapse.metrics.background_process_metrics import (
|
||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
|
@ -81,6 +86,8 @@ class TaskScheduler:
|
||||||
MAX_CONCURRENT_RUNNING_TASKS = 5
|
MAX_CONCURRENT_RUNNING_TASKS = 5
|
||||||
# Time from the last task update after which we will log a warning
|
# Time from the last task update after which we will log a warning
|
||||||
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
|
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
|
||||||
|
# Report a running task's status and usage every so often.
|
||||||
|
OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._hs = hs
|
self._hs = hs
|
||||||
|
@ -346,6 +353,33 @@ class TaskScheduler:
|
||||||
assert task.id not in self._running_tasks
|
assert task.id not in self._running_tasks
|
||||||
await self._store.delete_scheduled_task(task.id)
|
await self._store.delete_scheduled_task(task.id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _log_task_usage(
|
||||||
|
state: str, task: ScheduledTask, usage: ContextResourceUsage, active_time: float
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log a line describing the state and usage of a task.
|
||||||
|
The log line is inspired by / a copy of the request log line format,
|
||||||
|
but with irrelevant fields removed.
|
||||||
|
|
||||||
|
active_time: Time that the task has been running for, in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Task %s: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||||
|
" [%d dbevts] %r, %r",
|
||||||
|
state,
|
||||||
|
active_time,
|
||||||
|
usage.ru_utime,
|
||||||
|
usage.ru_stime,
|
||||||
|
usage.db_sched_duration_sec,
|
||||||
|
usage.db_txn_duration_sec,
|
||||||
|
int(usage.db_txn_count),
|
||||||
|
usage.evt_db_fetch_count,
|
||||||
|
task.resource_id,
|
||||||
|
task.params,
|
||||||
|
)
|
||||||
|
|
||||||
async def _launch_task(self, task: ScheduledTask) -> None:
|
async def _launch_task(self, task: ScheduledTask) -> None:
|
||||||
"""Launch a scheduled task now.
|
"""Launch a scheduled task now.
|
||||||
|
|
||||||
|
@ -360,8 +394,32 @@ class TaskScheduler:
|
||||||
)
|
)
|
||||||
function = self._actions[task.action]
|
function = self._actions[task.action]
|
||||||
|
|
||||||
|
def _occasional_report(
|
||||||
|
task_log_context: LoggingContext, start_time: float
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper to log a 'Task continuing' line every so often.
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_time = self._clock.time()
|
||||||
|
calling_context = set_current_context(task_log_context)
|
||||||
|
try:
|
||||||
|
usage = task_log_context.get_resource_usage()
|
||||||
|
TaskScheduler._log_task_usage(
|
||||||
|
"continuing", task, usage, current_time - start_time
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
set_current_context(calling_context)
|
||||||
|
|
||||||
async def wrapper() -> None:
|
async def wrapper() -> None:
|
||||||
with nested_logging_context(task.id):
|
with nested_logging_context(task.id) as log_context:
|
||||||
|
start_time = self._clock.time()
|
||||||
|
occasional_status_call = self._clock.looping_call(
|
||||||
|
_occasional_report,
|
||||||
|
TaskScheduler.OCCASIONAL_REPORT_INTERVAL_MS,
|
||||||
|
log_context,
|
||||||
|
start_time,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
(status, result, error) = await function(task)
|
(status, result, error) = await function(task)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -383,6 +441,13 @@ class TaskScheduler:
|
||||||
)
|
)
|
||||||
self._running_tasks.remove(task.id)
|
self._running_tasks.remove(task.id)
|
||||||
|
|
||||||
|
current_time = self._clock.time()
|
||||||
|
usage = log_context.get_resource_usage()
|
||||||
|
TaskScheduler._log_task_usage(
|
||||||
|
status.value, task, usage, current_time - start_time
|
||||||
|
)
|
||||||
|
occasional_status_call.stop()
|
||||||
|
|
||||||
# Try launch a new task since we've finished with this one.
|
# Try launch a new task since we've finished with this one.
|
||||||
self._clock.call_later(0.1, self._launch_scheduled_tasks)
|
self._clock.call_later(0.1, self._launch_scheduled_tasks)
|
||||||
|
|
||||||
|
|
657
tests/events/test_auto_accept_invites.py
Normal file
657
tests/events/test_auto_accept_invites.py
Normal file
|
@ -0,0 +1,657 @@
|
||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
# Copyright (C) 2024 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
#
|
||||||
|
# Originally licensed under the Apache License, Version 2.0:
|
||||||
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||||
|
#
|
||||||
|
# [This file includes modifications made by New Vector Limited]
|
||||||
|
#
|
||||||
|
#
|
||||||
|
import asyncio
|
||||||
|
from asyncio import Future
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, cast
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
|
||||||
|
from synapse.events.auto_accept_invites import InviteAutoAccepter
|
||||||
|
from synapse.federation.federation_base import event_from_pdu_json
|
||||||
|
from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
|
from synapse.rest import admin
|
||||||
|
from synapse.rest.client import login, room
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import StreamToken, create_requester
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests.handlers.test_sync import generate_sync_config
|
||||||
|
from tests.unittest import (
|
||||||
|
FederatingHomeserverTestCase,
|
||||||
|
HomeserverTestCase,
|
||||||
|
TestCase,
|
||||||
|
override_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
|
||||||
|
"""
|
||||||
|
Integration test cases for auto-accepting invites.
|
||||||
|
"""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
hs = self.setup_test_homeserver()
|
||||||
|
self.handler = hs.get_federation_handler()
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
return hs
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.sync_handler = self.hs.get_sync_handler()
|
||||||
|
self.module_api = hs.get_module_api()
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
[False],
|
||||||
|
[True],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"auto_accept_invites": {
|
||||||
|
"enabled": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_auto_accept_invites(self, direct_room: bool) -> None:
|
||||||
|
"""Test that a user automatically joins a room when invited, if the
|
||||||
|
module is enabled.
|
||||||
|
"""
|
||||||
|
# A local user who sends an invite
|
||||||
|
inviting_user_id = self.register_user("inviter", "pass")
|
||||||
|
inviting_user_tok = self.login("inviter", "pass")
|
||||||
|
|
||||||
|
# A local user who receives an invite
|
||||||
|
invited_user_id = self.register_user("invitee", "pass")
|
||||||
|
self.login("invitee", "pass")
|
||||||
|
|
||||||
|
# Create a room and send an invite to the other user
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
inviting_user_id,
|
||||||
|
is_public=False,
|
||||||
|
tok=inviting_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.invite(
|
||||||
|
room_id,
|
||||||
|
inviting_user_id,
|
||||||
|
invited_user_id,
|
||||||
|
tok=inviting_user_tok,
|
||||||
|
extra_data={"is_direct": direct_room},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the invite receiving user has automatically joined the room when syncing
|
||||||
|
join_updates, _ = sync_join(self, invited_user_id)
|
||||||
|
self.assertEqual(len(join_updates), 1)
|
||||||
|
|
||||||
|
join_update: JoinedSyncResult = join_updates[0]
|
||||||
|
self.assertEqual(join_update.room_id, room_id)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"auto_accept_invites": {
|
||||||
|
"enabled": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_module_not_enabled(self) -> None:
|
||||||
|
"""Test that a user does not automatically join a room when invited,
|
||||||
|
if the module is not enabled.
|
||||||
|
"""
|
||||||
|
# A local user who sends an invite
|
||||||
|
inviting_user_id = self.register_user("inviter", "pass")
|
||||||
|
inviting_user_tok = self.login("inviter", "pass")
|
||||||
|
|
||||||
|
# A local user who receives an invite
|
||||||
|
invited_user_id = self.register_user("invitee", "pass")
|
||||||
|
self.login("invitee", "pass")
|
||||||
|
|
||||||
|
# Create a room and send an invite to the other user
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
inviting_user_id, is_public=False, tok=inviting_user_tok
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.invite(
|
||||||
|
room_id,
|
||||||
|
inviting_user_id,
|
||||||
|
invited_user_id,
|
||||||
|
tok=inviting_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the invite receiving user has not automatically joined the room when syncing
|
||||||
|
join_updates, _ = sync_join(self, invited_user_id)
|
||||||
|
self.assertEqual(len(join_updates), 0)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"auto_accept_invites": {
|
||||||
|
"enabled": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_invite_from_remote_user(self) -> None:
|
||||||
|
"""Test that an invite from a remote user results in the invited user
|
||||||
|
automatically joining the room.
|
||||||
|
"""
|
||||||
|
# A remote user who sends the invite
|
||||||
|
remote_server = "otherserver"
|
||||||
|
remote_user = "@otheruser:" + remote_server
|
||||||
|
|
||||||
|
# A local user who creates the room
|
||||||
|
creator_user_id = self.register_user("creator", "pass")
|
||||||
|
creator_user_tok = self.login("creator", "pass")
|
||||||
|
|
||||||
|
# A local user who receives an invite
|
||||||
|
invited_user_id = self.register_user("invitee", "pass")
|
||||||
|
self.login("invitee", "pass")
|
||||||
|
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
room_creator=creator_user_id, tok=creator_user_tok
|
||||||
|
)
|
||||||
|
room_version = self.get_success(self.store.get_room_version(room_id))
|
||||||
|
|
||||||
|
invite_event = event_from_pdu_json(
|
||||||
|
{
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"content": {"membership": "invite"},
|
||||||
|
"room_id": room_id,
|
||||||
|
"sender": remote_user,
|
||||||
|
"state_key": invited_user_id,
|
||||||
|
"depth": 32,
|
||||||
|
"prev_events": [],
|
||||||
|
"auth_events": [],
|
||||||
|
"origin_server_ts": self.clock.time_msec(),
|
||||||
|
},
|
||||||
|
room_version,
|
||||||
|
)
|
||||||
|
self.get_success(
|
||||||
|
self.handler.on_invite_request(
|
||||||
|
remote_server,
|
||||||
|
invite_event,
|
||||||
|
invite_event.room_version,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the invite receiving user has automatically joined the room when syncing
|
||||||
|
join_updates, _ = sync_join(self, invited_user_id)
|
||||||
|
self.assertEqual(len(join_updates), 1)
|
||||||
|
|
||||||
|
join_update: JoinedSyncResult = join_updates[0]
|
||||||
|
self.assertEqual(join_update.room_id, room_id)
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
[False, False],
|
||||||
|
[True, True],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"auto_accept_invites": {
|
||||||
|
"enabled": True,
|
||||||
|
"only_for_direct_messages": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_accept_invite_direct_message(
|
||||||
|
self,
|
||||||
|
direct_room: bool,
|
||||||
|
expect_auto_join: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Tests that, if the module is configured to only accept DM invites, invites to DM rooms are still
|
||||||
|
automatically accepted. Otherwise they are rejected.
|
||||||
|
"""
|
||||||
|
# A local user who sends an invite
|
||||||
|
inviting_user_id = self.register_user("inviter", "pass")
|
||||||
|
inviting_user_tok = self.login("inviter", "pass")
|
||||||
|
|
||||||
|
# A local user who receives an invite
|
||||||
|
invited_user_id = self.register_user("invitee", "pass")
|
||||||
|
self.login("invitee", "pass")
|
||||||
|
|
||||||
|
# Create a room and send an invite to the other user
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
inviting_user_id,
|
||||||
|
is_public=False,
|
||||||
|
tok=inviting_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.helper.invite(
|
||||||
|
room_id,
|
||||||
|
inviting_user_id,
|
||||||
|
invited_user_id,
|
||||||
|
tok=inviting_user_tok,
|
||||||
|
extra_data={"is_direct": direct_room},
|
||||||
|
)
|
||||||
|
|
||||||
|
if expect_auto_join:
|
||||||
|
# Check that the invite receiving user has automatically joined the room when syncing
|
||||||
|
join_updates, _ = sync_join(self, invited_user_id)
|
||||||
|
self.assertEqual(len(join_updates), 1)
|
||||||
|
|
||||||
|
join_update: JoinedSyncResult = join_updates[0]
|
||||||
|
self.assertEqual(join_update.room_id, room_id)
|
||||||
|
else:
|
||||||
|
# Check that the invite receiving user has not automatically joined the room when syncing
|
||||||
|
join_updates, _ = sync_join(self, invited_user_id)
|
||||||
|
self.assertEqual(len(join_updates), 0)
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
[False, True],
|
||||||
|
[True, False],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"auto_accept_invites": {
|
||||||
|
"enabled": True,
|
||||||
|
"only_from_local_users": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_accept_invite_local_user(
|
||||||
|
self, remote_inviter: bool, expect_auto_join: bool
|
||||||
|
) -> None:
|
||||||
|
"""Tests that, if the module is configured to only accept invites from local users, invites
|
||||||
|
from local users are still automatically accepted. Otherwise they are rejected.
|
||||||
|
"""
|
||||||
|
# A local user who sends an invite
|
||||||
|
creator_user_id = self.register_user("inviter", "pass")
|
||||||
|
creator_user_tok = self.login("inviter", "pass")
|
||||||
|
|
||||||
|
# A local user who receives an invite
|
||||||
|
invited_user_id = self.register_user("invitee", "pass")
|
||||||
|
self.login("invitee", "pass")
|
||||||
|
|
||||||
|
# Create a room and send an invite to the other user
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
creator_user_id, is_public=False, tok=creator_user_tok
|
||||||
|
)
|
||||||
|
|
||||||
|
if remote_inviter:
|
||||||
|
room_version = self.get_success(self.store.get_room_version(room_id))
|
||||||
|
|
||||||
|
# A remote user who sends the invite
|
||||||
|
remote_server = "otherserver"
|
||||||
|
remote_user = "@otheruser:" + remote_server
|
||||||
|
|
||||||
|
invite_event = event_from_pdu_json(
|
||||||
|
{
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"content": {"membership": "invite"},
|
||||||
|
"room_id": room_id,
|
||||||
|
"sender": remote_user,
|
||||||
|
"state_key": invited_user_id,
|
||||||
|
"depth": 32,
|
||||||
|
"prev_events": [],
|
||||||
|
"auth_events": [],
|
||||||
|
"origin_server_ts": self.clock.time_msec(),
|
||||||
|
},
|
||||||
|
room_version,
|
||||||
|
)
|
||||||
|
self.get_success(
|
||||||
|
self.handler.on_invite_request(
|
||||||
|
remote_server,
|
||||||
|
invite_event,
|
||||||
|
invite_event.room_version,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.helper.invite(
|
||||||
|
room_id,
|
||||||
|
creator_user_id,
|
||||||
|
invited_user_id,
|
||||||
|
tok=creator_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
if expect_auto_join:
|
||||||
|
# Check that the invite receiving user has automatically joined the room when syncing
|
||||||
|
join_updates, _ = sync_join(self, invited_user_id)
|
||||||
|
self.assertEqual(len(join_updates), 1)
|
||||||
|
|
||||||
|
join_update: JoinedSyncResult = join_updates[0]
|
||||||
|
self.assertEqual(join_update.room_id, room_id)
|
||||||
|
else:
|
||||||
|
# Check that the invite receiving user has not automatically joined the room when syncing
|
||||||
|
join_updates, _ = sync_join(self, invited_user_id)
|
||||||
|
self.assertEqual(len(join_updates), 0)
|
||||||
|
|
||||||
|
|
||||||
|
_request_key = 0
|
||||||
|
|
||||||
|
|
||||||
|
def generate_request_key() -> SyncRequestKey:
|
||||||
|
global _request_key
|
||||||
|
_request_key += 1
|
||||||
|
return ("request_key", _request_key)
|
||||||
|
|
||||||
|
|
||||||
|
def sync_join(
|
||||||
|
testcase: HomeserverTestCase,
|
||||||
|
user_id: str,
|
||||||
|
since_token: Optional[StreamToken] = None,
|
||||||
|
) -> Tuple[List[JoinedSyncResult], StreamToken]:
|
||||||
|
"""Perform a sync request for the given user and return the user join updates
|
||||||
|
they've received, as well as the next_batch token.
|
||||||
|
|
||||||
|
This method assumes testcase.sync_handler points to the homeserver's sync handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
testcase: The testcase that is currently being run.
|
||||||
|
user_id: The ID of the user to generate a sync response for.
|
||||||
|
since_token: An optional token to indicate from at what point to sync from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing a list of join updates, and the sync response's
|
||||||
|
next_batch token.
|
||||||
|
"""
|
||||||
|
requester = create_requester(user_id)
|
||||||
|
sync_config = generate_sync_config(requester.user.to_string())
|
||||||
|
sync_result = testcase.get_success(
|
||||||
|
testcase.hs.get_sync_handler().wait_for_sync_for_user(
|
||||||
|
requester,
|
||||||
|
sync_config,
|
||||||
|
SyncVersion.SYNC_V2,
|
||||||
|
generate_request_key(),
|
||||||
|
since_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sync_result.joined, sync_result.next_batch
|
||||||
|
|
||||||
|
|
||||||
|
class InviteAutoAccepterInternalTestCase(TestCase):
|
||||||
|
"""
|
||||||
|
Test cases which exercise the internals of the InviteAutoAccepter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.module = create_module()
|
||||||
|
self.user_id = "@peter:test"
|
||||||
|
self.invitee = "@lesley:test"
|
||||||
|
self.remote_invitee = "@thomas:remote"
|
||||||
|
|
||||||
|
# We know our module API is a mock, but mypy doesn't.
|
||||||
|
self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment]
|
||||||
|
|
||||||
|
async def test_accept_invite_with_failures(self) -> None:
|
||||||
|
"""Tests that receiving an invite for a local user makes the module attempt to
|
||||||
|
make the invitee join the room. This test verifies that it works if the call to
|
||||||
|
update membership returns exceptions before successfully completing and returning an event.
|
||||||
|
"""
|
||||||
|
invite = MockEvent(
|
||||||
|
sender="@inviter:test",
|
||||||
|
state_key="@invitee:test",
|
||||||
|
type="m.room.member",
|
||||||
|
content={"membership": "invite"},
|
||||||
|
)
|
||||||
|
|
||||||
|
join_event = MockEvent(
|
||||||
|
sender="someone",
|
||||||
|
state_key="someone",
|
||||||
|
type="m.room.member",
|
||||||
|
content={"membership": "join"},
|
||||||
|
)
|
||||||
|
# the first two calls raise an exception while the third call is successful
|
||||||
|
self.mocked_update_membership.side_effect = [
|
||||||
|
SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
|
||||||
|
SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
|
||||||
|
make_awaitable(join_event),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Stop mypy from complaining that we give on_new_event a MockEvent rather than an
|
||||||
|
# EventBase.
|
||||||
|
await self.module.on_new_event(event=invite) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
await self.retry_assertions(
|
||||||
|
self.mocked_update_membership,
|
||||||
|
3,
|
||||||
|
sender=invite.state_key,
|
||||||
|
target=invite.state_key,
|
||||||
|
room_id=invite.room_id,
|
||||||
|
new_membership="join",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_accept_invite_failures(self) -> None:
|
||||||
|
"""Tests that receiving an invite for a local user makes the module attempt to
|
||||||
|
make the invitee join the room. This test verifies that if the update_membership call
|
||||||
|
fails consistently, _retry_make_join will break the loop after the set number of retries and
|
||||||
|
execution will continue.
|
||||||
|
"""
|
||||||
|
invite = MockEvent(
|
||||||
|
sender=self.user_id,
|
||||||
|
state_key=self.invitee,
|
||||||
|
type="m.room.member",
|
||||||
|
content={"membership": "invite"},
|
||||||
|
)
|
||||||
|
self.mocked_update_membership.side_effect = SynapseError(
|
||||||
|
HTTPStatus.FORBIDDEN, "Forbidden"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop mypy from complaining that we give on_new_event a MockEvent rather than an
|
||||||
|
# EventBase.
|
||||||
|
await self.module.on_new_event(event=invite) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
await self.retry_assertions(
|
||||||
|
self.mocked_update_membership,
|
||||||
|
5,
|
||||||
|
sender=invite.state_key,
|
||||||
|
target=invite.state_key,
|
||||||
|
room_id=invite.room_id,
|
||||||
|
new_membership="join",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_not_state(self) -> None:
|
||||||
|
"""Tests that receiving an invite that's not a state event does nothing."""
|
||||||
|
invite = MockEvent(
|
||||||
|
sender=self.user_id, type="m.room.member", content={"membership": "invite"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop mypy from complaining that we give on_new_event a MockEvent rather than an
|
||||||
|
# EventBase.
|
||||||
|
await self.module.on_new_event(event=invite) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
self.mocked_update_membership.assert_not_called()
|
||||||
|
|
||||||
|
async def test_not_invite(self) -> None:
|
||||||
|
"""Tests that receiving a membership update that's not an invite does nothing."""
|
||||||
|
invite = MockEvent(
|
||||||
|
sender=self.user_id,
|
||||||
|
state_key=self.user_id,
|
||||||
|
type="m.room.member",
|
||||||
|
content={"membership": "join"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop mypy from complaining that we give on_new_event a MockEvent rather than an
|
||||||
|
# EventBase.
|
||||||
|
await self.module.on_new_event(event=invite) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
self.mocked_update_membership.assert_not_called()
|
||||||
|
|
||||||
|
async def test_not_membership(self) -> None:
|
||||||
|
"""Tests that receiving a state event that's not a membership update does
|
||||||
|
nothing.
|
||||||
|
"""
|
||||||
|
invite = MockEvent(
|
||||||
|
sender=self.user_id,
|
||||||
|
state_key=self.user_id,
|
||||||
|
type="org.matrix.test",
|
||||||
|
content={"foo": "bar"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop mypy from complaining that we give on_new_event a MockEvent rather than an
|
||||||
|
# EventBase.
|
||||||
|
await self.module.on_new_event(event=invite) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
self.mocked_update_membership.assert_not_called()
|
||||||
|
|
||||||
|
def test_config_parse(self) -> None:
|
||||||
|
"""Tests that a correct configuration parses."""
|
||||||
|
config = {
|
||||||
|
"auto_accept_invites": {
|
||||||
|
"enabled": True,
|
||||||
|
"only_for_direct_messages": True,
|
||||||
|
"only_from_local_users": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsed_config = AutoAcceptInvitesConfig()
|
||||||
|
parsed_config.read_config(config)
|
||||||
|
|
||||||
|
self.assertTrue(parsed_config.enabled)
|
||||||
|
self.assertTrue(parsed_config.accept_invites_only_for_direct_messages)
|
||||||
|
self.assertTrue(parsed_config.accept_invites_only_from_local_users)
|
||||||
|
|
||||||
|
def test_runs_on_only_one_worker(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests that the module only runs on the specified worker.
|
||||||
|
"""
|
||||||
|
# By default, we run on the main process...
|
||||||
|
main_module = create_module(
|
||||||
|
config_override={"auto_accept_invites": {"enabled": True}}, worker_name=None
|
||||||
|
)
|
||||||
|
cast(
|
||||||
|
Mock, main_module._api.register_third_party_rules_callbacks
|
||||||
|
).assert_called_once()
|
||||||
|
|
||||||
|
# ...and not on other workers (like synchrotrons)...
|
||||||
|
sync_module = create_module(worker_name="synchrotron42")
|
||||||
|
cast(
|
||||||
|
Mock, sync_module._api.register_third_party_rules_callbacks
|
||||||
|
).assert_not_called()
|
||||||
|
|
||||||
|
# ...unless we configured them to be the designated worker.
|
||||||
|
specified_module = create_module(
|
||||||
|
config_override={
|
||||||
|
"auto_accept_invites": {
|
||||||
|
"enabled": True,
|
||||||
|
"worker_to_run_on": "account_data1",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
worker_name="account_data1",
|
||||||
|
)
|
||||||
|
cast(
|
||||||
|
Mock, specified_module._api.register_third_party_rules_callbacks
|
||||||
|
).assert_called_once()
|
||||||
|
|
||||||
|
async def retry_assertions(
|
||||||
|
self, mock: Mock, call_count: int, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
This is a hacky way to ensure that the assertions are not called before the other coroutine
|
||||||
|
has a chance to call `update_room_membership`. It catches the exception caused by a failure,
|
||||||
|
and sleeps the thread before retrying, up until 5 tries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_count: the number of times the mock should have been called
|
||||||
|
mock: the mocked function we want to assert on
|
||||||
|
kwargs: keyword arguments to assert that the mock was called with
|
||||||
|
"""
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < 5:
|
||||||
|
try:
|
||||||
|
# Check that the mocked method is called the expected amount of times and with the right
|
||||||
|
# arguments to attempt to make the user join the room.
|
||||||
|
mock.assert_called_with(**kwargs)
|
||||||
|
self.assertEqual(call_count, mock.call_count)
|
||||||
|
break
|
||||||
|
except AssertionError as e:
|
||||||
|
i += 1
|
||||||
|
if i == 5:
|
||||||
|
# we've used up the tries, force the test to fail as we've already caught the exception
|
||||||
|
self.fail(e)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class MockEvent:
|
||||||
|
"""Mocks an event. Only exposes properties the module uses."""
|
||||||
|
|
||||||
|
sender: str
|
||||||
|
type: str
|
||||||
|
content: Dict[str, Any]
|
||||||
|
room_id: str = "!someroom"
|
||||||
|
state_key: Optional[str] = None
|
||||||
|
|
||||||
|
def is_state(self) -> bool:
|
||||||
|
"""Checks if the event is a state event by checking if it has a state key."""
|
||||||
|
return self.state_key is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def membership(self) -> str:
|
||||||
|
"""Extracts the membership from the event. Should only be called on an event
|
||||||
|
that's a membership event, and will raise a KeyError otherwise.
|
||||||
|
"""
|
||||||
|
membership: str = self.content["membership"]
|
||||||
|
return membership
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
TV = TypeVar("TV")
|
||||||
|
|
||||||
|
|
||||||
|
async def make_awaitable(value: T) -> T:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def make_multiple_awaitable(result: TV) -> Awaitable[TV]:
|
||||||
|
"""
|
||||||
|
Makes an awaitable, suitable for mocking an `async` function.
|
||||||
|
This uses Futures as they can be awaited multiple times so can be returned
|
||||||
|
to multiple callers.
|
||||||
|
"""
|
||||||
|
future: Future[TV] = Future()
|
||||||
|
future.set_result(result)
|
||||||
|
return future
|
||||||
|
|
||||||
|
|
||||||
|
def create_module(
|
||||||
|
config_override: Optional[Dict[str, Any]] = None, worker_name: Optional[str] = None
|
||||||
|
) -> InviteAutoAccepter:
|
||||||
|
# Create a mock based on the ModuleApi spec, but override some mocked functions
|
||||||
|
# because some capabilities are needed for running the tests.
|
||||||
|
module_api = Mock(spec=ModuleApi)
|
||||||
|
module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
|
||||||
|
module_api.worker_name = worker_name
|
||||||
|
module_api.sleep.return_value = make_multiple_awaitable(None)
|
||||||
|
|
||||||
|
if config_override is None:
|
||||||
|
config_override = {}
|
||||||
|
|
||||||
|
config = AutoAcceptInvitesConfig()
|
||||||
|
config.read_config(config_override)
|
||||||
|
|
||||||
|
return InviteAutoAccepter(config, module_api)
|
|
@ -170,6 +170,7 @@ class RestHelper:
|
||||||
targ: Optional[str] = None,
|
targ: Optional[str] = None,
|
||||||
expect_code: int = HTTPStatus.OK,
|
expect_code: int = HTTPStatus.OK,
|
||||||
tok: Optional[str] = None,
|
tok: Optional[str] = None,
|
||||||
|
extra_data: Optional[dict] = None,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
return self.change_membership(
|
return self.change_membership(
|
||||||
room=room,
|
room=room,
|
||||||
|
@ -178,6 +179,7 @@ class RestHelper:
|
||||||
tok=tok,
|
tok=tok,
|
||||||
membership=Membership.INVITE,
|
membership=Membership.INVITE,
|
||||||
expect_code=expect_code,
|
expect_code=expect_code,
|
||||||
|
extra_data=extra_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def join(
|
def join(
|
||||||
|
|
|
@ -85,6 +85,7 @@ from twisted.web.server import Request, Site
|
||||||
|
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.events.auto_accept_invites import InviteAutoAccepter
|
||||||
from synapse.events.presence_router import load_legacy_presence_router
|
from synapse.events.presence_router import load_legacy_presence_router
|
||||||
from synapse.handlers.auth import load_legacy_password_auth_providers
|
from synapse.handlers.auth import load_legacy_password_auth_providers
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
|
@ -1156,6 +1157,11 @@ def setup_test_homeserver(
|
||||||
for module, module_config in hs.config.modules.loaded_modules:
|
for module, module_config in hs.config.modules.loaded_modules:
|
||||||
module(config=module_config, api=module_api)
|
module(config=module_config, api=module_api)
|
||||||
|
|
||||||
|
if hs.config.auto_accept_invites.enabled:
|
||||||
|
# Start the local auto_accept_invites module.
|
||||||
|
m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
|
||||||
|
logger.info("Loaded local module %s", m)
|
||||||
|
|
||||||
load_legacy_spam_checkers(hs)
|
load_legacy_spam_checkers(hs)
|
||||||
load_legacy_third_party_event_rules(hs)
|
load_legacy_third_party_event_rules(hs)
|
||||||
load_legacy_presence_router(hs)
|
load_legacy_presence_router(hs)
|
||||||
|
|
Loading…
Reference in a new issue