diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index c26860e0d6..73bcc5e613 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -90,8 +90,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): state_filter: Optional[StateFilter] = None, ) -> Mapping[int, StateMap[str]]: """ - We can sort from smallest to largest state_group and re-use the work from the - small state_group for a larger one if we see that the edge chain links up. + TODO """ state_filter = state_filter or StateFilter.all() @@ -111,11 +110,22 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): # This may return multiple rows per (type, state_key), but last_value # should be the same. sql = """ - WITH RECURSIVE sgs(state_group) AS ( - VALUES(?::bigint) + WITH RECURSIVE sgs(state_group, state_group_reached) AS ( + VALUES(?::bigint, NULL::bigint) UNION ALL - SELECT prev_state_group FROM state_group_edges e, sgs s - WHERE s.state_group = e.state_group + SELECT + prev_state_group, + CASE + /* Specify state_groups we have already done the work for */ + WHEN @prev_state_group IN (%s) THEN prev_state_group + ELSE NULL + END AS state_group_reached + FROM + state_group_edges e, sgs s + WHERE + s.state_group = e.state_group + /* Stop when we connect up to another state_group that we already did the work for */ + AND s.state_group_reached IS NULL ) %s """ @@ -159,7 +169,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): f""" ( SELECT DISTINCT ON (type, state_key) - type, state_key, event_id + type, state_key, event_id, state_group FROM state_groups_state INNER JOIN sgs USING (state_group) WHERE {where_clause} @@ -180,7 +190,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): overall_select_clause = f""" SELECT DISTINCT ON (type, state_key) - type, state_key, event_id + type, state_key, event_id, state_group FROM state_groups_state WHERE state_group IN ( SELECT state_group FROM sgs @@ -188,15 +198,57 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): ORDER BY type, state_key, state_group DESC """ - for group in groups: + # We can sort from smallest to largest state_group and re-use the work from + # the small state_group for a larger one if we see that the edge chain links + # up. + sorted_groups = sorted(groups) + state_groups_we_have_already_fetched: Set[int] = set() + for group in sorted_groups: args: List[Union[int, str]] = [group] args.extend(overall_select_query_args) - txn.execute(sql % (overall_select_clause,), args) + state_groups_we_have_already_fetched_string = [ + f"{state_group}::bigint" + for state_group in state_groups_we_have_already_fetched + ].join(", ") + + txn.execute( + sql + % ( + state_groups_we_have_already_fetched_string, + overall_select_clause, + ), + args, + ) + + min_state_group: Optional[int] = None + partial_state_map_for_state_group: MutableStateMap[str] = {} for row in txn: - typ, state_key, event_id = row + typ, state_key, event_id, state_group = row key = (intern_string(typ), intern_string(state_key)) - results[group][key] = event_id + partial_state_map_for_state_group[key] = event_id + + if state_group < min_state_group or min_state_group is None: + min_state_group = state_group + + # If we see a state group edge link to a previous state_group that we + # already fetched from the database, link up the base state to the + # partial state we retrieved from the database to build on top of. + if results[min_state_group] is not None: + base_state_map = results[min_state_group].copy() + + results[group] = base_state_map.update( + partial_state_map_for_state_group + ) + else: + # It's also completely normal for us not to have a previous + # state_group to build on top of if this is the first group being + # processes or we are processing a bunch of groups from different + # rooms which of course will never link together. + results[group] = partial_state_map_for_state_group + + state_groups_we_have_already_fetched.add(group) + else: max_entries_returned = state_filter.max_entries_returned()