diff --git a/changelog.d/9176.misc b/changelog.d/9176.misc new file mode 100644 index 0000000000..9c41d7b0f9 --- /dev/null +++ b/changelog.d/9176.misc @@ -0,0 +1 @@ +Speed up chain cover calculation when persisting a batch of state events at once. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5db7d7aaa8..ccda9f1caa 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -473,8 +473,9 @@ class PersistEventsStore: txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, ) - @staticmethod + @classmethod def _add_chain_cover_index( + cls, txn, db_pool: DatabasePool, event_to_room_id: Dict[str, str], @@ -614,60 +615,17 @@ class PersistEventsStore: if not events_to_calc_chain_id_for: return - # We now calculate the chain IDs/sequence numbers for the events. We - # do this by looking at the chain ID and sequence number of any auth - # event with the same type/state_key and incrementing the sequence - # number by one. If there was no match or the chain ID/sequence - # number is already taken we generate a new chain. - # - # We need to do this in a topologically sorted order as we want to - # generate chain IDs/sequence numbers of an event's auth events - # before the event itself. - chains_tuples_allocated = set() # type: Set[Tuple[int, int]] - new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] - for event_id in sorted_topologically( - events_to_calc_chain_id_for, event_to_auth_chain - ): - existing_chain_id = None - for auth_id in event_to_auth_chain.get(event_id, []): - if event_to_types.get(event_id) == event_to_types.get(auth_id): - existing_chain_id = chain_map[auth_id] - break - - new_chain_tuple = None - if existing_chain_id: - # We found a chain ID/sequence number candidate, check its - # not already taken. - proposed_new_id = existing_chain_id[0] - proposed_new_seq = existing_chain_id[1] + 1 - if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: - already_allocated = db_pool.simple_select_one_onecol_txn( - txn, - table="event_auth_chains", - keyvalues={ - "chain_id": proposed_new_id, - "sequence_number": proposed_new_seq, - }, - retcol="event_id", - allow_none=True, - ) - if already_allocated: - # Mark it as already allocated so we don't need to hit - # the DB again. - chains_tuples_allocated.add((proposed_new_id, proposed_new_seq)) - else: - new_chain_tuple = ( - proposed_new_id, - proposed_new_seq, - ) - - if not new_chain_tuple: - new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1) - - chains_tuples_allocated.add(new_chain_tuple) - - chain_map[event_id] = new_chain_tuple - new_chain_tuples[event_id] = new_chain_tuple + # Allocate chain ID/sequence numbers to each new event. + new_chain_tuples = cls._allocate_chain_ids( + txn, + db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, + events_to_calc_chain_id_for, + chain_map, + ) + chain_map.update(new_chain_tuples) db_pool.simple_insert_many_txn( txn, @@ -794,6 +752,137 @@ class PersistEventsStore: ], ) + @staticmethod + def _allocate_chain_ids( + txn, + db_pool: DatabasePool, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + events_to_calc_chain_id_for: Set[str], + chain_map: Dict[str, Tuple[int, int]], + ) -> Dict[str, Tuple[int, int]]: + """Allocates, but does not persist, chain ID/sequence numbers for the + events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index + for info on args) + """ + + # We now calculate the chain IDs/sequence numbers for the events. We do + # this by looking at the chain ID and sequence number of any auth event + # with the same type/state_key and incrementing the sequence number by + # one. If there was no match or the chain ID/sequence number is already + # taken we generate a new chain. + # + # We try to reduce the number of times that we hit the database by + # batching up calls, to make this more efficient when persisting large + # numbers of state events (e.g. during joins). + # + # We do this by: + # 1. Calculating for each event which auth event will be used to + # inherit the chain ID, i.e. converting the auth chain graph to a + # tree that we can allocate chains on. We also keep track of which + # existing chain IDs have been referenced. + # 2. Fetching the max allocated sequence number for each referenced + # existing chain ID, generating a map from chain ID to the max + # allocated sequence number. + # 3. Iterating over the tree and allocating a chain ID/seq no. to the + # new event, by incrementing the sequence number from the + # referenced event's chain ID/seq no. and checking that the + # incremented sequence number hasn't already been allocated (by + # looking in the map generated in the previous step). We generate a + # new chain if the sequence number has already been allocated. + # + + existing_chains = set() # type: Set[int] + tree = [] # type: List[Tuple[str, Optional[str]]] + + # We need to do this in a topologically sorted order as we want to + # generate chain IDs/sequence numbers of an event's auth events before + # the event itself. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + for auth_id in event_to_auth_chain.get(event_id, []): + if event_to_types.get(event_id) == event_to_types.get(auth_id): + existing_chain_id = chain_map.get(auth_id) + if existing_chain_id: + existing_chains.add(existing_chain_id[0]) + + tree.append((event_id, auth_id)) + break + else: + tree.append((event_id, None)) + + # Fetch the current max sequence number for each existing referenced chain. + sql = """ + SELECT chain_id, MAX(sequence_number) FROM event_auth_chains + WHERE %s + GROUP BY chain_id + """ + clause, args = make_in_list_sql_clause( + db_pool.engine, "chain_id", existing_chains + ) + txn.execute(sql % (clause,), args) + + chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int] + + # Allocate the new events chain ID/sequence numbers. + # + # To reduce the number of calls to the database we don't allocate a + # chain ID number in the loop, instead we use a temporary `object()` for + # each new chain ID. Once we've done the loop we generate the necessary + # number of new chain IDs in one call, replacing all temporary + # objects with real allocated chain IDs. + + unallocated_chain_ids = set() # type: Set[object] + new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]] + for event_id, auth_event_id in tree: + # If we reference an auth_event_id we fetch the allocated chain ID, + # either from the existing `chain_map` or the newly generated + # `new_chain_tuples` map. + existing_chain_id = None + if auth_event_id: + existing_chain_id = new_chain_tuples.get(auth_event_id) + if not existing_chain_id: + existing_chain_id = chain_map[auth_event_id] + + new_chain_tuple = None # type: Optional[Tuple[Any, int]] + if existing_chain_id: + # We found a chain ID/sequence number candidate, check its + # not already taken. + proposed_new_id = existing_chain_id[0] + proposed_new_seq = existing_chain_id[1] + 1 + + if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq: + new_chain_tuple = ( + proposed_new_id, + proposed_new_seq, + ) + + # If we need to start a new chain we allocate a temporary chain ID. + if not new_chain_tuple: + new_chain_tuple = (object(), 1) + unallocated_chain_ids.add(new_chain_tuple[0]) + + new_chain_tuples[event_id] = new_chain_tuple + chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] + + # Generate new chain IDs for all unallocated chain IDs. + newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + txn, len(unallocated_chain_ids) + ) + + # Map from potentially temporary chain ID to real chain ID + chain_id_to_allocated_map = dict( + zip(unallocated_chain_ids, newly_allocated_chain_ids) + ) # type: Dict[Any, int] + chain_id_to_allocated_map.update((c, c) for c in existing_chains) + + return { + event_id: (chain_id_to_allocated_map[chain_id], seq) + for event_id, (chain_id, seq) in new_chain_tuples.items() + } + def _persist_transaction_ids_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index c780ade077..0ec4dc2918 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -69,6 +69,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """Gets the next ID in the sequence""" ... + @abc.abstractmethod + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + """Get the next `n` IDs in the sequence""" + ... + @abc.abstractmethod def check_consistency( self, @@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + first_id = self._current_max_id + 1 + self._current_max_id += n + return [first_id + i for i in range(n)] + def check_consistency( self, db_conn: Connection,