Re-use work from previous state_groups

This commit is contained in:
Eric Eastwood 2023-05-17 20:03:28 -05:00
parent 4676e53e65
commit 6a19afcdad

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@ -90,8 +90,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
state_filter: Optional[StateFilter] = None,
) -> Mapping[int, StateMap[str]]:
"""
We can sort from smallest to largest state_group and re-use the work from the
small state_group for a larger one if we see that the edge chain links up.
TODO
"""
state_filter = state_filter or StateFilter.all()
@ -111,11 +110,22 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE sgs(state_group) AS (
VALUES(?::bigint)
WITH RECURSIVE sgs(state_group, state_group_reached) AS (
VALUES(?::bigint, NULL::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, sgs s
WHERE s.state_group = e.state_group
SELECT
prev_state_group,
CASE
/* Specify state_groups we have already done the work for */
WHEN @prev_state_group IN (%s) THEN prev_state_group
ELSE NULL
END AS state_group_reached
FROM
state_group_edges e, sgs s
WHERE
s.state_group = e.state_group
/* Stop when we connect up to another state_group that we already did the work for */
AND s.state_group_reached IS NULL
)
%s
"""
@ -159,7 +169,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
f"""
(
SELECT DISTINCT ON (type, state_key)
type, state_key, event_id
type, state_key, event_id, state_group
FROM state_groups_state
INNER JOIN sgs USING (state_group)
WHERE {where_clause}
@ -180,7 +190,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
overall_select_clause = f"""
SELECT DISTINCT ON (type, state_key)
type, state_key, event_id
type, state_key, event_id, state_group
FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM sgs
@ -188,15 +198,57 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
ORDER BY type, state_key, state_group DESC
"""
for group in groups:
# We can sort from smallest to largest state_group and re-use the work from
# the small state_group for a larger one if we see that the edge chain links
# up.
sorted_groups = sorted(groups)
state_groups_we_have_already_fetched: Set[int] = set()
for group in sorted_groups:
args: List[Union[int, str]] = [group]
args.extend(overall_select_query_args)
txn.execute(sql % (overall_select_clause,), args)
state_groups_we_have_already_fetched_string = [
f"{state_group}::bigint"
for state_group in state_groups_we_have_already_fetched
].join(", ")
txn.execute(
sql
% (
state_groups_we_have_already_fetched_string,
overall_select_clause,
),
args,
)
min_state_group: Optional[int] = None
partial_state_map_for_state_group: MutableStateMap[str] = {}
for row in txn:
typ, state_key, event_id = row
typ, state_key, event_id, state_group = row
key = (intern_string(typ), intern_string(state_key))
results[group][key] = event_id
partial_state_map_for_state_group[key] = event_id
if state_group < min_state_group or min_state_group is None:
min_state_group = state_group
# If we see a state group edge link to a previous state_group that we
# already fetched from the database, link up the base state to the
# partial state we retrieved from the database to build on top of.
if results[min_state_group] is not None:
base_state_map = results[min_state_group].copy()
results[group] = base_state_map.update(
partial_state_map_for_state_group
)
else:
# It's also completely normal for us not to have a previous
# state_group to build on top of if this is the first group being
# processes or we are processing a bunch of groups from different
# rooms which of course will never link together.
results[group] = partial_state_map_for_state_group
state_groups_we_have_already_fetched.add(group)
else:
max_entries_returned = state_filter.max_entries_returned()