mirror of
https://github.com/element-hq/synapse
synced 2024-09-28 14:12:41 +00:00
Handle large chain calc better
We calculate the auth chain links outside of the main persist event transaction to ensure that we do not block other event sending during the calculation.
This commit is contained in:
parent
3f06bbc0ac
commit
5a287d9d14
4 changed files with 199 additions and 88 deletions
|
@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
|
|||
room_id, chunk
|
||||
)
|
||||
|
||||
with Measure(self._clock, "calculate_chain_cover_index_for_events"):
|
||||
# We now calculate chain ID/sequence numbers for any state events we're
|
||||
# persisting. We ignore out of band memberships as we're not in the room
|
||||
# and won't have their auth chain (we'll fix it up later if we join the
|
||||
# room).
|
||||
#
|
||||
# See: docs/auth_chain_difference_algorithm.md
|
||||
new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
|
||||
room_id, [e for e, _ in chunk]
|
||||
)
|
||||
|
||||
await self.persist_events_store._persist_events_and_state_updates(
|
||||
room_id,
|
||||
chunk,
|
||||
|
@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
|
|||
new_forward_extremities=new_forward_extremities,
|
||||
use_negative_stream_ordering=backfilled,
|
||||
inhibit_local_membership_updates=backfilled,
|
||||
new_event_links=new_event_links,
|
||||
)
|
||||
|
||||
return replaced_events
|
||||
|
|
|
@ -34,7 +34,6 @@ from typing import (
|
|||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
@ -100,6 +99,14 @@ class DeltaState:
|
|||
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class NewEventChainLinks:
|
||||
chain_id: int
|
||||
sequence_number: int
|
||||
|
||||
links: List[Tuple[int, int]] = attr.Factory(list)
|
||||
|
||||
|
||||
class PersistEventsStore:
|
||||
"""Contains all the functions for writing events to the database.
|
||||
|
||||
|
@ -148,6 +155,7 @@ class PersistEventsStore:
|
|||
*,
|
||||
state_delta_for_room: Optional[DeltaState],
|
||||
new_forward_extremities: Optional[Set[str]],
|
||||
new_event_links: Dict[str, NewEventChainLinks],
|
||||
use_negative_stream_ordering: bool = False,
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
) -> None:
|
||||
|
@ -216,6 +224,7 @@ class PersistEventsStore:
|
|||
inhibit_local_membership_updates=inhibit_local_membership_updates,
|
||||
state_delta_for_room=state_delta_for_room,
|
||||
new_forward_extremities=new_forward_extremities,
|
||||
new_event_links=new_event_links,
|
||||
)
|
||||
persist_event_counter.inc(len(events_and_contexts))
|
||||
|
||||
|
@ -242,6 +251,64 @@ class PersistEventsStore:
|
|||
(room_id,), frozenset(new_forward_extremities)
|
||||
)
|
||||
|
||||
async def calculate_chain_cover_index_for_events(
|
||||
self, room_id: str, events: Collection[EventBase]
|
||||
) -> Dict[str, NewEventChainLinks]:
|
||||
|
||||
state_events = [event for event in events if event.is_state()]
|
||||
|
||||
if not state_events:
|
||||
return {}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"_calculate_chain_cover_index_for_events",
|
||||
self.calculate_chain_cover_index_for_events_txn,
|
||||
room_id,
|
||||
state_events,
|
||||
)
|
||||
|
||||
def calculate_chain_cover_index_for_events_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
|
||||
) -> Dict[str, NewEventChainLinks]:
|
||||
# We now calculate chain ID/sequence numbers for any state events we're
|
||||
# persisting. We ignore out of band memberships as we're not in the room
|
||||
# and won't have their auth chain (we'll fix it up later if we join the
|
||||
# room).
|
||||
#
|
||||
# See: docs/auth_chain_difference_algorithm.md
|
||||
|
||||
# We ignore legacy rooms that we aren't filling the chain cover index
|
||||
# for.
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("room_id", "has_auth_chain_index"),
|
||||
allow_none=True,
|
||||
)
|
||||
if row is None:
|
||||
return {}
|
||||
|
||||
if not state_events:
|
||||
return {}
|
||||
|
||||
# We need to know the type/state_key and auth events of the events we're
|
||||
# calculating chain IDs for. We don't rely on having the full Event
|
||||
# instances as we'll potentially be pulling more events from the DB and
|
||||
# we don't need the overhead of fetching/parsing the full event JSON.
|
||||
event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
|
||||
event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
|
||||
event_to_room_id = {e.event_id: e.room_id for e in state_events}
|
||||
|
||||
return self._calculate_chain_cover_index(
|
||||
txn,
|
||||
self.db_pool,
|
||||
self.store.event_chain_id_gen,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
|
||||
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
|
||||
"""Filter the supplied list of event_ids to get those which are prev_events of
|
||||
existing (non-outlier/rejected) events.
|
||||
|
@ -357,6 +424,7 @@ class PersistEventsStore:
|
|||
inhibit_local_membership_updates: bool,
|
||||
state_delta_for_room: Optional[DeltaState],
|
||||
new_forward_extremities: Optional[Set[str]],
|
||||
new_event_links: Dict[str, NewEventChainLinks],
|
||||
) -> None:
|
||||
"""Insert some number of room events into the necessary database tables.
|
||||
|
||||
|
@ -465,7 +533,9 @@ class PersistEventsStore:
|
|||
# Insert into event_to_state_groups.
|
||||
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
||||
|
||||
self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
|
||||
self._persist_event_auth_chain_txn(
|
||||
txn, [e for e, _ in events_and_contexts], new_event_links
|
||||
)
|
||||
|
||||
# _store_rejected_events_txn filters out any events which were
|
||||
# rejected, and returns the filtered list.
|
||||
|
@ -495,6 +565,7 @@ class PersistEventsStore:
|
|||
self,
|
||||
txn: LoggingTransaction,
|
||||
events: List[EventBase],
|
||||
new_event_links: Dict[str, NewEventChainLinks],
|
||||
) -> None:
|
||||
# We only care about state events, so this if there are no state events.
|
||||
if not any(e.is_state() for e in events):
|
||||
|
@ -518,59 +589,8 @@ class PersistEventsStore:
|
|||
],
|
||||
)
|
||||
|
||||
# We now calculate chain ID/sequence numbers for any state events we're
|
||||
# persisting. We ignore out of band memberships as we're not in the room
|
||||
# and won't have their auth chain (we'll fix it up later if we join the
|
||||
# room).
|
||||
#
|
||||
# See: docs/auth_chain_difference_algorithm.md
|
||||
|
||||
# We ignore legacy rooms that we aren't filling the chain cover index
|
||||
# for.
|
||||
rows = cast(
|
||||
List[Tuple[str, Optional[Union[int, bool]]]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="rooms",
|
||||
column="room_id",
|
||||
iterable={event.room_id for event in events if event.is_state()},
|
||||
keyvalues={},
|
||||
retcols=("room_id", "has_auth_chain_index"),
|
||||
),
|
||||
)
|
||||
rooms_using_chain_index = {
|
||||
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
|
||||
}
|
||||
|
||||
state_events = {
|
||||
event.event_id: event
|
||||
for event in events
|
||||
if event.is_state() and event.room_id in rooms_using_chain_index
|
||||
}
|
||||
|
||||
if not state_events:
|
||||
return
|
||||
|
||||
# We need to know the type/state_key and auth events of the events we're
|
||||
# calculating chain IDs for. We don't rely on having the full Event
|
||||
# instances as we'll potentially be pulling more events from the DB and
|
||||
# we don't need the overhead of fetching/parsing the full event JSON.
|
||||
event_to_types = {
|
||||
e.event_id: (e.type, e.state_key) for e in state_events.values()
|
||||
}
|
||||
event_to_auth_chain = {
|
||||
e.event_id: e.auth_event_ids() for e in state_events.values()
|
||||
}
|
||||
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
|
||||
|
||||
self._add_chain_cover_index(
|
||||
txn,
|
||||
self.db_pool,
|
||||
self.store.event_chain_id_gen,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
if new_event_links:
|
||||
self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
|
||||
|
||||
@classmethod
|
||||
def _add_chain_cover_index(
|
||||
|
@ -582,6 +602,35 @@ class PersistEventsStore:
|
|||
event_to_types: Dict[str, Tuple[str, str]],
|
||||
event_to_auth_chain: Dict[str, StrCollection],
|
||||
) -> None:
|
||||
"""Calculate and persist the chain cover index for the given events.
|
||||
|
||||
Args:
|
||||
event_to_room_id: Event ID to the room ID of the event
|
||||
event_to_types: Event ID to type and state_key of the event
|
||||
event_to_auth_chain: Event ID to list of auth event IDs of the
|
||||
event (events with no auth events can be excluded).
|
||||
"""
|
||||
|
||||
new_event_links = cls._calculate_chain_cover_index(
|
||||
txn,
|
||||
db_pool,
|
||||
event_chain_id_gen,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
cls._persist_chain_cover_index(txn, db_pool, new_event_links)
|
||||
|
||||
@classmethod
|
||||
def _calculate_chain_cover_index(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
db_pool: DatabasePool,
|
||||
event_chain_id_gen: SequenceGenerator,
|
||||
event_to_room_id: Dict[str, str],
|
||||
event_to_types: Dict[str, Tuple[str, str]],
|
||||
event_to_auth_chain: Dict[str, StrCollection],
|
||||
) -> Dict[str, NewEventChainLinks]:
|
||||
"""Calculate the chain cover index for the given events.
|
||||
|
||||
Args:
|
||||
|
@ -707,11 +756,11 @@ class PersistEventsStore:
|
|||
room_id = event_to_room_id.get(event_id)
|
||||
if room_id:
|
||||
e_type, state_key = event_to_types[event_id]
|
||||
db_pool.simple_insert_txn(
|
||||
db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="event_auth_chain_to_calculate",
|
||||
keyvalues={"event_id": event_id},
|
||||
values={
|
||||
"event_id": event_id,
|
||||
"room_id": room_id,
|
||||
"type": e_type,
|
||||
"state_key": state_key,
|
||||
|
@ -723,7 +772,7 @@ class PersistEventsStore:
|
|||
break
|
||||
|
||||
if not events_to_calc_chain_id_for:
|
||||
return
|
||||
return {}
|
||||
|
||||
# Allocate chain ID/sequence numbers to each new event.
|
||||
new_chain_tuples = cls._allocate_chain_ids(
|
||||
|
@ -738,23 +787,10 @@ class PersistEventsStore:
|
|||
)
|
||||
chain_map.update(new_chain_tuples)
|
||||
|
||||
db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth_chains",
|
||||
keys=("event_id", "chain_id", "sequence_number"),
|
||||
values=[
|
||||
(event_id, c_id, seq)
|
||||
for event_id, (c_id, seq) in new_chain_tuples.items()
|
||||
],
|
||||
)
|
||||
|
||||
db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="event_auth_chain_to_calculate",
|
||||
keyvalues={},
|
||||
column="event_id",
|
||||
values=new_chain_tuples,
|
||||
)
|
||||
to_return = {
|
||||
event_id: NewEventChainLinks(chain_id, sequence_number)
|
||||
for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
|
||||
}
|
||||
|
||||
# Now we need to calculate any new links between chains caused by
|
||||
# the new events.
|
||||
|
@ -824,10 +860,38 @@ class PersistEventsStore:
|
|||
auth_chain_id, auth_sequence_number = chain_map[auth_id]
|
||||
|
||||
# Step 2a, add link between the event and auth event
|
||||
to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
|
||||
chain_links.add_link(
|
||||
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
|
||||
)
|
||||
|
||||
return to_return
|
||||
|
||||
@classmethod
|
||||
def _persist_chain_cover_index(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
db_pool: DatabasePool,
|
||||
new_event_links: Dict[str, NewEventChainLinks],
|
||||
) -> None:
|
||||
db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth_chains",
|
||||
keys=("event_id", "chain_id", "sequence_number"),
|
||||
values=[
|
||||
(event_id, new_links.chain_id, new_links.sequence_number)
|
||||
for event_id, new_links in new_event_links.items()
|
||||
],
|
||||
)
|
||||
|
||||
db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="event_auth_chain_to_calculate",
|
||||
keyvalues={},
|
||||
column="event_id",
|
||||
values=new_event_links,
|
||||
)
|
||||
|
||||
db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth_chain_links",
|
||||
|
@ -837,7 +901,16 @@ class PersistEventsStore:
|
|||
"target_chain_id",
|
||||
"target_sequence_number",
|
||||
),
|
||||
values=list(chain_links.get_additions()),
|
||||
values=[
|
||||
(
|
||||
new_links.chain_id,
|
||||
new_links.sequence_number,
|
||||
target_chain_id,
|
||||
target_sequence_number,
|
||||
)
|
||||
for new_links in new_event_links.values()
|
||||
for (target_chain_id, target_sequence_number) in new_links.links
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -446,7 +446,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Actually call the function that calculates the auth chain stuff.
|
||||
persist_events_store._persist_event_auth_chain_txn(txn, events)
|
||||
new_event_links = (
|
||||
persist_events_store.calculate_chain_cover_index_for_events_txn(
|
||||
txn, events[0].room_id, [e for e in events if e.is_state()]
|
||||
)
|
||||
)
|
||||
persist_events_store._persist_event_auth_chain_txn(
|
||||
txn, events, new_event_links
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
persist_events_store.db_pool.runInteraction(
|
||||
|
|
|
@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
events = [
|
||||
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
|
||||
for event_id in AUTH_GRAPH
|
||||
]
|
||||
new_event_links = (
|
||||
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||
txn, room_id, [e for e in events if e.is_state()]
|
||||
)
|
||||
)
|
||||
self.persist_events._persist_event_auth_chain_txn(
|
||||
txn,
|
||||
[
|
||||
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
|
||||
for event_id in AUTH_GRAPH
|
||||
],
|
||||
events,
|
||||
new_event_links,
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
|
@ -628,13 +635,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Insert all events apart from 'B'
|
||||
events = [
|
||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||
for event_id in auth_graph
|
||||
if event_id != "b"
|
||||
]
|
||||
new_event_links = (
|
||||
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||
txn, room_id, [e for e in events if e.is_state()]
|
||||
)
|
||||
)
|
||||
self.persist_events._persist_event_auth_chain_txn(
|
||||
txn,
|
||||
[
|
||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||
for event_id in auth_graph
|
||||
if event_id != "b"
|
||||
],
|
||||
events,
|
||||
new_event_links,
|
||||
)
|
||||
|
||||
# Now we insert the event 'B' without a chain cover, by temporarily
|
||||
|
@ -647,9 +661,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
updatevalues={"has_auth_chain_index": False},
|
||||
)
|
||||
|
||||
events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
|
||||
new_event_links = (
|
||||
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||
txn, room_id, [e for e in events if e.is_state()]
|
||||
)
|
||||
)
|
||||
self.persist_events._persist_event_auth_chain_txn(
|
||||
txn,
|
||||
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
|
||||
txn, events, new_event_links
|
||||
)
|
||||
|
||||
self.store.db_pool.simple_update_txn(
|
||||
|
|
Loading…
Reference in a new issue