Compare commits

...

4 commits

Author SHA1 Message Date
Erik Johnston 74785ebde1
Merge f237303c32 into 2c36a679ae 2024-06-14 09:55:27 +02:00
Erik Johnston f237303c32 Handle previously persisted events properly 2024-06-10 16:58:58 +01:00
Erik Johnston db67cd6893 Newsfile 2024-06-10 14:24:16 +01:00
Erik Johnston 5a287d9d14 Handle large chain calc better
We calculate the auth chain links outside of the main persist event
transaction to ensure that we do not block other event sending during
the calculation.
2024-06-10 14:22:58 +01:00
5 changed files with 223 additions and 88 deletions

1
changelog.d/17291.misc Normal file
View file

@ -0,0 +1 @@
Do not block event sending/receiving while calulating large event auth chains.

View file

@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
room_id, chunk
)
with Measure(self._clock, "calculate_chain_cover_index_for_events"):
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
room_id, [e for e, _ in chunk]
)
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)
return replaced_events

View file

@ -34,7 +34,6 @@ from typing import (
Optional,
Set,
Tuple,
Union,
cast,
)
@ -100,6 +99,14 @@ class DeltaState:
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
@attr.s(slots=True, auto_attribs=True)
class NewEventChainLinks:
chain_id: int
sequence_number: int
links: List[Tuple[int, int]] = attr.Factory(list)
class PersistEventsStore:
"""Contains all the functions for writing events to the database.
@ -148,6 +155,7 @@ class PersistEventsStore:
*,
state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]],
new_event_links: Dict[str, NewEventChainLinks],
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
) -> None:
@ -217,6 +225,7 @@ class PersistEventsStore:
inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
new_event_links=new_event_links,
)
persist_event_counter.inc(len(events_and_contexts))
@ -243,6 +252,87 @@ class PersistEventsStore:
(room_id,), frozenset(new_forward_extremities)
)
async def calculate_chain_cover_index_for_events(
self, room_id: str, events: Collection[EventBase]
) -> Dict[str, NewEventChainLinks]:
# Filter to state events, and ensure there are no duplicates.
state_events = []
seen_events = set()
for event in events:
if not event.is_state() or event.event_id in seen_events:
continue
state_events.append(event)
seen_events.add(event.event_id)
if not state_events:
return {}
return await self.db_pool.runInteraction(
"_calculate_chain_cover_index_for_events",
self.calculate_chain_cover_index_for_events_txn,
room_id,
state_events,
)
def calculate_chain_cover_index_for_events_txn(
self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
) -> Dict[str, NewEventChainLinks]:
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
row = self.db_pool.simple_select_one_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "has_auth_chain_index"),
allow_none=True,
)
if row is None:
return {}
# Filter out already persisted events.
rows = self.db_pool.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=[e.event_id for e in state_events],
keyvalues={},
retcols=("event_id",),
)
already_persisted_events = {event_id for event_id, in rows}
state_events = [
event
for event in state_events
if event.event_id in already_persisted_events
]
if not state_events:
return {}
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
# we don't need the overhead of fetching/parsing the full event JSON.
event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
event_to_room_id = {e.event_id: e.room_id for e in state_events}
return self._calculate_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
@ -358,6 +448,7 @@ class PersistEventsStore:
inhibit_local_membership_updates: bool,
state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]],
new_event_links: Dict[str, NewEventChainLinks],
) -> None:
"""Insert some number of room events into the necessary database tables.
@ -466,7 +557,9 @@ class PersistEventsStore:
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
self._persist_event_auth_chain_txn(
txn, [e for e, _ in events_and_contexts], new_event_links
)
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@ -496,6 +589,7 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
events: List[EventBase],
new_event_links: Dict[str, NewEventChainLinks],
) -> None:
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
@ -519,59 +613,8 @@ class PersistEventsStore:
],
)
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
rows = cast(
List[Tuple[str, Optional[Union[int, bool]]]],
self.db_pool.simple_select_many_txn(
txn,
table="rooms",
column="room_id",
iterable={event.room_id for event in events if event.is_state()},
keyvalues={},
retcols=("room_id", "has_auth_chain_index"),
),
)
rooms_using_chain_index = {
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
}
state_events = {
event.event_id: event
for event in events
if event.is_state() and event.room_id in rooms_using_chain_index
}
if not state_events:
return
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
# we don't need the overhead of fetching/parsing the full event JSON.
event_to_types = {
e.event_id: (e.type, e.state_key) for e in state_events.values()
}
event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values()
}
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
self._add_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
if new_event_links:
self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
@classmethod
def _add_chain_cover_index(
@ -583,6 +626,35 @@ class PersistEventsStore:
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection],
) -> None:
"""Calculate and persist the chain cover index for the given events.
Args:
event_to_room_id: Event ID to the room ID of the event
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
"""
new_event_links = cls._calculate_chain_cover_index(
txn,
db_pool,
event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
cls._persist_chain_cover_index(txn, db_pool, new_event_links)
@classmethod
def _calculate_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection],
) -> Dict[str, NewEventChainLinks]:
"""Calculate the chain cover index for the given events.
Args:
@ -708,11 +780,11 @@ class PersistEventsStore:
room_id = event_to_room_id.get(event_id)
if room_id:
e_type, state_key = event_to_types[event_id]
db_pool.simple_insert_txn(
db_pool.simple_upsert_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={"event_id": event_id},
values={
"event_id": event_id,
"room_id": room_id,
"type": e_type,
"state_key": state_key,
@ -724,7 +796,7 @@ class PersistEventsStore:
break
if not events_to_calc_chain_id_for:
return
return {}
# Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids(
@ -739,23 +811,10 @@ class PersistEventsStore:
)
chain_map.update(new_chain_tuples)
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
keys=("event_id", "chain_id", "sequence_number"),
values=[
(event_id, c_id, seq)
for event_id, (c_id, seq) in new_chain_tuples.items()
],
)
db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
values=new_chain_tuples,
)
to_return = {
event_id: NewEventChainLinks(chain_id, sequence_number)
for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
}
# Now we need to calculate any new links between chains caused by
# the new events.
@ -825,10 +884,38 @@ class PersistEventsStore:
auth_chain_id, auth_sequence_number = chain_map[auth_id]
# Step 2a, add link between the event and auth event
to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
chain_links.add_link(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)
return to_return
@classmethod
def _persist_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
new_event_links: Dict[str, NewEventChainLinks],
) -> None:
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
keys=("event_id", "chain_id", "sequence_number"),
values=[
(event_id, new_links.chain_id, new_links.sequence_number)
for event_id, new_links in new_event_links.items()
],
)
db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
values=new_event_links,
)
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
@ -838,7 +925,16 @@ class PersistEventsStore:
"target_chain_id",
"target_sequence_number",
),
values=list(chain_links.get_additions()),
values=[
(
new_links.chain_id,
new_links.sequence_number,
target_chain_id,
target_sequence_number,
)
for new_links in new_event_links.values()
for (target_chain_id, target_sequence_number) in new_links.links
],
)
@staticmethod

View file

@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
# Actually call the function that calculates the auth chain stuff.
persist_events_store._persist_event_auth_chain_txn(txn, events)
new_event_links = (
persist_events_store.calculate_chain_cover_index_for_events_txn(
txn, events[0].room_id, [e for e in events if e.is_state()]
)
)
persist_events_store._persist_event_auth_chain_txn(
txn, events, new_event_links
)
self.get_success(
persist_events_store.db_pool.runInteraction(

View file

@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
events = [
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
for event_id in AUTH_GRAPH
]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
for event_id in AUTH_GRAPH
],
events,
new_event_links,
)
self.get_success(
@ -628,13 +635,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
events = [
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
if event_id != "b"
]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
if event_id != "b"
],
events,
new_event_links,
)
# Now we insert the event 'B' without a chain cover, by temporarily
@ -647,9 +661,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn(
txn,
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
txn, events, new_event_links
)
self.store.db_pool.simple_update_txn(