This commit is contained in:
Erik Johnston 2024-06-14 21:11:36 +02:00 committed by GitHub
commit bd91b8c7c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 93 additions and 7 deletions

1
changelog.d/17169.misc Normal file
View file

@ -0,0 +1 @@
Add database performance improvement when fetching auth chains.

View file

@ -39,6 +39,7 @@ from typing import (
import attr import attr
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
from sortedcontainers import SortedSet
from synapse.api.constants import MAX_DEPTH from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -283,7 +284,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# A map from chain ID to max sequence number *reachable* from any event ID. # A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {} chains: Dict[int, int] = {}
for links in self._get_chain_links(txn, set(event_chains.keys())): for links in self._get_chain_links(txn, event_chains.keys()):
for chain_id in links: for chain_id in links:
if chain_id not in event_chains: if chain_id not in event_chains:
continue continue
@ -335,7 +336,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@classmethod @classmethod
def _get_chain_links( def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int] cls, txn: LoggingTransaction, chains_to_fetch: Collection[int]
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all """Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively. links from those chains, recursively.
@ -371,9 +372,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
""" """
while chains_to_fetch: # We fetch the links in batches. Separate batches will likely fetch the
batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) # same set of links (e.g. they'll always pull in the links to create
chains_to_fetch.difference_update(batch2) # event). To try and minimize the amount of redundant links, we query
# the chain IDs in reverse order, as there will be a correlation between
# the order of chain IDs and links (i.e., higher chain IDs are more
# likely to depend on lower chain IDs than vice versa).
BATCH_SIZE = 1000
chains_to_fetch_sorted = SortedSet(chains_to_fetch)
while chains_to_fetch_sorted:
batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE))
chains_to_fetch_sorted.difference_update(batch2)
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2 txn.database_engine, "origin_chain_id", batch2
) )
@ -391,7 +402,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
(origin_sequence_number, target_chain_id, target_sequence_number) (origin_sequence_number, target_chain_id, target_sequence_number)
) )
chains_to_fetch.difference_update(links) chains_to_fetch_sorted.difference_update(links)
yield links yield links
@ -581,7 +592,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# are reachable from any event. # are reachable from any event.
# (We need to take a copy of `seen_chains` as the function mutates it) # (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)): for links in self._get_chain_links(txn, seen_chains):
for chains in set_to_chain: for chains in set_to_chain:
for chain_id in links: for chain_id in links:
if chain_id not in chains: if chain_id not in chains:

View file

@ -25,6 +25,7 @@ from synapse.rest.client import room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils.event_injection import inject_event
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -128,3 +129,76 @@ class PurgeTests(HomeserverTestCase):
self.store._invalidate_local_get_event_cache(create_event.event_id) self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
def test_state_groups_state_decreases(self) -> None:
response = self.helper.send(self.room_id, body="first")
first_event_id = response["event_id"]
batches = []
previous_event_id = first_event_id
for i in range(50):
state_event1 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 1},
prev_event_ids=[previous_event_id],
origin_server_ts=1,
)
)
state_event2 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 2},
prev_event_ids=[previous_event_id],
origin_server_ts=2,
)
)
# print(state_event2.origin_server_ts - state_event1.origin_server_ts)
message_event = self.get_success(
inject_event(
self.hs,
type="dummy_event",
sender=self.user_id,
room_id=self.room_id,
content={},
prev_event_ids=[state_event1.event_id, state_event2.event_id],
)
)
token = self.get_success(
self.store.get_topological_token_for_event(state_event1.event_id)
)
batches.append(token)
previous_event_id = message_event.event_id
self.helper.send(self.room_id, body="last event")
def count_state_groups() -> int:
sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?"
rows = self.get_success(
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
)
return rows[0][0]
print(count_state_groups())
for token in batches:
token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
self.get_success(
self._storage_controllers.purge_events.purge_history(
self.room_id, token_str, False
)
)
print(count_state_groups())