Compare commits

...

5 commits

Author SHA1 Message Date
Andrew Ferrazzutti a2cfcecfaf
Merge b1e98eec0e into 27756c9fdf 2024-06-27 17:05:02 +00:00
Andrew Ferrazzutti b1e98eec0e Rename GET /futures response keys to match request 2024-06-27 07:56:38 -04:00
Andrew Ferrazzutti 7c23b77f3e Rename /futures to /future 2024-06-27 07:56:38 -04:00
Andrew Ferrazzutti 77e1414b08 Add changelog 2024-06-27 07:56:38 -04:00
Andrew Ferrazzutti 222f5b120b Support MSC4140: Delayed events (Futures) 2024-06-27 07:56:38 -04:00
15 changed files with 1295 additions and 5 deletions

View file

@ -0,0 +1,2 @@
Add initial implementation of delayed events (Futures) as proposed by [MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140).

View file

@ -1662,6 +1662,19 @@ rc_registration_token_validity:
burst_count: 6
```
---
### `rc_future_token_validity`
This option checks the validity of future tokens that ratelimits requests based on
the client's IP address.
Defaults to `per_second: 0.1`, `burst_count: 5`.
Example configuration:
```yaml
rc_future_token_validity:
per_second: 0.3
burst_count: 6
```
---
### `rc_login`
This option specifies several limits for login:

View file

@ -441,6 +441,24 @@ class ExperimentalConfig(Config):
"msc3916_authenticated_media_enabled", False
)
# MSC4140: Delayed events (Futures)
# The maximum allowed delay for timeout futures.
try:
self.msc4140_max_future_timeout_duration = int(
experimental["msc4140_max_future_timeout_duration"]
)
if self.msc4140_max_future_timeout_duration < 0:
raise ValueError
except ValueError:
raise ConfigError(
"Timeout duration must be a positive integer",
("experimental", "msc4140_max_future_timeout_duration"),
)
except KeyError:
self.msc4140_max_future_timeout_duration = (
10 * 365 * 24 * 60 * 60 * 1000
) # 10 years
# MSC4151: Report room API (Client-Server API)
self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False)

View file

@ -124,6 +124,13 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.1, "burst_count": 5},
)
# Ratelimit requests to send/cancel/refresh futures (MSC4140):
self.rc_future_token_validity = RatelimitSettings.parse(
config,
"rc_future_token_validity",
defaults={"per_second": 0.1, "burst_count": 5},
)
# It is reasonable to login with a bunch of devices at once (i.e. when
# setting up an account), but it is *not* valid to continually be
# logging into new devices.

329
synapse/handlers/futures.py Normal file
View file

@ -0,0 +1,329 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
#
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes
from synapse.api.errors import Codes, NotFoundError, ShadowBanError, SynapseError
from synapse.logging.opentracing import set_tag
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.futures import (
EventType,
FutureID,
FutureTokenType,
StateKey,
Timeout,
Timestamp,
)
from synapse.types import JsonDict, Requester, RoomID, UserID, create_requester
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class FuturesHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.config = hs.config
self.clock = hs.get_clock()
self.request_ratelimiter = hs.get_request_ratelimiter()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self._hostname = hs.hostname
self._futures: Dict[FutureID, IDelayedCall] = {}
async def _schedule_db_futures() -> None:
all_future_timestamps = await self.store.get_all_future_timestamps()
# Get the time after awaiting to increase accuracy.
# For even more accuracy, could get the time on each loop iteration.
current_ts = self._get_current_ts()
for future_id, ts in all_future_timestamps:
timeout = _get_timeout_between(current_ts, ts)
if timeout > 0:
self._schedule_future(future_id, timeout)
else:
logger.info("Scheduling timeout for future_id %s now", future_id)
run_as_background_process(
"_send_future",
self._send_future,
future_id,
)
self._initialized_from_db = run_as_background_process(
"_schedule_db_futures", _schedule_db_futures
)
async def add_future(
self,
requester: Requester,
*,
room_id: str,
event_type: str,
state_key: Optional[str],
origin_server_ts: Optional[int],
content: JsonDict,
timeout: Optional[int],
group_id: Optional[str],
is_custom_endpoint: bool = False, # TODO: Remove this eventually
) -> JsonDict:
"""Creates a new future, and if it is a timeout future, schedules it to be sent.
Params:
requester: The initial requester of the future.
room_id: The room where the future event should be sent.
event_type: The event type of the future event.
state_key: The state key of the future event, or None if it is not a state event.
origin_server_ts: The custom timestamp to send the future event with.
If None, the timestamp will be the actual time when the future is sent.
content: The content of the future event.
timeout: How long (in milliseconds) to wait before automatically sending the future.
If None, the future will never be automatically sent.
group_id: The future group this future belongs to.
If None, the future will be added to a new group with an auto-generated ID.
Returns:
A mapping of tokens that can be used in requests to send, cancel, or refresh the future.
"""
if timeout is not None:
max_timeout = self.config.experimental.msc4140_max_future_timeout_duration
if timeout > max_timeout:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
f"'{'future_timeout' if not is_custom_endpoint else 'timeout'}' is too large",
Codes.INVALID_PARAM,
{
"max_timeout_duration": max_timeout,
},
)
else:
if group_id is None:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"At least one of 'future_timeout' and 'future_group_id' is required",
Codes.MISSING_PARAM,
)
await self.request_ratelimiter.ratelimit(requester)
await self._initialized_from_db
if (
timeout is None
and group_id is not None
and not await self.store.has_group_id(requester.user, group_id)
):
raise SynapseError(
HTTPStatus.CONFLICT,
# TODO: Think of a better wording for this
"The given 'future_group_id' does not exist and cannot be created with an action future",
Codes.INVALID_PARAM,
)
(
future_id,
group_id,
send_token,
cancel_token,
refresh_token,
) = await self.store.add_future(
user_id=requester.user,
group_id=group_id,
room_id=RoomID.from_string(room_id),
event_type=event_type,
state_key=state_key,
origin_server_ts=origin_server_ts,
content=content,
timeout=timeout,
timestamp=timeout + self.clock.time_msec() if timeout is not None else None,
)
if timeout is not None:
self._schedule_future(future_id, Timeout(timeout))
ret = {
"future_group_id": group_id,
"send_token": send_token,
"cancel_token": cancel_token,
}
# TODO: type hint for non-None refresh_token return value when timeout argument is non-None
if refresh_token is not None:
ret["refresh_token"] = refresh_token
return ret
async def use_future_token(self, token: str) -> None:
"""Executes the appropriate action for the given token.
Params:
token: The token of the future to act on.
"""
await self._initialized_from_db
future_id, future_token_type = await self.store.get_future_by_token(token)
if future_token_type == FutureTokenType.SEND:
await self._send_future(future_id)
elif future_token_type == FutureTokenType.CANCEL:
await self._cancel_future(future_id)
elif future_token_type == FutureTokenType.REFRESH:
await self._refresh_future(future_id)
async def _send_future(self, future_id: FutureID) -> None:
await self._initialized_from_db
args = await self.store.pop_future_event(future_id)
self._unschedule_future(future_id)
await self._send_event(*args)
async def _send_future_on_timeout(self, future_id: FutureID) -> None:
await self._initialized_from_db
try:
args = await self.store.pop_future_event(future_id)
except NotFoundError:
logger.warning(
"future_id %s missing from DB on timeout, so it must have been cancelled/sent",
future_id,
)
else:
await self._send_event(*args)
async def _cancel_future(self, future_id: FutureID) -> None:
await self._initialized_from_db
await self.store.remove_future(future_id)
self._unschedule_future(future_id)
async def _refresh_future(self, future_id: FutureID) -> None:
await self._initialized_from_db
timeout = await self.store.update_future_timestamp(
future_id,
self._get_current_ts(),
)
self._unschedule_future(future_id)
self._schedule_future(future_id, timeout)
def _schedule_future(
self,
future_id: FutureID,
timeout: Timeout,
) -> None:
"""NOTE: Should not be called with a future_id that isn't in the DB, or with a negative timeout."""
delay_sec = timeout / 1000
logger.info(
"Scheduling timeout for future_id %d in %.3fs", future_id, delay_sec
)
self._futures[future_id] = self.clock.call_later(
delay_sec,
run_as_background_process,
"_send_future_on_timeout",
self._send_future_on_timeout,
future_id,
)
def _unschedule_future(self, future_id: FutureID) -> None:
try:
delayed_call = self._futures.pop(future_id)
self.clock.cancel_call_later(delayed_call)
except KeyError:
logger.debug("future_id %s was not mapped to a delayed call", future_id)
async def get_all_futures_for_user(self, requester: Requester) -> List[JsonDict]:
"""Return all pending futures requested by the given user."""
await self.request_ratelimiter.ratelimit(requester)
await self._initialized_from_db
return await self.store.get_all_futures_for_user(requester.user)
# TODO: Consider getting a list of all timeout futures that were in this one's group, and cancel their delayed calls
async def _send_event(
self,
user_localpart: str,
room_id: RoomID,
event_type: EventType,
state_key: Optional[StateKey],
origin_server_ts: Optional[Timestamp],
content: JsonDict,
# TODO: support guests
# is_guest: bool,
txn_id: Optional[str] = None,
) -> None:
user_id = UserID(user_localpart, self._hostname)
requester = create_requester(
user_id,
# is_guest=is_guest,
)
try:
if state_key is not None and event_type == EventTypes.Member:
membership = content.get("membership", None)
event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=str(room_id),
action=membership,
content=content,
origin_server_ts=origin_server_ts,
)
else:
event_dict: JsonDict = {
"type": event_type,
"content": content,
"room_id": room_id.to_string(),
"sender": user_id.to_string(),
}
if state_key is not None:
event_dict["state_key"] = state_key
if origin_server_ts is not None:
event_dict["origin_server_ts"] = origin_server_ts
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
set_tag("event_id", event_id)
def _get_current_ts(self) -> Timestamp:
return Timestamp(self.clock.time_msec())
def _get_timeout_between(from_ts: Timestamp, to_ts: Timestamp) -> Timeout:
return Timeout(to_ts - from_ts)

View file

@ -34,6 +34,7 @@ from synapse.rest.client import (
directory,
events,
filter,
futures,
initial_sync,
keys,
knock,
@ -103,6 +104,7 @@ class ClientRestResource(JsonResource):
events.register_servlets(hs, client_resource)
room.register_servlets(hs, client_resource)
futures.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)

View file

@ -0,0 +1,90 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
#
""" This module contains REST servlets to do with futures: /future/<paths> """
import logging
from typing import TYPE_CHECKING, List, Tuple
from synapse.api.errors import NotFoundError
from synapse.api.ratelimiting import Ratelimiter
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# TODO: Needs unit testing
class FuturesTokenServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/future/(?P<token>[^/]*)$", releases=(), v1=False
)
CATEGORY = "Future management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.futures_handler = hs.get_futures_handler()
self._ratelimiter = Ratelimiter(
store=hs.get_datastores().main,
clock=hs.get_clock(),
cfg=hs.config.ratelimiting.rc_future_token_validity,
)
async def on_POST(
self, request: SynapseRequest, token: str
) -> Tuple[int, JsonDict]:
# Ratelimit by address, since this is an unauthenticated request
ratelimit_key = (request.getClientAddress().host,)
# Check if we should be ratelimited due to too many previous failed attempts
await self._ratelimiter.ratelimit(None, ratelimit_key, update=False)
try:
await self.futures_handler.use_future_token(token)
return 200, {}
except NotFoundError as e:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
await self._ratelimiter.can_do_action(None, ratelimit_key)
# TODO: Decide if the error code should be left at 404, instead of 410 as per the MSC
e.code = 410
raise
# TODO: Needs unit testing
class FuturesServlet(RestServlet):
PATTERNS = client_patterns(r"/org\.matrix\.msc4140/future$", releases=(), v1=False)
CATEGORY = "Future management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.futures_handler = hs.get_futures_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, List[JsonDict]]:
requester = await self.auth.get_user_by_req(request)
return 200, await self.futures_handler.get_all_futures_for_user(requester)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
FuturesTokenServlet(hs).register(http_server)
FuturesServlet(hs).register(http_server)

View file

@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@ -36,6 +36,7 @@ from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
InvalidAPICallError,
InvalidClientCredentialsError,
MissingClientTokenError,
ShadowBanError,
@ -64,7 +65,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types import (
JsonDict,
Requester,
StrCollection,
StreamToken,
ThirdPartyInstanceID,
UserID,
)
from synapse.types.state import StateFilter
from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name, random_string
@ -409,6 +417,261 @@ class RoomSendEventRestServlet(TransactionRestServlet):
)
# TODO: Needs unit testing
class RoomStateFutureRestServlet(RestServlet):
CATEGORY = "Future sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.futures_handler = hs.get_futures_handler()
def register(self, http_server: HttpServer) -> None:
# /org.matrix.msc4140/rooms/$roomid/state_future/$eventtype
no_state_key = r"/org\.matrix\.msc4140/rooms/(?P<room_id>[^/]*)/state_future/(?P<event_type>[^/]*)$"
# /org.matrix.msc4140/rooms/$roomid/state_future/$eventtype/$statekey
state_key = (
r"/org\.matrix\.msc4140/rooms/(?P<room_id>[^/]*)/state_future/"
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)"
)
http_server.register_paths(
"PUT",
client_patterns(state_key, releases=(), v1=False),
self.on_PUT,
self.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(no_state_key, releases=(), v1=False),
self.on_PUT_no_state_key,
self.__class__.__name__,
)
def on_PUT_no_state_key(
self, request: SynapseRequest, room_id: str, event_type: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_PUT(request, room_id, event_type, "")
async def on_PUT(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
state_key: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if txn_id:
set_tag("txn_id", txn_id)
ret = await self.futures_handler.add_future(
requester,
room_id=room_id,
event_type=event_type,
state_key=state_key,
origin_server_ts=(
parse_integer(request, "ts") if requester.app_service else None
),
content=parse_json_object_from_request(request),
timeout=parse_integer(request, "future_timeout"),
group_id=parse_string(request, "future_group_id"),
)
for k, v in ret.items():
set_tag(k, v)
return 200, ret
# TODO: Needs unit testing
class RoomSendFutureRestServlet(TransactionRestServlet):
CATEGORY = "Future sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.auth = hs.get_auth()
self.futures_handler = hs.get_futures_handler()
def register(self, http_server: HttpServer) -> None:
# /org.matrix.msc4140/rooms/$roomid/send_future/$event_type[/$txn_id]
PATTERNS = r"/org\.matrix\.msc4140/rooms/(?P<room_id>[^/]*)/send_future/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, releases=(), v1=False)
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
event_type: str,
) -> Tuple[int, JsonDict]:
ret = await self.futures_handler.add_future(
requester,
room_id=room_id,
event_type=event_type,
state_key=None,
origin_server_ts=(
parse_integer(request, "ts") if requester.app_service else None
),
content=parse_json_object_from_request(request),
timeout=parse_integer(request, "future_timeout"),
group_id=parse_string(request, "future_group_id"),
)
for k, v in ret.items():
set_tag(k, v)
return 200, ret
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, event_type)
async def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return await self.txns.fetch_or_execute_request(
request,
requester,
self._do,
request,
requester,
room_id,
event_type,
)
# TODO: Remove in favour of the Room{Send,State}FutureRestServlets. Otherwise, this needs unit testing
class RoomFutureRestServlet(TransactionRestServlet):
CATEGORY = "Future sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.auth = hs.get_auth()
self.futures_handler = hs.get_futures_handler()
def register(self, http_server: HttpServer) -> None:
# NOTE: Difference from MSC = remove /send part of path, to avoid ambiguity with regular event sending
# /org.matrix.msc4140/rooms/$roomid/future[/$txn_id]
PATTERNS = r"/org\.matrix\.msc4140/rooms/(?P<room_id>[^/]*)/future"
register_txn_path(self, PATTERNS, http_server, releases=(), v1=False)
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
try:
timeout = int(body["timeout"])
except KeyError:
raise SynapseError(400, "'timeout' is missing", Codes.MISSING_PARAM)
except Exception:
raise SynapseError(400, "'timeout' is not an integer", Codes.INVALID_PARAM)
try:
timeout_item: JsonDict = {
str(k): v for k, v in body["send_on_timeout"].items()
}
except KeyError:
raise SynapseError(400, "'send_on_timeout' is missing", Codes.MISSING_PARAM)
except Exception:
raise InvalidAPICallError("'send_on_timeout' is not valid JSON")
try:
send_on_action: Dict[str, JsonDict] = {
str(action_name): {str(k): v for k, v in action_item.items()}
for action_name, action_item in body.get("send_on_action", {}).items()
}
except Exception:
raise InvalidAPICallError(
"'send_on_action' is not a JSON mapping of action names to event content"
)
# TODO: Should action events be able to have a different ts value than the timeout event?
origin_server_ts = (
parse_integer(request, "ts") if requester.app_service else None
)
# TODO: send_now. But first validate send_on_timeout & send_on_action
send_on_timeout_resp = await self.futures_handler.add_future(
requester,
room_id=room_id,
event_type=str(timeout_item["type"]),
state_key=(
str(timeout_item["state_key"]) if "state_key" in timeout_item else None
),
origin_server_ts=origin_server_ts,
content=timeout_item["content"],
timeout=timeout,
group_id=None,
is_custom_endpoint=True,
)
group_id = str(send_on_timeout_resp.pop("future_group_id"))
ret = {"send_on_timeout": send_on_timeout_resp}
if send_on_action:
send_on_action_resp: Dict[str, Dict[str, JsonDict]] = {}
for action_name, action_item in send_on_action.items():
action_item_resp = await self.futures_handler.add_future(
requester,
room_id=room_id,
event_type=str(action_item["type"]),
state_key=(
str(action_item["state_key"])
if "state_key" in action_item
else None
),
origin_server_ts=origin_server_ts,
content=action_item["content"],
timeout=None,
group_id=group_id,
is_custom_endpoint=True,
)
try:
del action_item_resp["future_group_id"]
except KeyError:
pass
send_on_action_resp[action_name] = action_item_resp
ret["send_on_action"] = send_on_action_resp
return 200, ret
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id)
async def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return await self.txns.fetch_or_execute_request(
request,
requester,
self._do,
request,
requester,
room_id,
)
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
CATEGORY = "Event sending requests"
@ -1344,6 +1607,9 @@ def register_txn_path(
servlet: RestServlet,
regex_string: str,
http_server: HttpServer,
releases: StrCollection = ("r0", "v3"),
unstable: bool = True,
v1: bool = True,
) -> None:
"""Registers a transaction-based path.
@ -1362,13 +1628,13 @@ def register_txn_path(
raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
http_server.register_paths(
"POST",
client_patterns(regex_string + "$", v1=True),
client_patterns(regex_string + "$", releases, unstable, v1),
on_POST,
servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", releases, unstable, v1),
on_PUT,
servlet.__class__.__name__,
)
@ -1503,12 +1769,14 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomStateEventRestServlet(hs).register(http_server)
RoomStateFutureRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
RoomSendFutureRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
@ -1523,6 +1791,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SearchRestServlet(hs).register(http_server)
RoomCreateRestServlet(hs).register(http_server)
TimestampLookupRestServlet(hs).register(http_server)
RoomFutureRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process.
if hs.config.worker.worker_app is None:

View file

@ -149,6 +149,8 @@ class VersionsRestServlet(RestServlet):
is not None
)
),
# MSC4140: Delayed events (Futures)
"org.matrix.msc4140": True,
# MSC4151: Report room API (Client-Server API)
"org.matrix.msc4151": self.config.experimental.msc4151_enabled,
},

View file

@ -76,6 +76,7 @@ from synapse.handlers.event_auth import EventAuthHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.federation import FederationHandler
from synapse.handlers.federation_event import FederationEventHandler
from synapse.handlers.futures import FuturesHandler
from synapse.handlers.identity import IdentityHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
@ -248,6 +249,7 @@ class HomeServer(metaclass=abc.ABCMeta):
"account_validity",
"auth",
"deactivate_account",
"futures",
"message",
"pagination",
"profile",
@ -936,3 +938,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_task_scheduler(self) -> TaskScheduler:
return TaskScheduler(self)
@cache_in_self
def get_futures_handler(self) -> FuturesHandler:
return FuturesHandler(self)

View file

@ -54,6 +54,7 @@ from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
from .experimental_features import ExperimentalFeaturesStore
from .filtering import FilteringWorkerStore
from .futures import FuturesStore
from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
@ -156,6 +157,7 @@ class DataStore(
LockStore,
SessionStore,
TaskSchedulerWorkerStore,
FuturesStore,
):
def __init__(
self,

View file

@ -0,0 +1,474 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
#
from binascii import crc32
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
NewType,
Optional,
Tuple,
TypeVar,
)
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict, RoomID, UserID
from synapse.util import json_encoder, stringutils as stringutils
from synapse.util.stringutils import base62_encode
if TYPE_CHECKING:
from synapse.server import HomeServer
class FutureTokenType(Enum):
SEND = 1
CANCEL = 2
REFRESH = 4
def __str__(self) -> str:
if self == FutureTokenType.SEND:
return "is_send"
elif self == FutureTokenType.CANCEL:
return "is_cancel"
else:
return "is_refresh"
FutureToken = NewType("FutureToken", str)
GroupID = NewType("GroupID", str)
EventType = NewType("EventType", str)
StateKey = NewType("StateKey", str)
FutureID = NewType("FutureID", int)
Timeout = NewType("Timeout", int)
Timestamp = NewType("Timestamp", int)
AddFutureReturn = Tuple[
FutureID,
GroupID,
FutureToken,
FutureToken,
Optional[FutureToken],
]
# TODO: Try to support workers
class FuturesStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
async def add_future(
self,
*,
user_id: UserID,
group_id: Optional[str],
room_id: RoomID,
event_type: str,
state_key: Optional[str],
origin_server_ts: Optional[int],
content: JsonDict,
timeout: Optional[int],
timestamp: Optional[int],
) -> AddFutureReturn:
"""Inserts a new future in the DB."""
user_localpart = user_id.localpart
def add_future_txn(txn: LoggingTransaction) -> AddFutureReturn:
T = TypeVar("T", bound=str)
def insert_and_get_unique_colval(
table: str,
values: Dict[str, Any],
column: str,
colval_generator: Callable[[], T],
) -> T:
attempts_remaining = 10
while True:
colval = colval_generator()
try:
self.db_pool.simple_insert_txn(
txn,
table,
values={
**values,
column: colval,
},
)
return colval
# TODO: Handle only the error type for DB key collisions
except Exception:
if attempts_remaining > 0:
attempts_remaining -= 1
else:
raise SynapseError(
500,
f"Couldn't generate a unique value for column {column} in table {table}",
)
if group_id is None:
group_id_final = insert_and_get_unique_colval(
"future_groups",
{
"user_localpart": user_localpart,
},
"group_id",
_generate_group_id,
)
else:
txn.execute(
"""
INSERT INTO future_groups (user_localpart, group_id)
VALUES (?, ?)
ON CONFLICT DO NOTHING
""",
(user_localpart, group_id),
)
group_id_final = GroupID(group_id)
txn.execute(
"""
INSERT INTO futures (
user_localpart, group_id,
room_id, event_type, state_key, origin_server_ts,
content
) VALUES (
?, ?,
?, ?, ?, ?,
?
)
RETURNING future_id
""",
(
user_localpart,
group_id_final,
room_id.to_string(),
event_type,
state_key,
origin_server_ts,
json_encoder.encode(content),
),
)
row = txn.fetchone()
assert row is not None
future_id = FutureID(row[0])
def insert_and_get_future_token(token_type: FutureTokenType) -> FutureToken:
return insert_and_get_unique_colval(
"future_tokens",
{
"future_id": future_id,
str(token_type): True,
},
"token",
_generate_future_token,
)
send_token = insert_and_get_future_token(FutureTokenType.SEND)
cancel_token = insert_and_get_future_token(FutureTokenType.CANCEL)
if timeout is not None:
self.db_pool.simple_insert_txn(
txn,
table="future_timeouts",
values={
"future_id": future_id,
"timeout": timeout,
"timestamp": timestamp,
},
)
refresh_token = insert_and_get_future_token(FutureTokenType.REFRESH)
else:
refresh_token = None
return (
future_id,
group_id_final,
send_token,
cancel_token,
refresh_token,
)
return await self.db_pool.runInteraction("add_future", add_future_txn)
async def update_future_timestamp(
self,
future_id: FutureID,
current_ts: Timestamp,
) -> Timeout:
"""Updates the timestamp of the timeout future for the given future_id.
Params:
future_id: The ID of the future timeout to update.
current_ts: The current time to which the future's timestamp will be set relative to.
Returns: The timeout value for the future.
Raises:
NotFoundError if there is no timeout future for the given future_id.
"""
rows = await self.db_pool.execute(
"update_future_timestamp",
"""
UPDATE future_timeouts
SET timestamp = ? + timeout
WHERE future_id = ?
RETURNING timeout
""",
current_ts,
future_id,
)
if len(rows) == 0 or len(rows[0]) == 0:
raise NotFoundError
return Timeout(rows[0][0])
async def has_group_id(
self,
user_id: UserID,
group_id: str,
) -> bool:
"""Returns whether a future group exists for the given group_id."""
count: int = await self.db_pool.simple_select_one_onecol(
table="future_groups",
keyvalues={"user_localpart": user_id.localpart, "group_id": group_id},
retcol="COUNT(1)",
desc="has_group_id",
)
return count > 0
async def get_future_by_token(
self,
token: str,
) -> Tuple[FutureID, FutureTokenType]:
"""Returns the future ID for the given token, and what type of token it is.
Raises:
NotFoundError if there is no future for the given token.
"""
row = await self.db_pool.simple_select_one(
table="future_tokens",
keyvalues={"token": token},
retcols=("is_send", "is_cancel", "is_refresh", "future_id"),
allow_none=True,
desc="get_future_by_token",
)
if row is None:
raise NotFoundError
return FutureID(row[3]), FutureTokenType(sum(row[i] * 2**i for i in range(3)))
async def get_all_futures_for_user(
self,
user_id: UserID,
) -> List[JsonDict]:
"""Returns all pending futures that were requested by the given user."""
rows = await self.db_pool.execute(
"_get_all_futures_for_user_txn",
"""
SELECT
group_id, timeout,
room_id, event_type, state_key, content,
send_token, cancel_token, refresh_token
FROM futures
LEFT JOIN (
SELECT future_id, token AS send_token
FROM future_tokens
WHERE is_send
) USING (future_id)
LEFT JOIN (
SELECT future_id, token AS cancel_token
FROM future_tokens
WHERE is_cancel
) USING (future_id)
LEFT JOIN (
SELECT future_id, token AS refresh_token
FROM future_tokens
WHERE is_refresh
) USING (future_id)
LEFT JOIN future_timeouts USING (future_id)
WHERE user_localpart = ?
""",
user_id.localpart,
)
return [
{
"future_group_id": str(row[0]),
**({"future_timeout": int(row[1])} if row[1] is not None else {}),
"room_id": str(row[2]),
"type": str(row[3]),
**({"state_key": str(row[4])} if row[4] is not None else {}),
# TODO: Verify contents?
"content": db_to_json(row[5]),
# TODO: If suppressing send/cancel tokens is allowed, omit them if None
"send_token": str(row[6]),
"cancel_token": str(row[7]),
**({"refresh_token": str(row[8])} if row[8] is not None else {}),
}
for row in rows
]
async def get_all_future_timestamps(
self,
) -> List[Tuple[FutureID, Timestamp]]:
"""Returns all timeout futures' IDs and when they will be sent if not refreshed."""
return await self.db_pool.simple_select_list(
table="future_timeouts",
keyvalues=None,
retcols=("future_id", "timestamp"),
desc="get_all_future_timestamps",
)
async def pop_future_event(
self,
future_id: FutureID,
) -> Tuple[
str, RoomID, EventType, Optional[StateKey], Optional[Timestamp], JsonDict
]:
"""Get the partial event of the future with the specified future_id,
and remove all futures in its group from the DB.
"""
def pop_future_event_txn(txn: LoggingTransaction) -> Tuple[Any, ...]:
try:
row = self.db_pool.simple_select_one_txn(
txn,
table="futures",
keyvalues={"future_id": future_id},
retcols=(
"user_localpart",
"group_id",
"room_id",
"event_type",
"state_key",
"origin_server_ts",
"content",
),
)
assert row is not None
except StoreError:
raise NotFoundError
self.db_pool.simple_delete_txn(
txn,
table="future_groups",
keyvalues={
"user_localpart": str(row[0]),
"group_id": GroupID(row[1]),
},
)
return (row[0], *row[2:])
row = await self.db_pool.runInteraction(
"pop_future_event", pop_future_event_txn
)
room_id = RoomID.from_string(row[1])
# TODO: verify contents?
content: JsonDict = db_to_json(row[5])
return (row[0], room_id, *row[2:5], content)
async def remove_future(
self,
future_id: FutureID,
) -> None:
"""Removes the future for the given future_id from the DB.
If the future is the only timeout future in its group, removes the whole group.
"""
def remove_future_txn(txn: LoggingTransaction) -> None:
txn.execute(
"""
WITH futures_with_timeout AS (
SELECT future_id, user_localpart, group_id
FROM futures
JOIN future_timeouts USING (future_id)
), timeout_group AS (
SELECT user_localpart, group_id
FROM futures_with_timeout
WHERE future_id = ?
)
SELECT DISTINCT COUNT(1), user_localpart, group_id
FROM futures_with_timeout
JOIN timeout_group USING (user_localpart, group_id)
GROUP BY user_localpart, group_id
""",
(future_id,),
)
row = txn.fetchone()
if row is not None and row[0] == 1:
self.db_pool.simple_delete_one_txn(
txn,
"future_groups",
keyvalues={
"user_localpart": row[1],
"group_id": row[2],
},
)
else:
self.db_pool.simple_delete_one_txn(
txn,
"futures",
keyvalues={
"future_id": future_id,
},
)
await self.db_pool.runInteraction("remove_future", remove_future_txn)
def _generate_future_token() -> FutureToken:
"""Generates an opaque string, for use as a future token"""
# We use the following format for future tokens:
# syf_<random string>_<base62 crc check>
# They are NOT scoped to user localparts so that any delegate given the token
# won't necessarily know which user created the future.
random_string = stringutils.random_string(20)
base = f"syf_{random_string}"
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return FutureToken(f"{base}_{crc}")
def _generate_group_id() -> GroupID:
"""Generates an opaque string, for use as a future group ID"""
# We use the following format for future tokens:
# syf_<random string>_<base62 crc check>
# They are scoped to user localparts, but that CANNOT be relied on
# to keep them globally unique, as users may set their own group_id.
random_string = stringutils.random_string(20)
base = f"syg_{random_string}"
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return GroupID(f"{base}_{crc}")

View file

@ -19,7 +19,7 @@
#
#
SCHEMA_VERSION = 85 # remember to update the list below when updating
SCHEMA_VERSION = 86 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@ -139,6 +139,10 @@ Changes in SCHEMA_VERSION = 84
Changes in SCHEMA_VERSION = 85
- Add a column `suspended` to the `users` table
Changes in SCHEMA_VERSION = 86:
- MSC4140: Add `futures` table that keeps track of events that are to
be posted in response to a refreshable timeout or an on-demand action.
"""

View file

@ -0,0 +1,58 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
CREATE TABLE future_groups (
user_localpart TEXT NOT NULL,
group_id TEXT NOT NULL,
PRIMARY KEY (user_localpart, group_id)
);
CREATE TABLE futures (
future_id INTEGER PRIMARY KEY, -- An alias of rowid in SQLite
user_localpart TEXT NOT NULL,
group_id TEXT NOT NULL,
room_id TEXT NOT NULL,
event_type TEXT NOT NULL,
state_key TEXT,
origin_server_ts BIGINT,
content bytea NOT NULL,
FOREIGN KEY (user_localpart, group_id)
REFERENCES future_groups (user_localpart, group_id) ON DELETE CASCADE
);
-- TODO: Consider a trigger/constraint to disallow adding an action future to an empty group
CREATE INDEX future_group_idx ON futures (user_localpart, group_id);
CREATE INDEX room_state_event_idx ON futures (room_id, event_type, state_key) WHERE state_key IS NOT NULL;
CREATE TABLE future_timeouts (
future_id INTEGER PRIMARY KEY
REFERENCES futures (future_id) ON DELETE CASCADE,
timeout BIGINT NOT NULL,
timestamp BIGINT NOT NULL
);
CREATE TABLE future_tokens (
token TEXT PRIMARY KEY,
future_id INTEGER NOT NULL
REFERENCES futures (future_id) ON DELETE CASCADE,
is_send BOOLEAN NOT NULL DEFAULT FALSE,
is_cancel BOOLEAN NOT NULL DEFAULT FALSE,
is_refresh BOOLEAN NOT NULL DEFAULT FALSE,
CHECK (
+ CAST(is_send AS INTEGER)
+ CAST(is_cancel AS INTEGER)
+ CAST(is_refresh AS INTEGER)
= 1),
UNIQUE (future_id, is_send, is_cancel, is_refresh)
);
-- TODO: Consider a trigger/constraint to disallow refresh tokens for futures without a timeout

View file

@ -0,0 +1,14 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
ALTER TABLE futures ALTER COLUMN future_id ADD GENERATED ALWAYS AS IDENTITY;