mirror of
https://github.com/element-hq/synapse
synced 2024-06-30 10:13:29 +00:00
Compare commits
5 commits
658b086f07
...
a2cfcecfaf
Author | SHA1 | Date | |
---|---|---|---|
|
a2cfcecfaf | ||
|
b1e98eec0e | ||
|
7c23b77f3e | ||
|
77e1414b08 | ||
|
222f5b120b |
2
changelog.d/17326.feature
Normal file
2
changelog.d/17326.feature
Normal file
|
@ -0,0 +1,2 @@
|
|||
Add initial implementation of delayed events (Futures) as proposed by [MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140).
|
||||
|
|
@ -1662,6 +1662,19 @@ rc_registration_token_validity:
|
|||
burst_count: 6
|
||||
```
|
||||
---
|
||||
### `rc_future_token_validity`
|
||||
|
||||
This option checks the validity of future tokens that ratelimits requests based on
|
||||
the client's IP address.
|
||||
Defaults to `per_second: 0.1`, `burst_count: 5`.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
rc_future_token_validity:
|
||||
per_second: 0.3
|
||||
burst_count: 6
|
||||
```
|
||||
---
|
||||
### `rc_login`
|
||||
|
||||
This option specifies several limits for login:
|
||||
|
|
|
@ -441,6 +441,24 @@ class ExperimentalConfig(Config):
|
|||
"msc3916_authenticated_media_enabled", False
|
||||
)
|
||||
|
||||
# MSC4140: Delayed events (Futures)
|
||||
# The maximum allowed delay for timeout futures.
|
||||
try:
|
||||
self.msc4140_max_future_timeout_duration = int(
|
||||
experimental["msc4140_max_future_timeout_duration"]
|
||||
)
|
||||
if self.msc4140_max_future_timeout_duration < 0:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise ConfigError(
|
||||
"Timeout duration must be a positive integer",
|
||||
("experimental", "msc4140_max_future_timeout_duration"),
|
||||
)
|
||||
except KeyError:
|
||||
self.msc4140_max_future_timeout_duration = (
|
||||
10 * 365 * 24 * 60 * 60 * 1000
|
||||
) # 10 years
|
||||
|
||||
# MSC4151: Report room API (Client-Server API)
|
||||
self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False)
|
||||
|
||||
|
|
|
@ -124,6 +124,13 @@ class RatelimitConfig(Config):
|
|||
defaults={"per_second": 0.1, "burst_count": 5},
|
||||
)
|
||||
|
||||
# Ratelimit requests to send/cancel/refresh futures (MSC4140):
|
||||
self.rc_future_token_validity = RatelimitSettings.parse(
|
||||
config,
|
||||
"rc_future_token_validity",
|
||||
defaults={"per_second": 0.1, "burst_count": 5},
|
||||
)
|
||||
|
||||
# It is reasonable to login with a bunch of devices at once (i.e. when
|
||||
# setting up an account), but it is *not* valid to continually be
|
||||
# logging into new devices.
|
||||
|
|
329
synapse/handlers/futures.py
Normal file
329
synapse/handlers/futures.py
Normal file
|
@ -0,0 +1,329 @@
|
|||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
#
|
||||
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import Codes, NotFoundError, ShadowBanError, SynapseError
|
||||
from synapse.logging.opentracing import set_tag
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.databases.main.futures import (
|
||||
EventType,
|
||||
FutureID,
|
||||
FutureTokenType,
|
||||
StateKey,
|
||||
Timeout,
|
||||
Timestamp,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, RoomID, UserID, create_requester
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FuturesHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self.config = hs.config
|
||||
self.clock = hs.get_clock()
|
||||
self.request_ratelimiter = hs.get_request_ratelimiter()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
|
||||
self._hostname = hs.hostname
|
||||
|
||||
self._futures: Dict[FutureID, IDelayedCall] = {}
|
||||
|
||||
async def _schedule_db_futures() -> None:
|
||||
all_future_timestamps = await self.store.get_all_future_timestamps()
|
||||
|
||||
# Get the time after awaiting to increase accuracy.
|
||||
# For even more accuracy, could get the time on each loop iteration.
|
||||
current_ts = self._get_current_ts()
|
||||
|
||||
for future_id, ts in all_future_timestamps:
|
||||
timeout = _get_timeout_between(current_ts, ts)
|
||||
if timeout > 0:
|
||||
self._schedule_future(future_id, timeout)
|
||||
else:
|
||||
logger.info("Scheduling timeout for future_id %s now", future_id)
|
||||
run_as_background_process(
|
||||
"_send_future",
|
||||
self._send_future,
|
||||
future_id,
|
||||
)
|
||||
|
||||
self._initialized_from_db = run_as_background_process(
|
||||
"_schedule_db_futures", _schedule_db_futures
|
||||
)
|
||||
|
||||
async def add_future(
|
||||
self,
|
||||
requester: Requester,
|
||||
*,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: Optional[str],
|
||||
origin_server_ts: Optional[int],
|
||||
content: JsonDict,
|
||||
timeout: Optional[int],
|
||||
group_id: Optional[str],
|
||||
is_custom_endpoint: bool = False, # TODO: Remove this eventually
|
||||
) -> JsonDict:
|
||||
"""Creates a new future, and if it is a timeout future, schedules it to be sent.
|
||||
|
||||
Params:
|
||||
requester: The initial requester of the future.
|
||||
room_id: The room where the future event should be sent.
|
||||
event_type: The event type of the future event.
|
||||
state_key: The state key of the future event, or None if it is not a state event.
|
||||
origin_server_ts: The custom timestamp to send the future event with.
|
||||
If None, the timestamp will be the actual time when the future is sent.
|
||||
content: The content of the future event.
|
||||
timeout: How long (in milliseconds) to wait before automatically sending the future.
|
||||
If None, the future will never be automatically sent.
|
||||
group_id: The future group this future belongs to.
|
||||
If None, the future will be added to a new group with an auto-generated ID.
|
||||
|
||||
Returns:
|
||||
A mapping of tokens that can be used in requests to send, cancel, or refresh the future.
|
||||
"""
|
||||
if timeout is not None:
|
||||
max_timeout = self.config.experimental.msc4140_max_future_timeout_duration
|
||||
if timeout > max_timeout:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
f"'{'future_timeout' if not is_custom_endpoint else 'timeout'}' is too large",
|
||||
Codes.INVALID_PARAM,
|
||||
{
|
||||
"max_timeout_duration": max_timeout,
|
||||
},
|
||||
)
|
||||
else:
|
||||
if group_id is None:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"At least one of 'future_timeout' and 'future_group_id' is required",
|
||||
Codes.MISSING_PARAM,
|
||||
)
|
||||
|
||||
await self.request_ratelimiter.ratelimit(requester)
|
||||
await self._initialized_from_db
|
||||
|
||||
if (
|
||||
timeout is None
|
||||
and group_id is not None
|
||||
and not await self.store.has_group_id(requester.user, group_id)
|
||||
):
|
||||
raise SynapseError(
|
||||
HTTPStatus.CONFLICT,
|
||||
# TODO: Think of a better wording for this
|
||||
"The given 'future_group_id' does not exist and cannot be created with an action future",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
(
|
||||
future_id,
|
||||
group_id,
|
||||
send_token,
|
||||
cancel_token,
|
||||
refresh_token,
|
||||
) = await self.store.add_future(
|
||||
user_id=requester.user,
|
||||
group_id=group_id,
|
||||
room_id=RoomID.from_string(room_id),
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=content,
|
||||
timeout=timeout,
|
||||
timestamp=timeout + self.clock.time_msec() if timeout is not None else None,
|
||||
)
|
||||
|
||||
if timeout is not None:
|
||||
self._schedule_future(future_id, Timeout(timeout))
|
||||
|
||||
ret = {
|
||||
"future_group_id": group_id,
|
||||
"send_token": send_token,
|
||||
"cancel_token": cancel_token,
|
||||
}
|
||||
# TODO: type hint for non-None refresh_token return value when timeout argument is non-None
|
||||
if refresh_token is not None:
|
||||
ret["refresh_token"] = refresh_token
|
||||
return ret
|
||||
|
||||
async def use_future_token(self, token: str) -> None:
|
||||
"""Executes the appropriate action for the given token.
|
||||
|
||||
Params:
|
||||
token: The token of the future to act on.
|
||||
"""
|
||||
await self._initialized_from_db
|
||||
|
||||
future_id, future_token_type = await self.store.get_future_by_token(token)
|
||||
if future_token_type == FutureTokenType.SEND:
|
||||
await self._send_future(future_id)
|
||||
elif future_token_type == FutureTokenType.CANCEL:
|
||||
await self._cancel_future(future_id)
|
||||
elif future_token_type == FutureTokenType.REFRESH:
|
||||
await self._refresh_future(future_id)
|
||||
|
||||
async def _send_future(self, future_id: FutureID) -> None:
|
||||
await self._initialized_from_db
|
||||
|
||||
args = await self.store.pop_future_event(future_id)
|
||||
|
||||
self._unschedule_future(future_id)
|
||||
await self._send_event(*args)
|
||||
|
||||
async def _send_future_on_timeout(self, future_id: FutureID) -> None:
|
||||
await self._initialized_from_db
|
||||
|
||||
try:
|
||||
args = await self.store.pop_future_event(future_id)
|
||||
except NotFoundError:
|
||||
logger.warning(
|
||||
"future_id %s missing from DB on timeout, so it must have been cancelled/sent",
|
||||
future_id,
|
||||
)
|
||||
else:
|
||||
await self._send_event(*args)
|
||||
|
||||
async def _cancel_future(self, future_id: FutureID) -> None:
|
||||
await self._initialized_from_db
|
||||
|
||||
await self.store.remove_future(future_id)
|
||||
|
||||
self._unschedule_future(future_id)
|
||||
|
||||
async def _refresh_future(self, future_id: FutureID) -> None:
|
||||
await self._initialized_from_db
|
||||
|
||||
timeout = await self.store.update_future_timestamp(
|
||||
future_id,
|
||||
self._get_current_ts(),
|
||||
)
|
||||
|
||||
self._unschedule_future(future_id)
|
||||
self._schedule_future(future_id, timeout)
|
||||
|
||||
def _schedule_future(
|
||||
self,
|
||||
future_id: FutureID,
|
||||
timeout: Timeout,
|
||||
) -> None:
|
||||
"""NOTE: Should not be called with a future_id that isn't in the DB, or with a negative timeout."""
|
||||
delay_sec = timeout / 1000
|
||||
|
||||
logger.info(
|
||||
"Scheduling timeout for future_id %d in %.3fs", future_id, delay_sec
|
||||
)
|
||||
|
||||
self._futures[future_id] = self.clock.call_later(
|
||||
delay_sec,
|
||||
run_as_background_process,
|
||||
"_send_future_on_timeout",
|
||||
self._send_future_on_timeout,
|
||||
future_id,
|
||||
)
|
||||
|
||||
def _unschedule_future(self, future_id: FutureID) -> None:
|
||||
try:
|
||||
delayed_call = self._futures.pop(future_id)
|
||||
self.clock.cancel_call_later(delayed_call)
|
||||
except KeyError:
|
||||
logger.debug("future_id %s was not mapped to a delayed call", future_id)
|
||||
|
||||
async def get_all_futures_for_user(self, requester: Requester) -> List[JsonDict]:
|
||||
"""Return all pending futures requested by the given user."""
|
||||
await self.request_ratelimiter.ratelimit(requester)
|
||||
await self._initialized_from_db
|
||||
return await self.store.get_all_futures_for_user(requester.user)
|
||||
|
||||
# TODO: Consider getting a list of all timeout futures that were in this one's group, and cancel their delayed calls
|
||||
async def _send_event(
|
||||
self,
|
||||
user_localpart: str,
|
||||
room_id: RoomID,
|
||||
event_type: EventType,
|
||||
state_key: Optional[StateKey],
|
||||
origin_server_ts: Optional[Timestamp],
|
||||
content: JsonDict,
|
||||
# TODO: support guests
|
||||
# is_guest: bool,
|
||||
txn_id: Optional[str] = None,
|
||||
) -> None:
|
||||
user_id = UserID(user_localpart, self._hostname)
|
||||
requester = create_requester(
|
||||
user_id,
|
||||
# is_guest=is_guest,
|
||||
)
|
||||
|
||||
try:
|
||||
if state_key is not None and event_type == EventTypes.Member:
|
||||
membership = content.get("membership", None)
|
||||
event_id, _ = await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
target=UserID.from_string(state_key),
|
||||
room_id=str(room_id),
|
||||
action=membership,
|
||||
content=content,
|
||||
origin_server_ts=origin_server_ts,
|
||||
)
|
||||
else:
|
||||
event_dict: JsonDict = {
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id.to_string(),
|
||||
"sender": user_id.to_string(),
|
||||
}
|
||||
|
||||
if state_key is not None:
|
||||
event_dict["state_key"] = state_key
|
||||
|
||||
if origin_server_ts is not None:
|
||||
event_dict["origin_server_ts"] = origin_server_ts
|
||||
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
set_tag("event_id", event_id)
|
||||
|
||||
def _get_current_ts(self) -> Timestamp:
|
||||
return Timestamp(self.clock.time_msec())
|
||||
|
||||
|
||||
def _get_timeout_between(from_ts: Timestamp, to_ts: Timestamp) -> Timeout:
|
||||
return Timeout(to_ts - from_ts)
|
|
@ -34,6 +34,7 @@ from synapse.rest.client import (
|
|||
directory,
|
||||
events,
|
||||
filter,
|
||||
futures,
|
||||
initial_sync,
|
||||
keys,
|
||||
knock,
|
||||
|
@ -103,6 +104,7 @@ class ClientRestResource(JsonResource):
|
|||
events.register_servlets(hs, client_resource)
|
||||
|
||||
room.register_servlets(hs, client_resource)
|
||||
futures.register_servlets(hs, client_resource)
|
||||
login.register_servlets(hs, client_resource)
|
||||
profile.register_servlets(hs, client_resource)
|
||||
presence.register_servlets(hs, client_resource)
|
||||
|
|
90
synapse/rest/client/futures.py
Normal file
90
synapse/rest/client/futures.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
#
|
||||
|
||||
""" This module contains REST servlets to do with futures: /future/<paths> """
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class FuturesTokenServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/future/(?P<token>[^/]*)$", releases=(), v1=False
|
||||
)
|
||||
CATEGORY = "Future management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.futures_handler = hs.get_futures_handler()
|
||||
|
||||
self._ratelimiter = Ratelimiter(
|
||||
store=hs.get_datastores().main,
|
||||
clock=hs.get_clock(),
|
||||
cfg=hs.config.ratelimiting.rc_future_token_validity,
|
||||
)
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, token: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
# Ratelimit by address, since this is an unauthenticated request
|
||||
ratelimit_key = (request.getClientAddress().host,)
|
||||
# Check if we should be ratelimited due to too many previous failed attempts
|
||||
await self._ratelimiter.ratelimit(None, ratelimit_key, update=False)
|
||||
|
||||
try:
|
||||
await self.futures_handler.use_future_token(token)
|
||||
return 200, {}
|
||||
except NotFoundError as e:
|
||||
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
||||
await self._ratelimiter.can_do_action(None, ratelimit_key)
|
||||
# TODO: Decide if the error code should be left at 404, instead of 410 as per the MSC
|
||||
e.code = 410
|
||||
raise
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class FuturesServlet(RestServlet):
|
||||
PATTERNS = client_patterns(r"/org\.matrix\.msc4140/future$", releases=(), v1=False)
|
||||
CATEGORY = "Future management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.futures_handler = hs.get_futures_handler()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, List[JsonDict]]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
return 200, await self.futures_handler.get_all_futures_for_user(requester)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
FuturesTokenServlet(hs).register(http_server)
|
||||
FuturesServlet(hs).register(http_server)
|
|
@ -2,7 +2,7 @@
|
|||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
# Copyright (C) 2023-2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
|
@ -36,6 +36,7 @@ from synapse.api.constants import Direction, EventTypes, Membership
|
|||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
InvalidAPICallError,
|
||||
InvalidClientCredentialsError,
|
||||
MissingClientTokenError,
|
||||
ShadowBanError,
|
||||
|
@ -64,7 +65,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
StrCollection,
|
||||
StreamToken,
|
||||
ThirdPartyInstanceID,
|
||||
UserID,
|
||||
)
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
||||
|
@ -409,6 +417,261 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
)
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class RoomStateFutureRestServlet(RestServlet):
|
||||
CATEGORY = "Future sending requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.futures_handler = hs.get_futures_handler()
|
||||
|
||||
def register(self, http_server: HttpServer) -> None:
|
||||
# /org.matrix.msc4140/rooms/$roomid/state_future/$eventtype
|
||||
no_state_key = r"/org\.matrix\.msc4140/rooms/(?P<room_id>[^/]*)/state_future/(?P<event_type>[^/]*)$"
|
||||
|
||||
# /org.matrix.msc4140/rooms/$roomid/state_future/$eventtype/$statekey
|
||||
state_key = (
|
||||
r"/org\.matrix\.msc4140/rooms/(?P<room_id>[^/]*)/state_future/"
|
||||
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)"
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT",
|
||||
client_patterns(state_key, releases=(), v1=False),
|
||||
self.on_PUT,
|
||||
self.__class__.__name__,
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT",
|
||||
client_patterns(no_state_key, releases=(), v1=False),
|
||||
self.on_PUT_no_state_key,
|
||||
self.__class__.__name__,
|
||||
)
|
||||
|
||||
def on_PUT_no_state_key(
|
||||
self, request: SynapseRequest, room_id: str, event_type: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
return self.on_PUT(request, room_id, event_type, "")
|
||||
|
||||
async def on_PUT(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: str,
|
||||
txn_id: Optional[str] = None,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
if txn_id:
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
ret = await self.futures_handler.add_future(
|
||||
requester,
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
origin_server_ts=(
|
||||
parse_integer(request, "ts") if requester.app_service else None
|
||||
),
|
||||
content=parse_json_object_from_request(request),
|
||||
timeout=parse_integer(request, "future_timeout"),
|
||||
group_id=parse_string(request, "future_group_id"),
|
||||
)
|
||||
|
||||
for k, v in ret.items():
|
||||
set_tag(k, v)
|
||||
return 200, ret
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class RoomSendFutureRestServlet(TransactionRestServlet):
|
||||
CATEGORY = "Future sending requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.auth = hs.get_auth()
|
||||
self.futures_handler = hs.get_futures_handler()
|
||||
|
||||
def register(self, http_server: HttpServer) -> None:
|
||||
# /org.matrix.msc4140/rooms/$roomid/send_future/$event_type[/$txn_id]
|
||||
PATTERNS = r"/org\.matrix\.msc4140/rooms/(?P<room_id>[^/]*)/send_future/(?P<event_type>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server, releases=(), v1=False)
|
||||
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
ret = await self.futures_handler.add_future(
|
||||
requester,
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=None,
|
||||
origin_server_ts=(
|
||||
parse_integer(request, "ts") if requester.app_service else None
|
||||
),
|
||||
content=parse_json_object_from_request(request),
|
||||
timeout=parse_integer(request, "future_timeout"),
|
||||
group_id=parse_string(request, "future_group_id"),
|
||||
)
|
||||
|
||||
for k, v in ret.items():
|
||||
set_tag(k, v)
|
||||
return 200, ret
|
||||
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
return await self._do(request, requester, room_id, event_type)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request,
|
||||
requester,
|
||||
self._do,
|
||||
request,
|
||||
requester,
|
||||
room_id,
|
||||
event_type,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Remove in favour of the Room{Send,State}FutureRestServlets. Otherwise, this needs unit testing
|
||||
class RoomFutureRestServlet(TransactionRestServlet):
|
||||
CATEGORY = "Future sending requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.auth = hs.get_auth()
|
||||
self.futures_handler = hs.get_futures_handler()
|
||||
|
||||
def register(self, http_server: HttpServer) -> None:
|
||||
# NOTE: Difference from MSC = remove /send part of path, to avoid ambiguity with regular event sending
|
||||
# /org.matrix.msc4140/rooms/$roomid/future[/$txn_id]
|
||||
PATTERNS = r"/org\.matrix\.msc4140/rooms/(?P<room_id>[^/]*)/future"
|
||||
register_txn_path(self, PATTERNS, http_server, releases=(), v1=False)
|
||||
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
timeout = int(body["timeout"])
|
||||
except KeyError:
|
||||
raise SynapseError(400, "'timeout' is missing", Codes.MISSING_PARAM)
|
||||
except Exception:
|
||||
raise SynapseError(400, "'timeout' is not an integer", Codes.INVALID_PARAM)
|
||||
|
||||
try:
|
||||
timeout_item: JsonDict = {
|
||||
str(k): v for k, v in body["send_on_timeout"].items()
|
||||
}
|
||||
except KeyError:
|
||||
raise SynapseError(400, "'send_on_timeout' is missing", Codes.MISSING_PARAM)
|
||||
except Exception:
|
||||
raise InvalidAPICallError("'send_on_timeout' is not valid JSON")
|
||||
|
||||
try:
|
||||
send_on_action: Dict[str, JsonDict] = {
|
||||
str(action_name): {str(k): v for k, v in action_item.items()}
|
||||
for action_name, action_item in body.get("send_on_action", {}).items()
|
||||
}
|
||||
except Exception:
|
||||
raise InvalidAPICallError(
|
||||
"'send_on_action' is not a JSON mapping of action names to event content"
|
||||
)
|
||||
|
||||
# TODO: Should action events be able to have a different ts value than the timeout event?
|
||||
origin_server_ts = (
|
||||
parse_integer(request, "ts") if requester.app_service else None
|
||||
)
|
||||
|
||||
# TODO: send_now. But first validate send_on_timeout & send_on_action
|
||||
|
||||
send_on_timeout_resp = await self.futures_handler.add_future(
|
||||
requester,
|
||||
room_id=room_id,
|
||||
event_type=str(timeout_item["type"]),
|
||||
state_key=(
|
||||
str(timeout_item["state_key"]) if "state_key" in timeout_item else None
|
||||
),
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=timeout_item["content"],
|
||||
timeout=timeout,
|
||||
group_id=None,
|
||||
is_custom_endpoint=True,
|
||||
)
|
||||
group_id = str(send_on_timeout_resp.pop("future_group_id"))
|
||||
ret = {"send_on_timeout": send_on_timeout_resp}
|
||||
|
||||
if send_on_action:
|
||||
send_on_action_resp: Dict[str, Dict[str, JsonDict]] = {}
|
||||
for action_name, action_item in send_on_action.items():
|
||||
action_item_resp = await self.futures_handler.add_future(
|
||||
requester,
|
||||
room_id=room_id,
|
||||
event_type=str(action_item["type"]),
|
||||
state_key=(
|
||||
str(action_item["state_key"])
|
||||
if "state_key" in action_item
|
||||
else None
|
||||
),
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=action_item["content"],
|
||||
timeout=None,
|
||||
group_id=group_id,
|
||||
is_custom_endpoint=True,
|
||||
)
|
||||
try:
|
||||
del action_item_resp["future_group_id"]
|
||||
except KeyError:
|
||||
pass
|
||||
send_on_action_resp[action_name] = action_item_resp
|
||||
|
||||
ret["send_on_action"] = send_on_action_resp
|
||||
|
||||
return 200, ret
|
||||
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
return await self._do(request, requester, room_id)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, txn_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request,
|
||||
requester,
|
||||
self._do,
|
||||
request,
|
||||
requester,
|
||||
room_id,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Needs unit testing for room ID + alias joins
|
||||
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
|
||||
CATEGORY = "Event sending requests"
|
||||
|
@ -1344,6 +1607,9 @@ def register_txn_path(
|
|||
servlet: RestServlet,
|
||||
regex_string: str,
|
||||
http_server: HttpServer,
|
||||
releases: StrCollection = ("r0", "v3"),
|
||||
unstable: bool = True,
|
||||
v1: bool = True,
|
||||
) -> None:
|
||||
"""Registers a transaction-based path.
|
||||
|
||||
|
@ -1362,13 +1628,13 @@ def register_txn_path(
|
|||
raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
|
||||
http_server.register_paths(
|
||||
"POST",
|
||||
client_patterns(regex_string + "$", v1=True),
|
||||
client_patterns(regex_string + "$", releases, unstable, v1),
|
||||
on_POST,
|
||||
servlet.__class__.__name__,
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT",
|
||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
|
||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", releases, unstable, v1),
|
||||
on_PUT,
|
||||
servlet.__class__.__name__,
|
||||
)
|
||||
|
@ -1503,12 +1769,14 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
|
|||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
RoomStateEventRestServlet(hs).register(http_server)
|
||||
RoomStateFutureRestServlet(hs).register(http_server)
|
||||
RoomMemberListRestServlet(hs).register(http_server)
|
||||
JoinedRoomMemberListRestServlet(hs).register(http_server)
|
||||
RoomMessageListRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
RoomMembershipRestServlet(hs).register(http_server)
|
||||
RoomSendEventRestServlet(hs).register(http_server)
|
||||
RoomSendFutureRestServlet(hs).register(http_server)
|
||||
PublicRoomListRestServlet(hs).register(http_server)
|
||||
RoomStateRestServlet(hs).register(http_server)
|
||||
RoomRedactEventRestServlet(hs).register(http_server)
|
||||
|
@ -1523,6 +1791,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
SearchRestServlet(hs).register(http_server)
|
||||
RoomCreateRestServlet(hs).register(http_server)
|
||||
TimestampLookupRestServlet(hs).register(http_server)
|
||||
RoomFutureRestServlet(hs).register(http_server)
|
||||
|
||||
# Some servlets only get registered for the main process.
|
||||
if hs.config.worker.worker_app is None:
|
||||
|
|
|
@ -149,6 +149,8 @@ class VersionsRestServlet(RestServlet):
|
|||
is not None
|
||||
)
|
||||
),
|
||||
# MSC4140: Delayed events (Futures)
|
||||
"org.matrix.msc4140": True,
|
||||
# MSC4151: Report room API (Client-Server API)
|
||||
"org.matrix.msc4151": self.config.experimental.msc4151_enabled,
|
||||
},
|
||||
|
|
|
@ -76,6 +76,7 @@ from synapse.handlers.event_auth import EventAuthHandler
|
|||
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||
from synapse.handlers.federation import FederationHandler
|
||||
from synapse.handlers.federation_event import FederationEventHandler
|
||||
from synapse.handlers.futures import FuturesHandler
|
||||
from synapse.handlers.identity import IdentityHandler
|
||||
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||
from synapse.handlers.message import EventCreationHandler, MessageHandler
|
||||
|
@ -248,6 +249,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
"account_validity",
|
||||
"auth",
|
||||
"deactivate_account",
|
||||
"futures",
|
||||
"message",
|
||||
"pagination",
|
||||
"profile",
|
||||
|
@ -936,3 +938,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
@cache_in_self
|
||||
def get_task_scheduler(self) -> TaskScheduler:
|
||||
return TaskScheduler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_futures_handler(self) -> FuturesHandler:
|
||||
return FuturesHandler(self)
|
||||
|
|
|
@ -54,6 +54,7 @@ from .events_bg_updates import EventsBackgroundUpdatesStore
|
|||
from .events_forward_extremities import EventForwardExtremitiesStore
|
||||
from .experimental_features import ExperimentalFeaturesStore
|
||||
from .filtering import FilteringWorkerStore
|
||||
from .futures import FuturesStore
|
||||
from .keys import KeyStore
|
||||
from .lock import LockStore
|
||||
from .media_repository import MediaRepositoryStore
|
||||
|
@ -156,6 +157,7 @@ class DataStore(
|
|||
LockStore,
|
||||
SessionStore,
|
||||
TaskSchedulerWorkerStore,
|
||||
FuturesStore,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
474
synapse/storage/databases/main/futures.py
Normal file
474
synapse/storage/databases/main/futures.py
Normal file
|
@ -0,0 +1,474 @@
|
|||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2024 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
#
|
||||
|
||||
from binascii import crc32
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NewType,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from synapse.api.errors import NotFoundError, StoreError, SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import JsonDict, RoomID, UserID
|
||||
from synapse.util import json_encoder, stringutils as stringutils
|
||||
from synapse.util.stringutils import base62_encode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class FutureTokenType(Enum):
|
||||
SEND = 1
|
||||
CANCEL = 2
|
||||
REFRESH = 4
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self == FutureTokenType.SEND:
|
||||
return "is_send"
|
||||
elif self == FutureTokenType.CANCEL:
|
||||
return "is_cancel"
|
||||
else:
|
||||
return "is_refresh"
|
||||
|
||||
|
||||
FutureToken = NewType("FutureToken", str)
|
||||
GroupID = NewType("GroupID", str)
|
||||
EventType = NewType("EventType", str)
|
||||
StateKey = NewType("StateKey", str)
|
||||
|
||||
FutureID = NewType("FutureID", int)
|
||||
Timeout = NewType("Timeout", int)
|
||||
Timestamp = NewType("Timestamp", int)
|
||||
|
||||
AddFutureReturn = Tuple[
|
||||
FutureID,
|
||||
GroupID,
|
||||
FutureToken,
|
||||
FutureToken,
|
||||
Optional[FutureToken],
|
||||
]
|
||||
|
||||
|
||||
# TODO: Try to support workers
|
||||
class FuturesStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
async def add_future(
|
||||
self,
|
||||
*,
|
||||
user_id: UserID,
|
||||
group_id: Optional[str],
|
||||
room_id: RoomID,
|
||||
event_type: str,
|
||||
state_key: Optional[str],
|
||||
origin_server_ts: Optional[int],
|
||||
content: JsonDict,
|
||||
timeout: Optional[int],
|
||||
timestamp: Optional[int],
|
||||
) -> AddFutureReturn:
|
||||
"""Inserts a new future in the DB."""
|
||||
user_localpart = user_id.localpart
|
||||
|
||||
def add_future_txn(txn: LoggingTransaction) -> AddFutureReturn:
|
||||
T = TypeVar("T", bound=str)
|
||||
|
||||
def insert_and_get_unique_colval(
|
||||
table: str,
|
||||
values: Dict[str, Any],
|
||||
column: str,
|
||||
colval_generator: Callable[[], T],
|
||||
) -> T:
|
||||
attempts_remaining = 10
|
||||
while True:
|
||||
colval = colval_generator()
|
||||
try:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table,
|
||||
values={
|
||||
**values,
|
||||
column: colval,
|
||||
},
|
||||
)
|
||||
return colval
|
||||
# TODO: Handle only the error type for DB key collisions
|
||||
except Exception:
|
||||
if attempts_remaining > 0:
|
||||
attempts_remaining -= 1
|
||||
else:
|
||||
raise SynapseError(
|
||||
500,
|
||||
f"Couldn't generate a unique value for column {column} in table {table}",
|
||||
)
|
||||
|
||||
if group_id is None:
|
||||
group_id_final = insert_and_get_unique_colval(
|
||||
"future_groups",
|
||||
{
|
||||
"user_localpart": user_localpart,
|
||||
},
|
||||
"group_id",
|
||||
_generate_group_id,
|
||||
)
|
||||
else:
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO future_groups (user_localpart, group_id)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT DO NOTHING
|
||||
""",
|
||||
(user_localpart, group_id),
|
||||
)
|
||||
group_id_final = GroupID(group_id)
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
INSERT INTO futures (
|
||||
user_localpart, group_id,
|
||||
room_id, event_type, state_key, origin_server_ts,
|
||||
content
|
||||
) VALUES (
|
||||
?, ?,
|
||||
?, ?, ?, ?,
|
||||
?
|
||||
)
|
||||
RETURNING future_id
|
||||
""",
|
||||
(
|
||||
user_localpart,
|
||||
group_id_final,
|
||||
room_id.to_string(),
|
||||
event_type,
|
||||
state_key,
|
||||
origin_server_ts,
|
||||
json_encoder.encode(content),
|
||||
),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
assert row is not None
|
||||
future_id = FutureID(row[0])
|
||||
|
||||
def insert_and_get_future_token(token_type: FutureTokenType) -> FutureToken:
|
||||
return insert_and_get_unique_colval(
|
||||
"future_tokens",
|
||||
{
|
||||
"future_id": future_id,
|
||||
str(token_type): True,
|
||||
},
|
||||
"token",
|
||||
_generate_future_token,
|
||||
)
|
||||
|
||||
send_token = insert_and_get_future_token(FutureTokenType.SEND)
|
||||
cancel_token = insert_and_get_future_token(FutureTokenType.CANCEL)
|
||||
if timeout is not None:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="future_timeouts",
|
||||
values={
|
||||
"future_id": future_id,
|
||||
"timeout": timeout,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
refresh_token = insert_and_get_future_token(FutureTokenType.REFRESH)
|
||||
else:
|
||||
refresh_token = None
|
||||
|
||||
return (
|
||||
future_id,
|
||||
group_id_final,
|
||||
send_token,
|
||||
cancel_token,
|
||||
refresh_token,
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction("add_future", add_future_txn)
|
||||
|
||||
async def update_future_timestamp(
|
||||
self,
|
||||
future_id: FutureID,
|
||||
current_ts: Timestamp,
|
||||
) -> Timeout:
|
||||
"""Updates the timestamp of the timeout future for the given future_id.
|
||||
|
||||
Params:
|
||||
future_id: The ID of the future timeout to update.
|
||||
current_ts: The current time to which the future's timestamp will be set relative to.
|
||||
|
||||
Returns: The timeout value for the future.
|
||||
|
||||
Raises:
|
||||
NotFoundError if there is no timeout future for the given future_id.
|
||||
"""
|
||||
rows = await self.db_pool.execute(
|
||||
"update_future_timestamp",
|
||||
"""
|
||||
UPDATE future_timeouts
|
||||
SET timestamp = ? + timeout
|
||||
WHERE future_id = ?
|
||||
RETURNING timeout
|
||||
""",
|
||||
current_ts,
|
||||
future_id,
|
||||
)
|
||||
if len(rows) == 0 or len(rows[0]) == 0:
|
||||
raise NotFoundError
|
||||
return Timeout(rows[0][0])
|
||||
|
||||
async def has_group_id(
|
||||
self,
|
||||
user_id: UserID,
|
||||
group_id: str,
|
||||
) -> bool:
|
||||
"""Returns whether a future group exists for the given group_id."""
|
||||
count: int = await self.db_pool.simple_select_one_onecol(
|
||||
table="future_groups",
|
||||
keyvalues={"user_localpart": user_id.localpart, "group_id": group_id},
|
||||
retcol="COUNT(1)",
|
||||
desc="has_group_id",
|
||||
)
|
||||
return count > 0
|
||||
|
||||
async def get_future_by_token(
|
||||
self,
|
||||
token: str,
|
||||
) -> Tuple[FutureID, FutureTokenType]:
|
||||
"""Returns the future ID for the given token, and what type of token it is.
|
||||
|
||||
Raises:
|
||||
NotFoundError if there is no future for the given token.
|
||||
"""
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="future_tokens",
|
||||
keyvalues={"token": token},
|
||||
retcols=("is_send", "is_cancel", "is_refresh", "future_id"),
|
||||
allow_none=True,
|
||||
desc="get_future_by_token",
|
||||
)
|
||||
if row is None:
|
||||
raise NotFoundError
|
||||
|
||||
return FutureID(row[3]), FutureTokenType(sum(row[i] * 2**i for i in range(3)))
|
||||
|
||||
async def get_all_futures_for_user(
|
||||
self,
|
||||
user_id: UserID,
|
||||
) -> List[JsonDict]:
|
||||
"""Returns all pending futures that were requested by the given user."""
|
||||
rows = await self.db_pool.execute(
|
||||
"_get_all_futures_for_user_txn",
|
||||
"""
|
||||
SELECT
|
||||
group_id, timeout,
|
||||
room_id, event_type, state_key, content,
|
||||
send_token, cancel_token, refresh_token
|
||||
FROM futures
|
||||
LEFT JOIN (
|
||||
SELECT future_id, token AS send_token
|
||||
FROM future_tokens
|
||||
WHERE is_send
|
||||
) USING (future_id)
|
||||
LEFT JOIN (
|
||||
SELECT future_id, token AS cancel_token
|
||||
FROM future_tokens
|
||||
WHERE is_cancel
|
||||
) USING (future_id)
|
||||
LEFT JOIN (
|
||||
SELECT future_id, token AS refresh_token
|
||||
FROM future_tokens
|
||||
WHERE is_refresh
|
||||
) USING (future_id)
|
||||
LEFT JOIN future_timeouts USING (future_id)
|
||||
WHERE user_localpart = ?
|
||||
""",
|
||||
user_id.localpart,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"future_group_id": str(row[0]),
|
||||
**({"future_timeout": int(row[1])} if row[1] is not None else {}),
|
||||
"room_id": str(row[2]),
|
||||
"type": str(row[3]),
|
||||
**({"state_key": str(row[4])} if row[4] is not None else {}),
|
||||
# TODO: Verify contents?
|
||||
"content": db_to_json(row[5]),
|
||||
# TODO: If suppressing send/cancel tokens is allowed, omit them if None
|
||||
"send_token": str(row[6]),
|
||||
"cancel_token": str(row[7]),
|
||||
**({"refresh_token": str(row[8])} if row[8] is not None else {}),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_all_future_timestamps(
|
||||
self,
|
||||
) -> List[Tuple[FutureID, Timestamp]]:
|
||||
"""Returns all timeout futures' IDs and when they will be sent if not refreshed."""
|
||||
return await self.db_pool.simple_select_list(
|
||||
table="future_timeouts",
|
||||
keyvalues=None,
|
||||
retcols=("future_id", "timestamp"),
|
||||
desc="get_all_future_timestamps",
|
||||
)
|
||||
|
||||
async def pop_future_event(
|
||||
self,
|
||||
future_id: FutureID,
|
||||
) -> Tuple[
|
||||
str, RoomID, EventType, Optional[StateKey], Optional[Timestamp], JsonDict
|
||||
]:
|
||||
"""Get the partial event of the future with the specified future_id,
|
||||
and remove all futures in its group from the DB.
|
||||
"""
|
||||
|
||||
def pop_future_event_txn(txn: LoggingTransaction) -> Tuple[Any, ...]:
|
||||
try:
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="futures",
|
||||
keyvalues={"future_id": future_id},
|
||||
retcols=(
|
||||
"user_localpart",
|
||||
"group_id",
|
||||
"room_id",
|
||||
"event_type",
|
||||
"state_key",
|
||||
"origin_server_ts",
|
||||
"content",
|
||||
),
|
||||
)
|
||||
assert row is not None
|
||||
except StoreError:
|
||||
raise NotFoundError
|
||||
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="future_groups",
|
||||
keyvalues={
|
||||
"user_localpart": str(row[0]),
|
||||
"group_id": GroupID(row[1]),
|
||||
},
|
||||
)
|
||||
return (row[0], *row[2:])
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
"pop_future_event", pop_future_event_txn
|
||||
)
|
||||
room_id = RoomID.from_string(row[1])
|
||||
# TODO: verify contents?
|
||||
content: JsonDict = db_to_json(row[5])
|
||||
return (row[0], room_id, *row[2:5], content)
|
||||
|
||||
async def remove_future(
|
||||
self,
|
||||
future_id: FutureID,
|
||||
) -> None:
|
||||
"""Removes the future for the given future_id from the DB.
|
||||
If the future is the only timeout future in its group, removes the whole group.
|
||||
"""
|
||||
|
||||
def remove_future_txn(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"""
|
||||
WITH futures_with_timeout AS (
|
||||
SELECT future_id, user_localpart, group_id
|
||||
FROM futures
|
||||
JOIN future_timeouts USING (future_id)
|
||||
), timeout_group AS (
|
||||
SELECT user_localpart, group_id
|
||||
FROM futures_with_timeout
|
||||
WHERE future_id = ?
|
||||
)
|
||||
SELECT DISTINCT COUNT(1), user_localpart, group_id
|
||||
FROM futures_with_timeout
|
||||
JOIN timeout_group USING (user_localpart, group_id)
|
||||
GROUP BY user_localpart, group_id
|
||||
""",
|
||||
(future_id,),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
|
||||
if row is not None and row[0] == 1:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn,
|
||||
"future_groups",
|
||||
keyvalues={
|
||||
"user_localpart": row[1],
|
||||
"group_id": row[2],
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn,
|
||||
"futures",
|
||||
keyvalues={
|
||||
"future_id": future_id,
|
||||
},
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("remove_future", remove_future_txn)
|
||||
|
||||
|
||||
def _generate_future_token() -> FutureToken:
|
||||
"""Generates an opaque string, for use as a future token"""
|
||||
|
||||
# We use the following format for future tokens:
|
||||
# syf_<random string>_<base62 crc check>
|
||||
# They are NOT scoped to user localparts so that any delegate given the token
|
||||
# won't necessarily know which user created the future.
|
||||
|
||||
random_string = stringutils.random_string(20)
|
||||
base = f"syf_{random_string}"
|
||||
|
||||
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||
return FutureToken(f"{base}_{crc}")
|
||||
|
||||
|
||||
def _generate_group_id() -> GroupID:
|
||||
"""Generates an opaque string, for use as a future group ID"""
|
||||
|
||||
# We use the following format for future tokens:
|
||||
# syf_<random string>_<base62 crc check>
|
||||
# They are scoped to user localparts, but that CANNOT be relied on
|
||||
# to keep them globally unique, as users may set their own group_id.
|
||||
|
||||
random_string = stringutils.random_string(20)
|
||||
base = f"syg_{random_string}"
|
||||
|
||||
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
|
||||
return GroupID(f"{base}_{crc}")
|
|
@ -19,7 +19,7 @@
|
|||
#
|
||||
#
|
||||
|
||||
SCHEMA_VERSION = 85 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 86 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
@ -139,6 +139,10 @@ Changes in SCHEMA_VERSION = 84
|
|||
|
||||
Changes in SCHEMA_VERSION = 85
|
||||
- Add a column `suspended` to the `users` table
|
||||
|
||||
Changes in SCHEMA_VERSION = 86:
|
||||
- MSC4140: Add `futures` table that keeps track of events that are to
|
||||
be posted in response to a refreshable timeout or an on-demand action.
|
||||
"""
|
||||
|
||||
|
||||
|
|
58
synapse/storage/schema/main/delta/86/01_add_futures.sql
Normal file
58
synapse/storage/schema/main/delta/86/01_add_futures.sql
Normal file
|
@ -0,0 +1,58 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
CREATE TABLE future_groups (
|
||||
user_localpart TEXT NOT NULL,
|
||||
group_id TEXT NOT NULL,
|
||||
PRIMARY KEY (user_localpart, group_id)
|
||||
);
|
||||
|
||||
CREATE TABLE futures (
|
||||
future_id INTEGER PRIMARY KEY, -- An alias of rowid in SQLite
|
||||
user_localpart TEXT NOT NULL,
|
||||
group_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
state_key TEXT,
|
||||
origin_server_ts BIGINT,
|
||||
content bytea NOT NULL,
|
||||
FOREIGN KEY (user_localpart, group_id)
|
||||
REFERENCES future_groups (user_localpart, group_id) ON DELETE CASCADE
|
||||
);
|
||||
-- TODO: Consider a trigger/constraint to disallow adding an action future to an empty group
|
||||
|
||||
CREATE INDEX future_group_idx ON futures (user_localpart, group_id);
|
||||
CREATE INDEX room_state_event_idx ON futures (room_id, event_type, state_key) WHERE state_key IS NOT NULL;
|
||||
|
||||
CREATE TABLE future_timeouts (
|
||||
future_id INTEGER PRIMARY KEY
|
||||
REFERENCES futures (future_id) ON DELETE CASCADE,
|
||||
timeout BIGINT NOT NULL,
|
||||
timestamp BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE future_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
future_id INTEGER NOT NULL
|
||||
REFERENCES futures (future_id) ON DELETE CASCADE,
|
||||
is_send BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
is_cancel BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
is_refresh BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
CHECK (
|
||||
+ CAST(is_send AS INTEGER)
|
||||
+ CAST(is_cancel AS INTEGER)
|
||||
+ CAST(is_refresh AS INTEGER)
|
||||
= 1),
|
||||
UNIQUE (future_id, is_send, is_cancel, is_refresh)
|
||||
);
|
||||
-- TODO: Consider a trigger/constraint to disallow refresh tokens for futures without a timeout
|
|
@ -0,0 +1,14 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
ALTER TABLE futures ALTER COLUMN future_id ADD GENERATED ALWAYS AS IDENTITY;
|
Loading…
Reference in a new issue