Merge branch 'develop' into madlittlemods/msc3575-sliding-sync-0.0.1

This commit is contained in:
Eric Eastwood 2024-05-30 09:21:58 -05:00
commit 49998e053e
27 changed files with 539 additions and 436 deletions

1
changelog.d/17164.bugfix Normal file
View file

@ -0,0 +1 @@
Fix deduplicating of membership events to not create unused state groups.

1
changelog.d/17215.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where duplicate events could be sent down sync when using workers that are overloaded.

1
changelog.d/17229.misc Normal file
View file

@ -0,0 +1 @@
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.

1
changelog.d/17242.misc Normal file
View file

@ -0,0 +1 @@
Clean out invalid destinations from `device_federation_outbox` table.

1
changelog.d/17246.misc Normal file
View file

@ -0,0 +1 @@
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.

View file

@ -200,10 +200,8 @@ netaddr = ">=0.7.18"
# add a lower bound to the Jinja2 dependency.
Jinja2 = ">=3.0"
bleach = ">=1.4.3"
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
# Additionally we need https://github.com/python/typing/pull/817 to allow types to be
# generic over ParamSpecs.
typing-extensions = ">=3.10.0.1"
# We use `Self`, which were added in `typing-extensions` 4.0.
typing-extensions = ">=4.0"
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
cryptography = ">=3.4.7"

View file

