From ca79b4d87df814ae69dd093253d500108d48e461 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 May 2024 16:31:59 +0100 Subject: [PATCH] Use a sortedset instead --- .../databases/main/event_federation.py | 28 +++---- tests/storage/test_purge.py | 74 +++++++++++++++++++ 2 files changed, 84 insertions(+), 18 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 68f30d893c..3dd53f2038 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -39,6 +39,7 @@ from typing import ( import attr from prometheus_client import Counter, Gauge +from sortedcontainers import SortedSet from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError @@ -373,24 +374,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # We fetch the links in batches. Separate batches will likely fetch the # same set of links (e.g. they'll always pull in the links to create - # event). To try and minimize the amount of redundant links, we sort the - # chain IDs in reverse, as there will be a correlation between the order - # of chain IDs and links (i.e., higher chain IDs are more likely to - # depend on lower chain IDs than vice versa). + # event). To try and minimize the amount of redundant links, we query + # the chain IDs in reverse order, as there will be a correlation between + # the order of chain IDs and links (i.e., higher chain IDs are more + # likely to depend on lower chain IDs than vice versa). BATCH_SIZE = 1000 - chains_to_fetch_list = list(chains_to_fetch) - chains_to_fetch_list.sort(reverse=True) + chains_to_fetch_sorted = SortedSet(chains_to_fetch) - seen_chains: Set[int] = set() - while chains_to_fetch_list: - batch2 = [ - c for c in chains_to_fetch_list[-BATCH_SIZE:] if c not in seen_chains - ] - chains_to_fetch_list = chains_to_fetch_list[:-BATCH_SIZE] - while len(batch2) < BATCH_SIZE and chains_to_fetch_list: - chain_id = chains_to_fetch_list.pop() - if chain_id not in seen_chains: - batch2.append(chain_id) + while chains_to_fetch_sorted: + batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE)) + chains_to_fetch_sorted.difference_update(batch2) clause, args = make_in_list_sql_clause( txn.database_engine, "origin_chain_id", batch2 @@ -409,8 +402,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas (origin_sequence_number, target_chain_id, target_sequence_number) ) - seen_chains.update(links) - seen_chains.update(batch2) + chains_to_fetch_sorted.difference_update(links) yield links diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 080d5640a5..9fa69f6581 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -25,6 +25,7 @@ from synapse.rest.client import room from synapse.server import HomeServer from synapse.util import Clock +from tests.test_utils.event_injection import inject_event from tests.unittest import HomeserverTestCase @@ -128,3 +129,76 @@ class PurgeTests(HomeserverTestCase): self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) + + def test_state_groups_state_decreases(self) -> None: + response = self.helper.send(self.room_id, body="first") + first_event_id = response["event_id"] + + batches = [] + + previous_event_id = first_event_id + for i in range(50): + state_event1 = self.get_success( + inject_event( + self.hs, + type="test.state", + sender=self.user_id, + state_key="", + room_id=self.room_id, + content={"key": i, "e": 1}, + prev_event_ids=[previous_event_id], + origin_server_ts=1, + ) + ) + + state_event2 = self.get_success( + inject_event( + self.hs, + type="test.state", + sender=self.user_id, + state_key="", + room_id=self.room_id, + content={"key": i, "e": 2}, + prev_event_ids=[previous_event_id], + origin_server_ts=2, + ) + ) + + # print(state_event2.origin_server_ts - state_event1.origin_server_ts) + + message_event = self.get_success( + inject_event( + self.hs, + type="dummy_event", + sender=self.user_id, + room_id=self.room_id, + content={}, + prev_event_ids=[state_event1.event_id, state_event2.event_id], + ) + ) + + token = self.get_success( + self.store.get_topological_token_for_event(state_event1.event_id) + ) + batches.append(token) + + previous_event_id = message_event.event_id + + self.helper.send(self.room_id, body="last event") + + def count_state_groups() -> int: + sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?" + rows = self.get_success( + self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id) + ) + return rows[0][0] + + print(count_state_groups()) + for token in batches: + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) + self.get_success( + self._storage_controllers.purge_events.purge_history( + self.room_id, token_str, False + ) + ) + print(count_state_groups())