Perf improvement to getting auth chains

This commit is contained in:
Erik Johnston 2024-05-08 16:04:35 +01:00
parent 34a8652366
commit db25e30a25

View file

@ -283,7 +283,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# A map from chain ID to max sequence number *reachable* from any event ID. # A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {} chains: Dict[int, int] = {}
for links in self._get_chain_links(txn, set(event_chains.keys())): for links in self._get_chain_links(txn, event_chains.keys()):
for chain_id in links: for chain_id in links:
if chain_id not in event_chains: if chain_id not in event_chains:
continue continue
@ -335,7 +335,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@classmethod @classmethod
def _get_chain_links( def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int] cls, txn: LoggingTransaction, chains_to_fetch: Collection[int]
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all """Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively. links from those chains, recursively.
@ -371,9 +371,27 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
""" """
while chains_to_fetch: # We fetch the links in batches. Separate batches will likely fetch the
batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) # same set of links (e.g. they'll always pull in the links to create
chains_to_fetch.difference_update(batch2) # event). To try and minimize the amount of redundant links, we sort the
# chain IDs in reverse, as there will be a correlation between the order
# of chain IDs and links (i.e., higher chain IDs are more likely to
# depend on lower chain IDs than vice versa).
BATCH_SIZE = 1000
chains_to_fetch_list = list(chains_to_fetch)
chains_to_fetch_list.sort(reverse=True)
seen_chains: Set[int] = set()
while chains_to_fetch_list:
batch2 = [
c for c in chains_to_fetch_list[-BATCH_SIZE:] if c not in seen_chains
]
chains_to_fetch_list = chains_to_fetch_list[:-BATCH_SIZE]
while len(batch2) < BATCH_SIZE and chains_to_fetch_list:
chain_id = chains_to_fetch_list.pop()
if chain_id not in seen_chains:
batch2.append(chain_id)
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2 txn.database_engine, "origin_chain_id", batch2
) )
@ -391,7 +409,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
(origin_sequence_number, target_chain_id, target_sequence_number) (origin_sequence_number, target_chain_id, target_sequence_number)
) )
chains_to_fetch.difference_update(links) seen_chains.update(links)
seen_chains.update(batch2)
yield links yield links
@ -581,7 +600,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# are reachable from any event. # are reachable from any event.
# (We need to take a copy of `seen_chains` as the function mutates it) # (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)): for links in self._get_chain_links(txn, seen_chains):
for chains in set_to_chain: for chains in set_to_chain:
for chain_id in links: for chain_id in links:
if chain_id not in chains: if chain_id not in chains: