Add a cache to auth links

This commit is contained in:
Erik Johnston 2024-05-17 10:34:39 +01:00
parent 4ffe5a4459
commit fe9fa90af4

View file

@ -120,6 +120,11 @@ class BackfillQueueNavigationItem:
type: str type: str
@attr.s(frozen=True, slots=True, auto_attribs=True)
class _ChainLinksCacheEntry:
links: List[Tuple[int, int, int, "_ChainLinksCacheEntry"]] = attr.Factory(list)
class _NoChainCoverIndex(Exception): class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str): def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@ -140,6 +145,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self.hs = hs self.hs = hs
self._chain_links_cache: LruCache[int, _ChainLinksCacheEntry] = LruCache(
max_size=10000, cache_name="chain_links_cache"
)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:
hs.get_clock().looping_call( hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
@ -285,7 +294,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# A map from chain ID to max sequence number *reachable* from any event ID. # A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {} chains: Dict[int, int] = {}
for links in self._get_chain_links(txn, event_chains.keys()): for links in self._get_chain_links(
txn, event_chains.keys(), self._chain_links_cache
):
for chain_id in links: for chain_id in links:
if chain_id not in event_chains: if chain_id not in event_chains:
continue continue
@ -337,7 +348,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@classmethod @classmethod
def _get_chain_links( def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Collection[int] cls,
txn: LoggingTransaction,
chains_to_fetch: Collection[int],
cache: Optional[LruCache[int, _ChainLinksCacheEntry]] = None,
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all """Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively. links from those chains, recursively.
@ -349,6 +363,44 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
of origin sequence number, target chain ID and target sequence number. of origin sequence number, target chain ID and target sequence number.
""" """
found_cached_chains = set()
if cache:
entries: Dict[int, _ChainLinksCacheEntry] = {}
for chain_id in chains_to_fetch:
entry = cache.get(chain_id)
if entry:
entries[chain_id] = entry
cached_links: Dict[int, List[Tuple[int, int, int]]] = {}
while entries:
origin_chain_id, entry = entries.popitem()
for (
origin_sequence_number,
target_chain_id,
target_sequence_number,
target_entry,
) in entry.links:
if target_chain_id in found_cached_chains:
continue
found_cached_chains.add(target_chain_id)
cache.get(chain_id)
entries[chain_id] = target_entry
cached_links.setdefault(origin_chain_id, []).append(
(
origin_sequence_number,
target_chain_id,
target_sequence_number,
)
)
yield cached_links
logger.info("CHAINS: Found cached chain links %d", len(found_cached_chains))
# This query is structured to first get all chain IDs reachable, and # This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more # then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of # rows than is strictly necessary, however there isn't a way of
@ -385,6 +437,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# likely to depend on lower chain IDs than vice versa). # likely to depend on lower chain IDs than vice versa).
BATCH_SIZE = 5000 BATCH_SIZE = 5000
chains_to_fetch_sorted = SortedSet(chains_to_fetch) chains_to_fetch_sorted = SortedSet(chains_to_fetch)
chains_to_fetch_sorted.difference_update(found_cached_chains)
logger.info("CHAINS: Fetching chain links %d", len(chains_to_fetch_sorted)) logger.info("CHAINS: Fetching chain links %d", len(chains_to_fetch_sorted))
@ -406,6 +459,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
links: Dict[int, List[Tuple[int, int, int]]] = {} links: Dict[int, List[Tuple[int, int, int]]] = {}
cache_entries: Dict[int, _ChainLinksCacheEntry] = {}
for ( for (
origin_chain_id, origin_chain_id,
origin_sequence_number, origin_sequence_number,
@ -416,6 +471,27 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
(origin_sequence_number, target_chain_id, target_sequence_number) (origin_sequence_number, target_chain_id, target_sequence_number)
) )
if cache:
origin_entry = cache_entries.setdefault(
origin_chain_id, _ChainLinksCacheEntry()
)
target_entry = cache_entries.setdefault(
target_chain_id, _ChainLinksCacheEntry()
)
origin_entry.links.append(
(
origin_sequence_number,
target_chain_id,
target_sequence_number,
target_entry,
)
)
if cache:
for chain_id, entry in cache_entries.items():
if chain_id not in cache:
cache[chain_id] = entry
chains_to_fetch_sorted.difference_update(links) chains_to_fetch_sorted.difference_update(links)
logger.info("CHAINS: returned %d", len(links)) logger.info("CHAINS: returned %d", len(links))
@ -614,7 +690,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# (We need to take a copy of `seen_chains` as the function mutates it) # (We need to take a copy of `seen_chains` as the function mutates it)
logger.info("CHAINS: for room %s", room_id) logger.info("CHAINS: for room %s", room_id)
for links in self._get_chain_links(txn, seen_chains): for links in self._get_chain_links(txn, seen_chains, self._chain_links_cache):
for chains in set_to_chain: for chains in set_to_chain:
for chain_id in links: for chain_id in links:
if chain_id not in chains: if chain_id not in chains: