From db25e30a256047454a9b09b579856d6cce0a6a7b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 May 2024 16:04:35 +0100 Subject: [PATCH] Perf improvement to getting auth chains --- .../databases/main/event_federation.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index fb132ef090..68f30d893c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -283,7 +283,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # A map from chain ID to max sequence number *reachable* from any event ID. chains: Dict[int, int] = {} - for links in self._get_chain_links(txn, set(event_chains.keys())): + for links in self._get_chain_links(txn, event_chains.keys()): for chain_id in links: if chain_id not in event_chains: continue @@ -335,7 +335,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @classmethod def _get_chain_links( - cls, txn: LoggingTransaction, chains_to_fetch: Set[int] + cls, txn: LoggingTransaction, chains_to_fetch: Collection[int] ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: """Fetch all auth chain links from the given set of chains, and all links from those chains, recursively. @@ -371,9 +371,27 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) """ - while chains_to_fetch: - batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) - chains_to_fetch.difference_update(batch2) + # We fetch the links in batches. Separate batches will likely fetch the + # same set of links (e.g. they'll always pull in the links to create + # event). To try and minimize the amount of redundant links, we sort the + # chain IDs in reverse, as there will be a correlation between the order + # of chain IDs and links (i.e., higher chain IDs are more likely to + # depend on lower chain IDs than vice versa). + BATCH_SIZE = 1000 + chains_to_fetch_list = list(chains_to_fetch) + chains_to_fetch_list.sort(reverse=True) + + seen_chains: Set[int] = set() + while chains_to_fetch_list: + batch2 = [ + c for c in chains_to_fetch_list[-BATCH_SIZE:] if c not in seen_chains + ] + chains_to_fetch_list = chains_to_fetch_list[:-BATCH_SIZE] + while len(batch2) < BATCH_SIZE and chains_to_fetch_list: + chain_id = chains_to_fetch_list.pop() + if chain_id not in seen_chains: + batch2.append(chain_id) + clause, args = make_in_list_sql_clause( txn.database_engine, "origin_chain_id", batch2 ) @@ -391,7 +409,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas (origin_sequence_number, target_chain_id, target_sequence_number) ) - chains_to_fetch.difference_update(links) + seen_chains.update(links) + seen_chains.update(batch2) yield links @@ -581,7 +600,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # are reachable from any event. # (We need to take a copy of `seen_chains` as the function mutates it) - for links in self._get_chain_links(txn, set(seen_chains)): + for links in self._get_chain_links(txn, seen_chains): for chains in set_to_chain: for chain_id in links: if chain_id not in chains: