Speed up state res in rare case we don't have all events (#16116)

If we don't have all the auth events in a room then not all state events will have a chain cover index. Even so, we can still use the chain cover index on the events that do have it, rather than bailing and using the slower functions.

This situation should not arise for newly persisted rooms, as we check we have the full auth chain for each event, but can happen for existing rooms.

c.f. #15245
This commit is contained in:
Erik Johnston 2023-08-18 15:32:06 +01:00 committed by GitHub
parent 2d15e39684
commit bd558a6dc3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 355 additions and 71 deletions

1
changelog.d/16116.bugfix Normal file
View file

@ -0,0 +1 @@
Fix performance of state resolutions for large, old rooms that did not have the full auth chain persisted.

View file

@ -452,33 +452,56 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# sets.
seen_chains: Set[int] = set()
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
# Fetch the chain cover index for the initial set of events we're
# considering.
def fetch_chain_info(events_to_fetch: Collection[str]) -> None:
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(events_to_fetch, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
for event_id, chain_id, sequence_number in txn:
chain_info[event_id] = (chain_id, sequence_number)
seen_chains.add(chain_id)
chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
for event_id, chain_id, sequence_number in txn:
chain_info[event_id] = (chain_id, sequence_number)
seen_chains.add(chain_id)
chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
fetch_chain_info(initial_events)
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(chain_info)
# The result set to return, i.e. the auth chain difference.
result: Set[str] = set()
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
# For some reason we have events we haven't calculated the chain
# index for, so we need to handle those separately. This should only
# happen for older rooms where the server doesn't have all the auth
# events.
result = self._fixup_auth_chain_difference_sets(
txn,
room_id,
events_missing_chain_info,
state_sets=state_sets,
events_missing_chain_info=events_missing_chain_info,
events_that_have_chain_index=chain_info,
)
raise _NoChainCoverIndex(room_id)
# We now need to refetch any events that we have added to the state
# sets.
new_events_to_fetch = {
event_id
for state_set in state_sets
for event_id in state_set
if event_id not in initial_events
}
fetch_chain_info(new_events_to_fetch)
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
@ -487,8 +510,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
set_to_chain.append(chains)
for event_id in state_set:
chain_id, seq_no = chain_info[event_id]
for state_id in state_set:
chain_id, seq_no = chain_info[state_id]
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
@ -532,7 +555,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# from *any* state set and the minimum sequence number reachable from
# *all* state sets. Events in that range are in the auth chain
# difference.
result = set()
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
@ -588,6 +610,122 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
def _fixup_auth_chain_difference_sets(
self,
txn: LoggingTransaction,
room_id: str,
state_sets: List[Set[str]],
events_missing_chain_info: Set[str],
events_that_have_chain_index: Collection[str],
) -> Set[str]:
"""Helper for `_get_auth_chain_difference_using_cover_index_txn` to
handle the case where we haven't calculated the chain cover index for
all events.
This modifies `state_sets` so that they only include events that have a
chain cover index, and returns a set of event IDs that are part of the
auth difference.
"""
# This works similarly to the handling of unpersisted events in
# `synapse.state.v2_get_auth_chain_difference`. We uses the observation
# that if you can split the set of events into two classes X and Y,
# where no events in Y have events in X in their auth chain, then we can
# calculate the auth difference by considering X and Y separately.
#
# We do this in three steps:
# 1. Compute the set of events without chain cover index belonging to
# the auth difference.
# 2. Replacing the un-indexed events in the state_sets with their auth
# events, recursively, until the state_sets contain only indexed
# events. We can then calculate the auth difference of those state
# sets using the chain cover index.
# 3. Add the results of 1 and 2 together.
# By construction we know that all events that we haven't persisted the
# chain cover index for are contained in
# `event_auth_chain_to_calculate`, so we pull out the events from those
# rather than doing recursive queries to walk the auth chain.
#
# We pull out those events with their auth events, which gives us enough
# information to construct the auth chain of an event up to auth events
# that have the chain cover index.
sql = """
SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL
FROM event_auth_chain_to_calculate AS tc
LEFT JOIN event_auth AS ea USING (event_id)
LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id)
WHERE tc.room_id = ?
"""
txn.execute(sql, (room_id,))
event_to_auth_ids: Dict[str, Set[str]] = {}
events_that_have_chain_index = set(events_that_have_chain_index)
for event_id, auth_id, auth_id_has_chain in txn:
s = event_to_auth_ids.setdefault(event_id, set())
if auth_id is not None:
s.add(auth_id)
if auth_id_has_chain:
events_that_have_chain_index.add(auth_id)
if events_missing_chain_info - event_to_auth_ids.keys():
# Uh oh, we somehow haven't correctly done the chain cover index,
# bail and fall back to the old method.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info - event_to_auth_ids.keys(),
)
raise _NoChainCoverIndex(room_id)
# Create a map from event IDs we care about to their partial auth chain.
event_id_to_partial_auth_chain: Dict[str, Set[str]] = {}
for event_id, auth_ids in event_to_auth_ids.items():
if not any(event_id in state_set for state_set in state_sets):
continue
processing = set(auth_ids)
to_add = set()
while processing:
auth_id = processing.pop()
to_add.add(auth_id)
sub_auth_ids = event_to_auth_ids.get(auth_id)
if sub_auth_ids is None:
continue
processing.update(sub_auth_ids - to_add)
event_id_to_partial_auth_chain[event_id] = to_add
# Now we do two things:
# 1. Update the state sets to only include indexed events; and
# 2. Create a new list containing the auth chains of the un-indexed
# events
unindexed_state_sets: List[Set[str]] = []
for state_set in state_sets:
unindexed_state_set = set()
for event_id, auth_chain in event_id_to_partial_auth_chain.items():
if event_id not in state_set:
continue
unindexed_state_set.add(event_id)
state_set.discard(event_id)
state_set.difference_update(auth_chain)
for auth_id in auth_chain:
if auth_id in events_that_have_chain_index:
state_set.add(auth_id)
else:
unindexed_state_set.add(auth_id)
unindexed_state_sets.append(unindexed_state_set)
# Calculate and return the auth difference of the un-indexed events.
union = unindexed_state_sets[0].union(*unindexed_state_sets[1:])
intersection = unindexed_state_sets[0].intersection(*unindexed_state_sets[1:])
return union - intersection
def _get_auth_chain_difference_txn(
self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:

View file

@ -13,7 +13,19 @@
# limitations under the License.
import datetime
from typing import Dict, List, Tuple, Union, cast
from typing import (
Collection,
Dict,
FrozenSet,
Iterable,
List,
Mapping,
Set,
Tuple,
TypeVar,
Union,
cast,
)
import attr
from parameterized import parameterized
@ -38,6 +50,138 @@ from synapse.util import Clock, json_encoder
import tests.unittest
import tests.utils
# The silly auth graph we use to test the auth difference algorithm,
# where the top are the most recent events.
#
# A B
# \ /
# D E
# \ |
# ` F C
# | /|
# G ´ |
# | \ |
# H I
# | |
# K J
AUTH_GRAPH: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
"d": ["f"],
"e": ["f"],
"f": ["g"],
"g": ["h", "i"],
"h": ["k"],
"i": ["j"],
"k": [],
"j": [],
}
DEPTH_GRAPH = {
"a": 7,
"b": 7,
"c": 4,
"d": 6,
"e": 6,
"f": 5,
"g": 3,
"h": 2,
"i": 2,
"k": 1,
"j": 1,
}
T = TypeVar("T")
def get_all_topologically_sorted_orders(
nodes: Iterable[T],
graph: Mapping[T, Collection[T]],
) -> List[List[T]]:
"""Given a set of nodes and a graph, return all possible topological
orderings.
"""
# This is implemented by Kahn's algorithm, and forking execution each time
# we have a choice over which node to consider next.
degree_map = {node: 0 for node in nodes}
reverse_graph: Dict[T, Set[T]] = {}
for node, edges in graph.items():
if node not in degree_map:
continue
for edge in set(edges):
if edge in degree_map:
degree_map[node] += 1
reverse_graph.setdefault(edge, set()).add(node)
reverse_graph.setdefault(node, set())
zero_degree = [node for node, degree in degree_map.items() if degree == 0]
return _get_all_topologically_sorted_orders_inner(
reverse_graph, zero_degree, degree_map
)
def _get_all_topologically_sorted_orders_inner(
reverse_graph: Dict[T, Set[T]],
zero_degree: List[T],
degree_map: Dict[T, int],
) -> List[List[T]]:
new_paths = []
# Rather than only choosing *one* item from the list of nodes with zero
# degree, we "fork" execution and run the algorithm for each node in the
# zero degree.
for node in zero_degree:
new_degree_map = degree_map.copy()
new_zero_degree = zero_degree.copy()
new_zero_degree.remove(node)
for edge in reverse_graph.get(node, []):
if edge in new_degree_map:
new_degree_map[edge] -= 1
if new_degree_map[edge] == 0:
new_zero_degree.append(edge)
paths = _get_all_topologically_sorted_orders_inner(
reverse_graph, new_zero_degree, new_degree_map
)
for path in paths:
path.insert(0, node)
new_paths.extend(paths)
if not new_paths:
return [[]]
return new_paths
def get_all_topologically_consistent_subsets(
nodes: Iterable[T],
graph: Mapping[T, Collection[T]],
) -> Set[FrozenSet[T]]:
"""Get all subsets of the graph where if node N is in the subgraph, then all
nodes that can reach that node (i.e. for all X there exists a path X -> N)
are in the subgraph.
"""
all_topological_orderings = get_all_topologically_sorted_orders(nodes, graph)
graph_subsets = set()
for ordering in all_topological_orderings:
ordering.reverse()
for idx in range(len(ordering)):
graph_subsets.add(frozenset(ordering[:idx]))
return graph_subsets
@attr.s(auto_attribs=True, frozen=True, slots=True)
class _BackfillSetupInfo:
@ -172,49 +316,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
# where the top are the most recent events.
#
# A B
# \ /
# D E
# \ |
# ` F C
# | /|
# G ´ |
# | \ |
# H I
# | |
# K J
auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
"d": ["f"],
"e": ["f"],
"f": ["g"],
"g": ["h", "i"],
"h": ["k"],
"i": ["j"],
"k": [],
"j": [],
}
depth_map = {
"a": 7,
"b": 7,
"c": 4,
"d": 6,
"e": 6,
"f": 5,
"g": 3,
"h": 2,
"i": 2,
"k": 1,
"j": 1,
}
# Mark the room as maybe having a cover index.
def store_room(txn: LoggingTransaction) -> None:
@ -238,9 +339,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def insert_event(txn: LoggingTransaction) -> None:
stream_ordering = 0
for event_id in auth_graph:
for event_id in AUTH_GRAPH:
stream_ordering += 1
depth = depth_map[event_id]
depth = DEPTH_GRAPH[event_id]
self.store.db_pool.simple_insert_txn(
txn,
@ -260,8 +361,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
for event_id in AUTH_GRAPH
],
)
@ -344,7 +445,51 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result:
self.assert_auth_diff_is_expected(room_id)
@parameterized.expand(
[
[graph_subset]
for graph_subset in get_all_topologically_consistent_subsets(
AUTH_GRAPH, AUTH_GRAPH
)
]
)
def test_auth_difference_partial(self, graph_subset: Collection[str]) -> None:
"""Test that if we only have a chain cover index on a partial subset of
the room we still get the correct auth chain difference.
We do this by removing the chain cover index for every valid subset of the
graph.
"""
room_id = self._setup_auth_chain(True)
for event_id in graph_subset:
# Remove chain cover from that event.
self.get_success(
self.store.db_pool.simple_delete(
table="event_auth_chains",
keyvalues={"event_id": event_id},
desc="test_auth_difference_partial_remove",
)
)
self.get_success(
self.store.db_pool.simple_insert(
table="event_auth_chain_to_calculate",
values={
"event_id": event_id,
"room_id": room_id,
"type": "",
"state_key": "",
},
desc="test_auth_difference_partial_remove",
)
)
self.assert_auth_diff_is_expected(room_id)
def assert_auth_diff_is_expected(self, room_id: str) -> None:
"""Assert the auth chain difference returns the correct answers."""
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)