@ -777,22 +777,74 @@ class Porter:
await self._setup_events_stream_seqs()
await self._setup_sequence(
"un_partial_stated_event_stream_sequence",
("un_partial_stated_event_stream",),
[("un_partial_stated_event_stream", "stream_id")],
)
await self._setup_sequence(
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
"device_inbox_sequence",
[
("device_inbox", "stream_id"),
("device_federation_outbox", "stream_id"),
],
)
await self._setup_sequence(
"account_data_sequence",
("room_account_data", "room_tags_revisions", "account_data"),
[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
("account_data", "stream_id"),
],
)
await self._setup_sequence(
"receipts_sequence",
[
("receipts_linearized", "stream_id"),
],
)
await self._setup_sequence(
"presence_stream_sequence",
[
("presence_stream", "stream_id"),
],
)
await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
await self._setup_auth_chain_sequence()
await self._setup_sequence(
"application_services_txn_id_seq",
("application_services_txns",),
"txn_id",
[
(
"application_services_txns",
"txn_id",
)
],
)
await self._setup_sequence(
"device_lists_sequence",
[
("device_lists_stream", "stream_id"),
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
("device_lists_changes_converted_stream_position", "stream_id"),
],
)
await self._setup_sequence(
"e2e_cross_signing_keys_sequence",
[
("e2e_cross_signing_keys", "stream_id"),
],
)
await self._setup_sequence(
"push_rules_stream_sequence",
[
("push_rules_stream", "stream_id"),
],
)
await self._setup_sequence(
"pushers_sequence",
[
("pushers", "id"),
("deleted_pushers", "stream_id"),
],
)
# Step 3. Get tables.
@ -1101,12 +1153,11 @@ class Porter:
async def _setup_sequence(
self,
sequence_name: str,
stream_id_tables: Iterable[str],
column_name: str = "stream_id",
stream_id_tables: Iterable[Tuple[str, str]],
) -> None:
"""Set a sequence to the correct value."""
current_stream_ids = []
for stream_id_table in stream_id_tables:
for stream_id_table, column_name in stream_id_tables:
max_stream_id = cast(
int,
await self.sqlite_store.db_pool.simple_select_one_onecol(

View file

@ -496,13 +496,6 @@ class EventCreationHandler:
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
self.membership_types_to_include_profile_data_in = {
Membership.JOIN,
Membership.KNOCK,
}
if self.hs.config.server.include_profile_data_on_invite:
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
@ -594,8 +587,6 @@ class EventCreationHandler:
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict: An entire event
@ -672,29 +663,6 @@ class EventCreationHandler:
self.validator.validate_builder(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in self.membership_types_to_include_profile_data_in:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
displayname = await profile.get_displayname(target)
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
avatar_url = await profile.get_avatar_url(target)
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s", target, e
)
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
if require_consent and not is_exempt:
await self.assert_accepted_privacy_policy(requester)

View file

@ -106,6 +106,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.event_auth_handler = hs.get_event_auth_handler()
self._worker_lock_handler = hs.get_worker_locks_handler()
self._membership_types_to_include_profile_data_in = {
Membership.JOIN,
Membership.KNOCK,
}
if self.hs.config.server.include_profile_data_on_invite:
self._membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.member_linearizer: Linearizer = Linearizer(name="member")
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
@ -785,9 +792,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if (
not self.allow_per_room_profiles and not is_requester_server_notices_user
) or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
# Strip profile data, knowing that new profile data will be added to
# the event's content below using the target's global profile.
content.pop("displayname", None)
content.pop("avatar_url", None)
@ -823,6 +829,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action in ["kick", "unban"]:
effective_membership_state = "leave"
if effective_membership_state not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
# Add profile data for joins etc, if no per-room profile.
if (
effective_membership_state
in self._membership_types_to_include_profile_data_in
):
# If event doesn't include a display name, add one.
profile = self.profile_handler
try:
if "displayname" not in content:
displayname = await profile.get_displayname(target)
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
avatar_url = await profile.get_avatar_url(target)
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
logger.info("Failed to get profile information for %r: %s", target, e)
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None:

View file

@ -284,6 +284,23 @@ class SyncResult:
or self.device_lists
)
@staticmethod
def empty(next_batch: StreamToken) -> "SyncResult":
"Return a new empty result"
return SyncResult(
next_batch=next_batch,
presence=[],
account_data=[],
joined=[],
invited=[],
knocked=[],
archived=[],
to_device=[],
device_lists=DeviceListUpdates(),
device_one_time_keys_count={},
device_unused_fallback_key_types=[],
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class E2eeSyncResult:
@ -497,6 +514,24 @@ class SyncHandler:
if context:
context.tag = sync_label
if since_token is not None:
# We need to make sure this worker has caught up with the token. If
# this returns false it means we timed out waiting, and we should
# just return an empty response.
start = self.clock.time_msec()
if not await self.notifier.wait_for_stream_token(since_token):
logger.warning(
"Timed out waiting for worker to catch up. Returning empty response"
)
return SyncResult.empty(since_token)
# If we've spent significant time waiting to catch up, take it off
# the timeout.
now = self.clock.time_msec()
if now - start > 1_000:
timeout -= now - start
timeout = max(timeout, 0)
# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
if since_token is not None:

View file

@ -1049,6 +1049,11 @@ class MediaRepository:
finally:
t_byte_source.close()
# We flush and close the file to ensure that the bytes have
# been written before getting the size.
f.flush()
f.close()
t_len = os.path.getsize(fname)
# Write to database

View file

@ -137,42 +137,37 @@ class MediaStorage:
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)
main_media_repo_write_trace_scope = start_active_span(
"writing to main media repo"
)
main_media_repo_write_trace_scope.__enter__()
with main_media_repo_write_trace_scope:
try:
try:
with start_active_span("writing to main media repo"):
with open(fname, "wb") as f:
yield f, fname
except Exception as e:
try:
os.remove(fname)
except Exception:
pass
raise e from None
with start_active_span("writing to other storage providers"):
spam_check = (
await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
with start_active_span("writing to other storage providers"):
spam_check = (
await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
)
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])
for provider in self.storage_providers:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)
for provider in self.storage_providers:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)
except Exception as e:
try:
os.remove(fname)
except Exception:
pass
raise e from None
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache

View file

@ -763,6 +763,29 @@ class Notifier:
return result
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
"""Wait for this worker to catch up with the given stream token."""
start = self.clock.time_msec()
while True:
current_token = self.event_sources.get_current_token()
if stream_token.is_before_or_eq(current_token):
return True
now = self.clock.time_msec()
if now - start > 10_000:
return False
logger.info(
"Waiting for current token to reach %s; currently at %s",
stream_token,
current_token,
)
# TODO: be better
await self.clock.sleep(0.5)
async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
) -> Tuple[StrCollection, bool]:

View file

@ -58,6 +58,7 @@ from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -964,6 +965,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
CLEANUP_DEVICE_FEDERATION_OUTBOX = "cleanup_device_federation_outbox"
def __init__(
self,
@ -989,6 +991,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self._remove_dead_devices_from_device_inbox,
)
self.db_pool.updates.register_background_update_handler(
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
self._cleanup_device_federation_outbox,
)
async def _background_drop_index_device_inbox(
self, progress: JsonDict, batch_size: int
) -> int:
@ -1080,6 +1087,75 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _cleanup_device_federation_outbox(
self,
progress: JsonDict,
batch_size: int,
) -> int:
def _cleanup_device_federation_outbox_txn(
txn: LoggingTransaction,
) -> bool:
if "max_stream_id" in progress:
max_stream_id = progress["max_stream_id"]
else:
txn.execute("SELECT max(stream_id) FROM device_federation_outbox")
res = cast(Tuple[Optional[int]], txn.fetchone())
if res[0] is None:
# this can only happen if the `device_inbox` table is empty, in which
# case we have no work to do.
return True
else:
max_stream_id = res[0]
start = progress.get("stream_id", 0)
stop = start + batch_size
sql = """
SELECT destination FROM device_federation_outbox
WHERE ? < stream_id AND stream_id <= ?
"""
txn.execute(sql, (start, stop))
destinations = {d for d, in txn}
to_remove = set()
for d in destinations:
try:
parse_and_validate_server_name(d)
except ValueError:
to_remove.add(d)
self.db_pool.simple_delete_many_txn(
txn,
table="device_federation_outbox",
column="destination",
values=to_remove,
keyvalues={},
)
self.db_pool.updates._background_update_progress_txn(
txn,
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
{
"stream_id": stop,
"max_stream_id": max_stream_id,
},
)
return stop >= max_stream_id
finished = await self.db_pool.runInteraction(
"_cleanup_device_federation_outbox",
_cleanup_device_federation_outbox_txn,
)
if finished:
await self.db_pool.updates._end_background_update(
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
)
return batch_size
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
pass

View file

@ -57,10 +57,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import (
JsonDict,
JsonMapping,
@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._device_list_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
("device_lists_changes_converted_stream_position", "stream_id"),
self._device_list_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="device_lists_stream",
instance_name=self._instance_name,
tables=[
("device_lists_stream", "instance_name", "stream_id"),
("user_signature_stream", "instance_name", "stream_id"),
("device_lists_outbound_pokes", "instance_name", "stream_id"),
("device_lists_changes_in_room", "instance_name", "stream_id"),
("device_lists_remote_pending", "instance_name", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
sequence_name="device_lists_sequence",
writers=["master"],
)
device_list_max = self._device_list_id_gen.get_current_token()
@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"stream_id": stream_id,
"from_user_id": from_user_id,
"user_ids": json_encoder.encode(user_ids),
"instance_name": self._instance_name,
},
)
@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"device_lists_outbound_pokes",
{
"stream_id": stream_id,
"instance_name": self._instance_name,
"destination": destination,
"user_id": user_id,
"device_id": device_id,
@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Because we have write access, this will be a StreamIdGenerator
# (see DeviceWorkerStore.__init__)
_device_list_id_gen: AbstractStreamIdGenerator
def __init__(
self,
database: DatabasePool,
@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
keys=("stream_id", "user_id", "device_id"),
keys=("instance_name", "stream_id", "user_id", "device_id"),
values=[
(stream_id, user_id, device_id)
(self._instance_name, stream_id, user_id, device_id)
for stream_id, device_id in zip(stream_ids, device_ids)
],
)
@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values = [
(
destination,
self._instance_name,
next(stream_id_iterator),
user_id,
device_id,
@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_outbound_pokes",
keys=(
"destination",
"instance_name",
"stream_id",
"user_id",
"device_id",
@ -2157,7 +2158,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id,
{
stream_id: destination
for (destination, stream_id, _, _, _, _, _) in values
for (destination, _, stream_id, _, _, _, _, _) in values
},
)
@ -2210,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"device_id",
"room_id",
"stream_id",
"instance_name",
"converted_to_destinations",
"opentracing_context",
),
@ -2219,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id,
room_id,
stream_id,
self._instance_name,
# We only need to calculate outbound pokes for local users
not self.hs.is_mine_id(user_id),
encoded_context,
@ -2338,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"user_id": user_id,
"device_id": device_id,
},
values={"stream_id": stream_id},
values={
"stream_id": stream_id,
"instance_name": self._instance_name,
},
desc="add_remote_device_list_to_pending",
)

View file

@ -58,7 +58,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
):
super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"e2e_cross_signing_keys",
"stream_id",
self._cross_signing_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="e2e_cross_signing_keys",
instance_name=self._instance_name,
tables=[
("e2e_cross_signing_keys", "instance_name", "stream_id"),
],
sequence_name="e2e_cross_signing_keys_sequence",
writers=["master"],
)
async def set_e2e_device_keys(
@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"keytype": key_type,
"keydata": json_encoder.encode(key),
"stream_id": stream_id,
"instance_name": self._instance_name,
},
)

View file

@ -95,6 +95,10 @@ class DeltaState:
to_insert: StateMap[str]
no_longer_in_room: bool = False
def is_noop(self) -> bool:
"""Whether this state delta is actually empty"""
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
class PersistEventsStore:
"""Contains all the functions for writing events to the database.
@ -1017,6 +1021,9 @@ class PersistEventsStore:
) -> None:
"""Update the current state stored in the datatabase for the given room"""
if state_delta.is_noop():
return
async with self._stream_id_gen.get_next() as stream_ordering:
await self.db_pool.runInteraction(
"update_current_state",

View file

@ -200,7 +200,11 @@ class EventsWorkerStore(SQLBaseStore):
notifier=hs.get_replication_notifier(),
stream_name="events",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
tables=[
("events", "instance_name", "stream_ordering"),
("current_state_delta_stream", "instance_name", "stream_id"),
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
],
sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events,
)
@ -210,7 +214,10 @@ class EventsWorkerStore(SQLBaseStore):
notifier=hs.get_replication_notifier(),
stream_name="backfill",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
tables=[
("events", "instance_name", "stream_ordering"),
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
],
sequence_name="events_backfill_stream_seq",
positive=False,
writers=hs.config.worker.writers.events,

View file

@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError
@ -126,7 +126,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
_push_rules_stream_id_gen: StreamIdGenerator
_push_rules_stream_id_gen: MultiWriterIdGenerator
def __init__(
self,
@ -140,14 +140,17 @@ class PushRulesWorkerStore(
hs.get_instance_name() in hs.config.worker.writers.push_rules
)
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._push_rules_stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
is_writer=self._is_push_writer,
self._push_rules_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="push_rules_stream",
instance_name=self._instance_name,
tables=[
("push_rules_stream", "instance_name", "stream_id"),
],
sequence_name="push_rules_stream_sequence",
writers=hs.config.worker.writers.push_rules,
)
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@ -880,6 +883,7 @@ class PushRulesWorkerStore(
raise Exception("Not a push writer")
values = {
"instance_name": self._instance_name,
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
"user_id": user_id,

View file

@ -40,10 +40,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._pushers_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
is_writer=hs.config.worker.worker_app is None,
self._instance_name = hs.get_instance_name()
self._pushers_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="pushers",
instance_name=self._instance_name,
tables=[
("pushers", "instance_name", "id"),
("deleted_pushers", "instance_name", "stream_id"),
],
sequence_name="pushers_sequence",
writers=["master"],
)
self.db_pool.updates.register_background_update_handler(
@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
# Because we have write access, this will be a StreamIdGenerator
# (see PusherWorkerStore.__init__)
_pushers_id_gen: AbstractStreamIdGenerator
_pushers_id_gen: MultiWriterIdGenerator
async def add_pusher(
self,
@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
"instance_name": self._instance_name,
"enabled": enabled,
"device_id": device_id,
# XXX(quenting): We're only really persisting the access token ID
@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
table="deleted_pushers",
values={
"stream_id": stream_id,
"instance_name": self._instance_name,
"app_id": app_id,
"pushkey": pushkey,
"user_id": user_id,
@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
self.db_pool.simple_insert_many_txn(
txn,
table="deleted_pushers",
keys=("stream_id", "app_id", "pushkey", "user_id"),
keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
values=[
(stream_id, pusher.app_id, pusher.pushkey, user_id)
(
stream_id,
self._instance_name,
pusher.app_id,
pusher.pushkey,
user_id,
)
for stream_id, pusher in zip(stream_ids, pushers)
],
)

View file

@ -0,0 +1,27 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Add `instance_name` columns to stream tables to allow them to be used with
-- `MultiWriterIdGenerator`
ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
ALTER TABLE pushers ADD COLUMN instance_name TEXT;
ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;

View file

@ -0,0 +1,54 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Add squences for stream tables to allow them to be used with
-- `MultiWriterIdGenerator`
CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
-- We need to take the max across all the device lists tables as they share the
-- ID generator
SELECT setval('device_lists_sequence', (
SELECT GREATEST(
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
(SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
)
));
CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
SELECT setval('e2e_cross_signing_keys_sequence', (
SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
));
CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
SELECT setval('push_rules_stream_sequence', (
SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
));
CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
-- We need to take the max across all the pusher tables as they share the
-- ID generator
SELECT setval('pushers_sequence', (
SELECT GREATEST(
(SELECT COALESCE(MAX(id), 1) FROM pushers),
(SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
)
));

View file

@ -0,0 +1,15 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(8504, 'cleanup_device_federation_outbox', '{}');

View file

@ -23,15 +23,12 @@ import abc
import heapq
import logging
import threading
from collections import OrderedDict
from contextlib import contextmanager
from types import TracebackType
from typing import (
TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Dict,
Generator,
Generic,
Iterable,
List,
@ -179,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with a single writer.
This class must only be used when the current Synapse process is the sole
writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
initial value of the generator from.
table(str): A database table to read the initial value of the id
generator from.
column(str): The column of the database table to read the initial
value from the id generator from.
extra_tables(list): List of pairs of database tables and columns to
use to source the initial value of the generator from. The value
with the largest magnitude is used.
step(int): which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards.
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
def __init__(
self,
db_conn: LoggingDatabaseConnection,
notifier: "ReplicationNotifier",
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
step: int = 1,
is_writer: bool = True,
) -> None:
assert step != 0
self._lock = threading.Lock()
self._step: int = step
self._current: int = _load_current_id(db_conn, table, column, step)
self._is_writer = is_writer
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
# We use this as an ordered set, as we want to efficiently append items,
# remove items and get the first item. Since we insert IDs in order, the
# insertion ordering will ensure its in the correct ordering.
#
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
self._notifier = notifier
def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
raise Exception("Replication is not supported by writer StreamIdGenerator")
self._current = (max if self._step > 0 else min)(self._current, new_id)
def get_next(self) -> AsyncContextManager[int]:
with self._lock:
self._current += self._step
next_id = self._current
self._unfinished_ids[next_id] = next_id
@contextmanager
def manager() -> Generator[int, None, None]:
try:
yield next_id
finally:
with self._lock:
self._unfinished_ids.pop(next_id)
self._notifier.notify_replication()
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
with self._lock:
next_ids = range(
self._current + self._step,
self._current + self._step * (n + 1),
self._step,
)
self._current += n * self._step
for next_id in next_ids:
self._unfinished_ids[next_id] = next_id
@contextmanager
def manager() -> Generator[Sequence[int], None, None]:
try:
yield next_ids
finally:
with self._lock:
for next_id in next_ids:
self._unfinished_ids.pop(next_id)
self._notifier.notify_replication()
return _AsyncCtxManagerWrapper(manager())
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Retrieve the next stream ID from within a database transaction.
Clean-up functions will be called when the transaction finishes.
Args:
txn: The database transaction object.
Returns:
The next stream ID.
"""
if not self._is_writer:
raise Exception("Tried to allocate stream ID on non-writer")
# Get the next stream ID.
with self._lock:
self._current += self._step
next_id = self._current
self._unfinished_ids[next_id] = next_id
def clear_unfinished_id(id_to_clear: int) -> None:
"""A function to mark processing this ID as finished"""
with self._lock:
self._unfinished_ids.pop(id_to_clear)
# Mark this ID as finished once the database transaction itself finishes.
txn.call_after(clear_unfinished_id, next_id)
txn.call_on_exception(clear_unfinished_id, next_id)
# Return the new ID.
return next_id
def get_current_token(self) -> int:
if not self._is_writer:
return self._current
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
return self.get_current_token()
def get_minimal_local_current_token(self) -> int:
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers.

View file

@ -48,7 +48,7 @@ import attr
from immutabledict import immutabledict
from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
from typing_extensions import TypedDict
from typing_extensions import Self, TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@ -515,6 +515,27 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)
def is_before_or_eq(self, other_token: Self) -> bool:
"""Wether this token is before the other token, i.e. every constituent
part is before the other.
Essentially it is `self <= other`.
Note: if `self.is_before_or_eq(other_token) is False` then that does not
imply that the reverse is True.
"""
if self.stream > other_token.stream:
return False
instances = self.instance_map.keys() | other_token.instance_map.keys()
for instance in instances:
if self.instance_map.get(
instance, self.stream
) > other_token.instance_map.get(instance, other_token.stream):
return False
return True
@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken(AbstractMultiWriterStreamToken):
@ -1008,6 +1029,41 @@ class StreamToken:
"""Returns the stream ID for the given key."""
return getattr(self, key.value)
def is_before_or_eq(self, other_token: "StreamToken") -> bool:
"""Wether this token is before the other token, i.e. every constituent
part is before the other.
Essentially it is `self <= other`.
Note: if `self.is_before_or_eq(other_token) is False` then that does not
imply that the reverse is True.
"""
for _, key in StreamKeyType.__members__.items():
if key == StreamKeyType.TYPING:
# Typing stream is allowed to "reset", and so comparisons don't
# really make sense as is.
# TODO: Figure out a better way of tracking resets.
continue
self_value = self.get_field(key)
other_value = other_token.get_field(key)
if isinstance(self_value, RoomStreamToken):
assert isinstance(other_value, RoomStreamToken)
if not self_value.is_before_or_eq(other_value):
return False
elif isinstance(self_value, MultiWriterStreamToken):
assert isinstance(other_value, MultiWriterStreamToken)
if not self_value.is_before_or_eq(other_value):
return False
else:
assert isinstance(other_value, int)
if self_value > other_value:
return False
return True
StreamToken.START = StreamToken(
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0

View file

@ -407,3 +407,24 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
self.assertFalse(
self.get_success(self.store.did_forget(self.alice, self.room_id))
)
def test_deduplicate_joins(self) -> None:
"""
Test that calling /join multiple times does not store a new state group.
"""
self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
sql = "SELECT COUNT(*) FROM state_groups WHERE room_id = ?"
rows = self.get_success(
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
)
initial_count = rows[0][0]
self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
rows = self.get_success(
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
)
new_count = rows[0][0]
self.assertEqual(initial_count, new_count)

View file

@ -30,7 +30,7 @@ from synapse.storage.database import (
)
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import (
LocalSequenceGenerator,
PostgresSequenceGenerator,
@ -42,144 +42,6 @@ from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
class StreamIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
data TEXT
);
"""
)
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
def _create_id_generator(self) -> StreamIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
return StreamIdGenerator(
db_conn=conn,
notifier=self.hs.get_replication_notifier(),
table="foobar",
column="stream_id",
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def test_initial_value(self) -> None:
"""Check that we read the current token from the DB."""
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_current_token(), 123)
def test_single_gen_next(self) -> None:
"""Check that we correctly increment the current token from the DB."""
id_gen = self._create_id_generator()
async def test_gen_next() -> None:
async with id_gen.get_next() as next_id:
# We haven't persisted `next_id` yet; current token is still 123
self.assertEqual(id_gen.get_current_token(), 123)
# But we did learn what the next value is
self.assertEqual(next_id, 124)
# Once the context manager closes we assume that the `next_id` has been
# written to the DB.
self.assertEqual(id_gen.get_current_token(), 124)
self.get_success(test_gen_next())
def test_multiple_gen_nexts(self) -> None:
"""Check that we handle overlapping calls to gen_next sensibly."""
id_gen = self._create_id_generator()
async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()
# Request three new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)
self.assertEqual(await ctx3.__aenter__(), 126)
# None are persisted: current token unchanged.
self.assertEqual(id_gen.get_current_token(), 123)
# Persist each in turn.
await ctx1.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 124)
await ctx2.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 125)
await ctx3.__aexit__(None, None, None)
self.assertEqual(id_gen.get_current_token(), 126)
self.get_success(test_gen_next())
def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
"""Check that we handle overlapping calls to gen_next, even when their IDs
created and persisted in different orders."""
id_gen = self._create_id_generator()
async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()
# Request three new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)
self.assertEqual(await ctx3.__aenter__(), 126)
# None are persisted: current token unchanged.
self.assertEqual(id_gen.get_current_token(), 123)
# Persist them in a different order, starting with 126 from ctx3.
await ctx3.__aexit__(None, None, None)
# We haven't persisted 124 from ctx1 yet---current token is still 123.
self.assertEqual(id_gen.get_current_token(), 123)
# Now persist 124 from ctx1.
await ctx1.__aexit__(None, None, None)
# Current token is then 124, waiting for 125 to be persisted.
self.assertEqual(id_gen.get_current_token(), 124)
# Finally persist 125 from ctx2.
await ctx2.__aexit__(None, None, None)
# Current token is then 126 (skipping over 125).
self.assertEqual(id_gen.get_current_token(), 126)
self.get_success(test_gen_next())
def test_gen_next_while_still_waiting_for_persistence(self) -> None:
"""Check that we handle overlapping calls to gen_next."""
id_gen = self._create_id_generator()
async def test_gen_next() -> None:
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
ctx3 = id_gen.get_next()
# Request two new stream IDs.
self.assertEqual(await ctx1.__aenter__(), 124)
self.assertEqual(await ctx2.__aenter__(), 125)
# Persist ctx2 first.
await ctx2.__aexit__(None, None, None)
# Still waiting on ctx1's ID to be persisted.
self.assertEqual(id_gen.get_current_token(), 123)
# Now request a third stream ID. It should be 126 (the smallest ID that
# we've not yet handed out.)
self.assertEqual(await ctx3.__aenter__(), 126)
self.get_success(test_gen_next())
class MultiWriterIdGeneratorBase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main