Convert simple_select_many_batch, simple_select_many_txn to tuples. (#16444)

This commit is contained in:
Patrick Cloke 2023-10-11 13:24:56 -04:00 committed by GitHub
parent d6b7d49a61
commit a4904dcb04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 641 additions and 443 deletions

1
changelog.d/16444.misc Normal file
View file

@ -0,0 +1 @@
Reduce memory allocations.

View file

@ -1874,9 +1874,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]] = None, keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch", desc: str = "simple_select_many_batch",
batch_size: int = 100, batch_size: int = 100,
) -> List[Dict[str, Any]]: ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows.
Filters rows by whether the value of `column` is in `iterable`. Filters rows by whether the value of `column` is in `iterable`.
@ -1888,10 +1888,13 @@ class DatabasePool:
keyvalues: dict of column names and values to select the rows with keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query batch_size: the number of rows for each select query
Returns:
The results as a list of tuples.
""" """
keyvalues = keyvalues or {} keyvalues = keyvalues or {}
results: List[Dict[str, Any]] = [] results: List[Tuple[Any, ...]] = []
for chunk in batch_iter(iterable, batch_size): for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction( rows = await self.runInteraction(
@ -1918,9 +1921,9 @@ class DatabasePool:
iterable: Collection[Any], iterable: Collection[Any],
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
retcols: Iterable[str], retcols: Iterable[str],
) -> List[Dict[str, Any]]: ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows.
Filters rows by whether the value of `column` is in `iterable`. Filters rows by whether the value of `column` is in `iterable`.
@ -1931,6 +1934,9 @@ class DatabasePool:
iterable: list iterable: list
keyvalues: dict of column names and values to select the rows with keyvalues: dict of column names and values to select the rows with
retcols: list of strings giving the names of the columns to return retcols: list of strings giving the names of the columns to return
Returns:
The results as a list of tuples.
""" """
if not iterable: if not iterable:
return [] return []
@ -1949,7 +1955,7 @@ class DatabasePool:
) )
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return txn.fetchall()
async def simple_update( async def simple_update(
self, self,

View file

@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# Note that this is more efficient than just dropping `device_id` from the query, # Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)` # since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query: if not device_ids_to_query:
user_device_dicts = self.db_pool.simple_select_many_txn( user_device_dicts = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="devices", table="devices",
column="user_id", column="user_id",
iterable=user_ids_to_query, iterable=user_ids_to_query,
keyvalues={"hidden": False}, keyvalues={"hidden": False},
retcols=("device_id",), retcols=("device_id",),
),
) )
device_ids_to_query.update( device_ids_to_query.update({row[0] for row in user_device_dicts})
{row["device_id"] for row in user_device_dicts}
)
if not device_ids_to_query: if not device_ids_to_query:
# We've ended up with no devices to query. # We've ended up with no devices to query.
@ -845,20 +846,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# We exclude hidden devices (such as cross-signing keys) here as they are # We exclude hidden devices (such as cross-signing keys) here as they are
# not expected to receive to-device messages. # not expected to receive to-device messages.
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="devices", table="devices",
keyvalues={"user_id": user_id, "hidden": False}, keyvalues={"user_id": user_id, "hidden": False},
column="device_id", column="device_id",
iterable=devices, iterable=devices,
retcols=("device_id",), retcols=("device_id",),
),
) )
for row in rows: for (device_id,) in rows:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server
device_id = row["device_id"]
with start_active_span("serialise_to_device_message"): with start_active_span("serialise_to_device_message"):
msg = messages_by_device[device_id] msg = messages_by_device[device_id]
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"]) set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])

View file

@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_device_list_last_stream_id_for_remotes( async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str] self, user_ids: Iterable[str]
) -> Mapping[str, Optional[str]]: ) -> Mapping[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
retcols=("user_id", "stream_id"), retcols=("user_id", "stream_id"),
desc="get_device_list_last_stream_id_for_remotes", desc="get_device_list_last_stream_id_for_remotes",
),
) )
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids} results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows}) results.update(rows)
return results return results
@ -1077,19 +1080,27 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync. The IDs of users whose device lists need resync.
""" """
if user_ids: if user_ids:
rows = await self.db_pool.simple_select_many_batch( row_tuples = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync", table="device_lists_remote_resync",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
retcols=("user_id",), retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync_with_iterable", desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
) )
return {row[0] for row in row_tuples}
else: else:
rows = await self.db_pool.simple_select_list( rows = cast(
List[Dict[str, str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync", table="device_lists_remote_resync",
keyvalues=None, keyvalues=None,
retcols=("user_id",), retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync", desc="get_user_ids_requiring_device_list_resync",
),
) )
return {row["user_id"] for row in rows} return {row["user_id"] for row in rows}

View file

@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
A map from (algorithm, key_id) to json string for key A map from (algorithm, key_id) to json string for key
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, str, str]],
await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
column="key_id", column="key_id",
iterable=key_ids, iterable=key_ids,
retcols=("algorithm", "key_id", "key_json"), retcols=("algorithm", "key_id", "key_json"),
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check", desc="add_e2e_one_time_keys_check",
),
) )
result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result return result

View file

@ -1049,7 +1049,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args: Args:
event_ids: The event IDs to calculate the max depth of. event_ids: The event IDs to calculate the max depth of.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, int]],
await self.db_pool.simple_select_many_batch(
table="events", table="events",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
@ -1058,6 +1060,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"depth", "depth",
), ),
desc="get_max_depth_of", desc="get_max_depth_of",
),
) )
if not rows: if not rows:
@ -1065,10 +1068,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else: else:
max_depth_event_id = "" max_depth_event_id = ""
current_max_depth = 0 current_max_depth = 0
for row in rows: for event_id, depth in rows:
if row["depth"] > current_max_depth: if depth > current_max_depth:
max_depth_event_id = row["event_id"] max_depth_event_id = event_id
current_max_depth = row["depth"] current_max_depth = depth
return max_depth_event_id, current_max_depth return max_depth_event_id, current_max_depth
@ -1078,7 +1081,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args: Args:
event_ids: The event IDs to calculate the max depth of. event_ids: The event IDs to calculate the max depth of.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, int]],
await self.db_pool.simple_select_many_batch(
table="events", table="events",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
@ -1087,6 +1092,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"depth", "depth",
), ),
desc="get_min_depth_of", desc="get_min_depth_of",
),
) )
if not rows: if not rows:
@ -1094,10 +1100,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else: else:
min_depth_event_id = "" min_depth_event_id = ""
current_min_depth = MAX_DEPTH current_min_depth = MAX_DEPTH
for row in rows: for event_id, depth in rows:
if row["depth"] < current_min_depth: if depth < current_min_depth:
min_depth_event_id = row["event_id"] min_depth_event_id = event_id
current_min_depth = row["depth"] current_min_depth = depth
return min_depth_event_id, current_min_depth return min_depth_event_id, current_min_depth
@ -1553,19 +1559,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A filtered down list of `event_ids` that have previous failed pull attempts. A filtered down list of `event_ids` that have previous failed pull attempts.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts", table="event_failed_pull_attempts",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
keyvalues={}, keyvalues={},
retcols=("event_id",), retcols=("event_id",),
desc="get_event_ids_with_failed_pull_attempts", desc="get_event_ids_with_failed_pull_attempts",
),
) )
event_ids_with_failed_pull_attempts: Set[str] = { return {row[0] for row in rows}
row["event_id"] for row in rows
}
return event_ids_with_failed_pull_attempts
@trace @trace
async def get_event_ids_to_not_pull_from_backoff( async def get_event_ids_to_not_pull_from_backoff(
@ -1585,7 +1590,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A dictionary of event_ids that should not be attempted to be pulled and the A dictionary of event_ids that should not be attempted to be pulled and the
next timestamp at which we may try pulling them again. next timestamp at which we may try pulling them again.
""" """
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( event_failed_pull_attempts = cast(
List[Tuple[str, int, int]],
await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts", table="event_failed_pull_attempts",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
@ -1596,21 +1603,21 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"num_attempts", "num_attempts",
), ),
desc="get_event_ids_to_not_pull_from_backoff", desc="get_event_ids_to_not_pull_from_backoff",
),
) )
current_time = self._clock.time_msec() current_time = self._clock.time_msec()
event_ids_with_backoff = {} event_ids_with_backoff = {}
for event_failed_pull_attempt in event_failed_pull_attempts: for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
event_id = event_failed_pull_attempt["event_id"]
# Exponential back-off (up to the upper bound) so we don't try to # Exponential back-off (up to the upper bound) so we don't try to
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
backoff_end_time = ( backoff_end_time = (
event_failed_pull_attempt["last_attempt_ts"] last_attempt_ts
+ ( + (
2 2
** min( ** min(
event_failed_pull_attempt["num_attempts"], num_attempts,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
) )
) )

View file

@ -27,6 +27,7 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
Union,
cast, cast,
) )
@ -501,16 +502,19 @@ class PersistEventsStore:
# We ignore legacy rooms that we aren't filling the chain cover index # We ignore legacy rooms that we aren't filling the chain cover index
# for. # for.
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[str, Optional[Union[int, bool]]]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="rooms", table="rooms",
column="room_id", column="room_id",
iterable={event.room_id for event in events if event.is_state()}, iterable={event.room_id for event in events if event.is_state()},
keyvalues={}, keyvalues={},
retcols=("room_id", "has_auth_chain_index"), retcols=("room_id", "has_auth_chain_index"),
),
) )
rooms_using_chain_index = { rooms_using_chain_index = {
row["room_id"] for row in rows if row["has_auth_chain_index"] room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
} }
state_events = { state_events = {
@ -571,19 +575,18 @@ class PersistEventsStore:
# We check if there are any events that need to be handled in the rooms # We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where # we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted. # we didn't have the auth chain when we first persisted.
rows = db_pool.simple_select_many_txn( auth_chain_to_calc_rows = cast(
List[Tuple[str, str, str]],
db_pool.simple_select_many_txn(
txn, txn,
table="event_auth_chain_to_calculate", table="event_auth_chain_to_calculate",
keyvalues={}, keyvalues={},
column="room_id", column="room_id",
iterable=set(event_to_room_id.values()), iterable=set(event_to_room_id.values()),
retcols=("event_id", "type", "state_key"), retcols=("event_id", "type", "state_key"),
),
) )
for row in rows: for event_id, event_type, state_key in auth_chain_to_calc_rows:
event_id = row["event_id"]
event_type = row["type"]
state_key = row["state_key"]
# (We could pull out the auth events for all rows at once using # (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always # simple_select_many, but this case happens rarely and almost always
# with a single row.) # with a single row.)
@ -753,7 +756,9 @@ class PersistEventsStore:
# Step 1, fetch all existing links from all the chains we've seen # Step 1, fetch all existing links from all the chains we've seen
# referenced. # referenced.
chain_links = _LinkMap() chain_links = _LinkMap()
rows = db_pool.simple_select_many_txn( auth_chain_rows = cast(
List[Tuple[int, int, int, int]],
db_pool.simple_select_many_txn(
txn, txn,
table="event_auth_chain_links", table="event_auth_chain_links",
column="origin_chain_id", column="origin_chain_id",
@ -765,11 +770,17 @@ class PersistEventsStore:
"target_chain_id", "target_chain_id",
"target_sequence_number", "target_sequence_number",
), ),
),
) )
for row in rows: for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in auth_chain_rows:
chain_links.add_link( chain_links.add_link(
(row["origin_chain_id"], row["origin_sequence_number"]), (origin_chain_id, origin_sequence_number),
(row["target_chain_id"], row["target_sequence_number"]), (target_chain_id, target_sequence_number),
new=False, new=False,
) )

View file

@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks: for chunk in chunks:
ev_rows = self.db_pool.simple_select_many_txn( ev_rows = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="event_json", table="event_json",
column="event_id", column="event_id",
iterable=chunk, iterable=chunk,
retcols=["event_id", "json"], retcols=["event_id", "json"],
keyvalues={}, keyvalues={},
),
) )
for row in ev_rows: for event_id, json in ev_rows:
event_id = row["event_id"] event_json = db_to_json(json)
event_json = db_to_json(row["json"])
try: try:
origin_server_ts = event_json["origin_server_ts"] origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError): except (KeyError, AttributeError):
@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if deleted: if deleted:
# We now need to invalidate the caches of these rooms # We now need to invalidate the caches of these rooms
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="events", table="events",
column="event_id", column="event_id",
iterable=to_delete, iterable=to_delete,
keyvalues={}, keyvalues={},
retcols=("room_id",), retcols=("room_id",),
),
) )
room_ids = {row["room_id"] for row in rows} room_ids = {row[0] for row in rows}
for room_id in room_ids: for room_id in room_ids:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
count = len(rows) count = len(rows)
# We also need to fetch the auth events for them. # We also need to fetch the auth events for them.
auth_events = self.db_pool.simple_select_many_txn( auth_events = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="event_auth", table="event_auth",
column="event_id", column="event_id",
iterable=event_to_room_id, iterable=event_to_room_id,
keyvalues={}, keyvalues={},
retcols=("event_id", "auth_id"), retcols=("event_id", "auth_id"),
),
) )
event_to_auth_chain: Dict[str, List[str]] = {} event_to_auth_chain: Dict[str, List[str]] = {}
for row in auth_events: for event_id, auth_id in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) event_to_auth_chain.setdefault(event_id, []).append(auth_id)
# Calculate and persist the chain cover index for this set of events. # Calculate and persist the chain cover index for this set of events.
# #

View file

@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and """Given a list of event ids, check if we have already processed and
stored them as non outliers. stored them as non outliers.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="events", table="events",
retcols=("event_id",), retcols=("event_id",),
column="event_id", column="event_id",
iterable=list(event_ids), iterable=list(event_ids),
keyvalues={"outlier": False}, keyvalues={"outlier": False},
desc="have_events_in_timeline", desc="have_events_in_timeline",
),
) )
return {r["event_id"] for r in rows} return {r[0] for r in rows}
@trace @trace
@tag_args @tag_args
@ -2336,15 +2339,18 @@ class EventsWorkerStore(SQLBaseStore):
a dict mapping from event id to partial-stateness. We return True for a dict mapping from event id to partial-stateness. We return True for
any of the events which are unknown (or are outliers). any of the events which are unknown (or are outliers).
""" """
result = await self.db_pool.simple_select_many_batch( result = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="partial_state_events", table="partial_state_events",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
retcols=["event_id"], retcols=["event_id"],
desc="get_partial_state_events", desc="get_partial_state_events",
),
) )
# convert the result to a dict, to make @cachedList work # convert the result to a dict, to make @cachedList work
partial = {r["event_id"] for r in result} partial = {r[0] for r in result}
return {e_id: e_id in partial for e_id in event_ids} return {e_id: e_id in partial for e_id in event_ids}
@cached() @cached()

View file

@ -16,7 +16,7 @@
import itertools import itertools
import json import json
import logging import logging
from typing import Dict, Iterable, Mapping, Optional, Tuple from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -205,7 +205,9 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent. If we have multiple entries for a given key ID, returns the most recent.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
await self.db_pool.simple_select_many_batch(
table="server_keys_json", table="server_keys_json",
column="key_id", column="key_id",
iterable=key_ids, iterable=key_ids,
@ -218,22 +220,24 @@ class KeyStore(CacheInvalidationWorkerStore):
"key_json", "key_json",
), ),
desc="get_server_keys_json_for_remote", desc="get_server_keys_json_for_remote",
),
) )
if not rows: if not rows:
return {} return {}
# We sort the rows so that the most recently added entry is picked up. # We sort the rows by ts_added_ms so that the most recently added entry
rows.sort(key=lambda r: r["ts_added_ms"]) # will stomp over older entries in the dictionary.
rows.sort(key=lambda r: r[2])
return { return {
row["key_id"]: FetchKeyResultForRemote( key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview. # Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]), key_json=bytes(key_json),
valid_until_ts=row["ts_valid_until_ms"], valid_until_ts=ts_valid_until_ms,
added_ts=row["ts_added_ms"], added_ts=ts_added_ms,
) )
for row in rows for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
} }
async def get_all_server_keys_json_for_remote( async def get_all_server_keys_json_for_remote(
@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore):
if not rows: if not rows:
return {} return {}
# We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary.
rows.sort(key=lambda r: r["ts_added_ms"]) rows.sort(key=lambda r: r["ts_added_ms"])
return { return {

View file

@ -261,7 +261,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
async def get_presence_for_users( async def get_presence_for_users(
self, user_ids: Iterable[str] self, user_ids: Iterable[str]
) -> Mapping[str, UserPresenceState]: ) -> Mapping[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch( # TODO All these columns are nullable, but we don't expect that:
# https://github.com/matrix-org/synapse/issues/16467
rows = cast(
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
await self.db_pool.simple_select_many_batch(
table="presence_stream", table="presence_stream",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
@ -276,12 +280,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
"currently_active", "currently_active",
), ),
desc="get_presence_for_users", desc="get_presence_for_users",
),
) )
for row in rows: return {
row["currently_active"] = bool(row["currently_active"]) user_id: UserPresenceState(
user_id=user_id,
return {row["user_id"]: UserPresenceState(**row) for row in rows} state=state,
last_active_ts=last_active_ts,
last_federation_update_ts=last_federation_update_ts,
last_user_sync_ts=last_user_sync_ts,
status_msg=status_msg,
currently_active=bool(currently_active),
)
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
}
async def should_user_receive_full_presence_with_token( async def should_user_receive_full_presence_with_token(
self, self,
@ -386,6 +399,8 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
limit = 100 limit = 100
offset = 0 offset = 0
while True: while True:
# TODO All these columns are nullable, but we don't expect that:
# https://github.com/matrix-org/synapse/issues/16467
rows = cast( rows = cast(
List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]], List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
await self.db_pool.runInteraction( await self.db_pool.runInteraction(

View file

@ -62,20 +62,34 @@ logger = logging.getLogger(__name__)
def _load_rules( def _load_rules(
rawrules: List[JsonDict], rawrules: List[Tuple[str, int, str, str]],
enabled_map: Dict[str, bool], enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig, experimental_config: ExperimentalConfig,
) -> FilteredPushRules: ) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full """Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object. `FilteredPushRules` object.
Args:
rawrules: List of tuples of:
* rule ID
* Priority lass
* Conditions (as serialized JSON)
* Actions (as serialized JSON)
enabled_map: A dictionary of rule ID to a boolean of whether the rule is
enabled. This might not include all rule IDs from rawrules.
experimental_config: The `experimental_features` section of the Synapse
config. (Used to check if various features are enabled.)
Returns:
A new FilteredPushRules object.
""" """
ruleslist = [ ruleslist = [
PushRule.from_db( PushRule.from_db(
rule_id=rawrule["rule_id"], rule_id=rawrule[0],
priority_class=rawrule["priority_class"], priority_class=rawrule[1],
conditions=rawrule["conditions"], conditions=rawrule[2],
actions=rawrule["actions"], actions=rawrule[3],
) )
for rawrule in rawrules for rawrule in rawrules
] ]
@ -183,7 +197,19 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id) enabled_map = await self.get_push_rules_enabled_for_user(user_id)
return _load_rules(rows, enabled_map, self.hs.config.experimental) return _load_rules(
[
(
row["rule_id"],
row["priority_class"],
row["conditions"],
row["actions"],
)
for row in rows
],
enabled_map,
self.hs.config.experimental,
)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list( results = await self.db_pool.simple_select_list(
@ -221,21 +247,36 @@ class PushRulesWorkerStore(
if not user_ids: if not user_ids:
return {} return {}
raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = {
user_id: [] for user_id in user_ids
}
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, str, int, int, str, str]],
await self.db_pool.simple_select_many_batch(
table="push_rules", table="push_rules",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,
retcols=("*",), retcols=(
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
),
desc="bulk_get_push_rules", desc="bulk_get_push_rules",
batch_size=1000, batch_size=1000,
),
) )
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) # Sort by highest priority_class, then highest priority.
rows.sort(key=lambda row: (-int(row[2]), -int(row[3])))
for row in rows: for user_name, rule_id, priority_class, _, conditions, actions in rows:
raw_rules.setdefault(row["user_name"], []).append(row) raw_rules.setdefault(user_name, []).append(
(rule_id, priority_class, conditions, actions)
)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
@ -256,17 +297,19 @@ class PushRulesWorkerStore(
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, str, Optional[int]]],
await self.db_pool.simple_select_many_batch(
table="push_rules_enable", table="push_rules_enable",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"), retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled", desc="bulk_get_push_rules_enabled",
batch_size=1000, batch_size=1000,
),
) )
for row in rows: for user_name, rule_id, enabled in rows:
enabled = bool(row["enabled"]) results.setdefault(user_name, {})[rule_id] = bool(enabled)
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
return results return results
async def get_all_push_rule_updates( async def get_all_push_rule_updates(

View file

@ -349,16 +349,19 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_with_types_txn( def get_all_relation_ids_for_event_with_types_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[str]: ) -> List[str]:
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn=txn, txn=txn,
table="event_relations", table="event_relations",
column="relation_type", column="relation_type",
iterable=relation_types, iterable=relation_types,
keyvalues={"relates_to_id": event_id}, keyvalues={"relates_to_id": event_id},
retcols=["event_id"], retcols=["event_id"],
),
) )
return [row["event_id"] for row in rows] return [row[0] for row in rows]
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event_with_types", desc="get_all_relation_ids_for_event_with_types",

View file

@ -1296,14 +1296,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
complete. complete.
""" """
rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="partial_state_rooms", table="partial_state_rooms",
column="room_id", column="room_id",
iterable=room_ids, iterable=room_ids,
retcols=("room_id",), retcols=("room_id",),
desc="is_partial_state_room_batched", desc="is_partial_state_room_batched",
),
) )
partial_state_rooms = {row_dict["room_id"] for row_dict in rows} partial_state_rooms = {row[0] for row in rows}
return {room_id: room_id in partial_state_rooms for room_id in room_ids} return {room_id: room_id in partial_state_rooms for room_id in room_ids}
async def get_join_event_id_and_device_lists_stream_id_for_partial_state( async def get_join_event_id_and_device_lists_stream_id_for_partial_state(

View file

@ -27,6 +27,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
Union, Union,
cast,
) )
import attr import attr
@ -683,7 +684,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from user_id to set of rooms that is currently in. Map from user_id to set of rooms that is currently in.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_many_batch(
table="current_state_events", table="current_state_events",
column="state_key", column="state_key",
iterable=user_ids, iterable=user_ids,
@ -696,12 +699,13 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"membership": Membership.JOIN, "membership": Membership.JOIN,
}, },
desc="get_rooms_for_users", desc="get_rooms_for_users",
),
) )
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
for row in rows: for state_key, room_id in rows:
user_rooms[row["state_key"]].add(row["room_id"]) user_rooms[state_key].add(room_id)
return {key: frozenset(rooms) for key, rooms in user_rooms.items()} return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from event ID to `user_id`, or None if event is not a join. Map from event ID to `user_id`, or None if event is not a join.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_many_batch(
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
retcols=("user_id", "event_id"), retcols=("event_id", "user_id"),
keyvalues={"membership": Membership.JOIN}, keyvalues={"membership": Membership.JOIN},
batch_size=1000, batch_size=1000,
desc="_get_user_ids_from_membership_event_ids", desc="_get_user_ids_from_membership_event_ids",
),
) )
return {row["event_id"]: row["user_id"] for row in rows} return dict(rows)
@cached(max_entries=10000) @cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool: async def is_host_joined(self, room_id: str, host: str) -> bool:
@ -1202,7 +1209,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
membership event, otherwise the value is None. membership event, otherwise the value is None.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, str, str]],
await self.db_pool.simple_select_many_batch(
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=member_event_ids, iterable=member_event_ids,
@ -1210,13 +1219,12 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
keyvalues={}, keyvalues={},
batch_size=500, batch_size=500,
desc="get_membership_from_event_ids", desc="get_membership_from_event_ids",
),
) )
return { return {
row["event_id"]: EventIdMembership( event_id: EventIdMembership(membership=membership, user_id=user_id)
membership=row["membership"], user_id=row["user_id"] for user_id, membership, event_id in rows
)
for row in rows
} }
async def is_local_host_in_room_ignoring_users( async def is_local_host_in_room_ignoring_users(

View file

@ -20,10 +20,12 @@ from typing import (
Collection, Collection,
Dict, Dict,
Iterable, Iterable,
List,
Mapping, Mapping,
Optional, Optional,
Set, Set,
Tuple, Tuple,
cast,
) )
import attr import attr
@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises: Raises:
RuntimeError if the state is unknown at any of the given events RuntimeError if the state is unknown at any of the given events
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, int]],
await self.db_pool.simple_select_many_batch(
table="event_to_state_groups", table="event_to_state_groups",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
keyvalues={}, keyvalues={},
retcols=("event_id", "state_group"), retcols=("event_id", "state_group"),
desc="_get_state_group_for_events", desc="_get_state_group_for_events",
),
) )
res = {row["event_id"]: row["state_group"] for row in rows} res = dict(rows)
for e in event_ids: for e in event_ids:
if e not in res: if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e) raise RuntimeError("No state group for unknown or outlier event %s" % e)
@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The subset of state groups that are referenced. The subset of state groups that are referenced.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[int]],
await self.db_pool.simple_select_many_batch(
table="event_to_state_groups", table="event_to_state_groups",
column="state_group", column="state_group",
iterable=state_groups, iterable=state_groups,
keyvalues={}, keyvalues={},
retcols=("DISTINCT state_group",), retcols=("DISTINCT state_group",),
desc="get_referenced_state_groups", desc="get_referenced_state_groups",
),
) )
return {row["state_group"] for row in rows} return {row[0] for row in rows}
async def update_state_for_partial_state_event( async def update_state_for_partial_state_event(
self, self,
@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
# potentially stale, since there may have been a period where the # potentially stale, since there may have been a period where the
# server didn't share a room with the remote user and therefore may # server didn't share a room with the remote user and therefore may
# have missed any device updates. # have missed any device updates.
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="current_state_events", table="current_state_events",
column="room_id", column="room_id",
iterable=to_delete, iterable=to_delete,
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, keyvalues={
"type": EventTypes.Member,
"membership": Membership.JOIN,
},
retcols=("state_key",), retcols=("state_key",),
),
) )
potentially_left_users = {row["state_key"] for row in rows} potentially_left_users = {row[0] for row in rows}
# Now lets actually delete the rooms from the DB. # Now lets actually delete the rooms from the DB.
self.db_pool.simple_delete_many_txn( self.db_pool.simple_delete_many_txn(

View file

@ -506,7 +506,9 @@ class StatsStore(StateDeltasStore):
) -> Tuple[List[str], Dict[str, int], int, List[str], int]: ) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined] pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="current_state_events", table="current_state_events",
column="type", column="type",
@ -522,9 +524,10 @@ class StatsStore(StateDeltasStore):
], ],
keyvalues={"room_id": room_id, "state_key": ""}, keyvalues={"room_id": room_id, "state_key": ""},
retcols=["event_id"], retcols=["event_id"],
),
) )
event_ids = cast(List[str], [row["event_id"] for row in rows]) event_ids = [row[0] for row in rows]
txn.execute( txn.execute(
""" """

View file

@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
async def get_destination_retry_timings_batch( async def get_destination_retry_timings_batch(
self, destinations: StrCollection self, destinations: StrCollection
) -> Mapping[str, Optional[DestinationRetryTimings]]: ) -> Mapping[str, Optional[DestinationRetryTimings]]:
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str, Optional[int], Optional[int], Optional[int]]],
await self.db_pool.simple_select_many_batch(
table="destinations", table="destinations",
iterable=destinations, iterable=destinations,
column="destination", column="destination",
retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), retcols=(
"destination",
"failure_ts",
"retry_last_ts",
"retry_interval",
),
desc="get_destination_retry_timings_batch", desc="get_destination_retry_timings_batch",
),
) )
return { return {
row.pop("destination"): DestinationRetryTimings(**row) destination: DestinationRetryTimings(
for row in rows failure_ts, retry_last_ts, retry_interval
if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"] )
for destination, failure_ts, retry_last_ts, retry_interval in rows
if retry_last_ts and failure_ts and retry_interval
} }
async def set_destination_retry_timings( async def set_destination_retry_timings(

View file

@ -337,13 +337,16 @@ class UIAuthWorkerStore(SQLBaseStore):
# If a registration token was used, decrement the pending counter # If a registration token was used, decrement the pending counter
# before deleting the session. # before deleting the session.
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="ui_auth_sessions_credentials", table="ui_auth_sessions_credentials",
column="session_id", column="session_id",
iterable=session_ids, iterable=session_ids,
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
retcols=["result"], retcols=["result"],
),
) )
# Get the tokens used and how much pending needs to be decremented by. # Get the tokens used and how much pending needs to be decremented by.
@ -353,23 +356,25 @@ class UIAuthWorkerStore(SQLBaseStore):
# registration token stage for that session will be True. # registration token stage for that session will be True.
# If a token was used to authenticate, but registration was # If a token was used to authenticate, but registration was
# never completed, the result will be the token used. # never completed, the result will be the token used.
token = db_to_json(r["result"]) token = db_to_json(r[0])
if isinstance(token, str): if isinstance(token, str):
token_counts[token] = token_counts.get(token, 0) + 1 token_counts[token] = token_counts.get(token, 0) + 1
# Update the `pending` counters. # Update the `pending` counters.
if len(token_counts) > 0: if len(token_counts) > 0:
token_rows = self.db_pool.simple_select_many_txn( token_rows = cast(
List[Tuple[str, int]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="registration_tokens", table="registration_tokens",
column="token", column="token",
iterable=list(token_counts.keys()), iterable=list(token_counts.keys()),
keyvalues={}, keyvalues={},
retcols=["token", "pending"], retcols=["token", "pending"],
),
) )
for token_row in token_rows: for token, pending in token_rows:
token = token_row["token"] new_pending = pending - token_counts[token]
new_pending = token_row["pending"] - token_counts[token]
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="registration_tokens", table="registration_tokens",

View file

@ -410,7 +410,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
# Next fetch their profiles. Note that not all users have profiles. # Next fetch their profiles. Note that not all users have profiles.
profile_rows = self.db_pool.simple_select_many_txn( profile_rows = cast(
List[Tuple[str, Optional[str], Optional[str]]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="profiles", table="profiles",
column="full_user_id", column="full_user_id",
@ -421,14 +423,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"avatar_url", "avatar_url",
), ),
keyvalues={}, keyvalues={},
),
) )
profiles = { profiles = {
row["full_user_id"]: _UserDirProfile( full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url)
row["full_user_id"], for full_user_id, displayname, avatar_url in profile_rows
row["displayname"],
row["avatar_url"],
)
for row in profile_rows
} }
profiles_to_insert = [ profiles_to_insert = [
@ -517,7 +516,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined] and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined]
] ]
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[str, Optional[str]]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="users", table="users",
column="name", column="name",
@ -526,9 +527,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"deactivated": 0, "deactivated": 0,
}, },
retcols=("name", "user_type"), retcols=("name", "user_type"),
),
) )
return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT] return [name for name, user_type in rows if user_type != UserTypes.SUPPORT]
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable""" """Check if the room is either world_readable or publically joinable"""

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, Mapping from typing import Iterable, List, Mapping, Tuple, cast
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
Returns: Returns:
for each user, whether the user has requested erasure. for each user, whether the user has requested erasure.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="erased_users", table="erased_users",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
retcols=("user_id",), retcols=("user_id",),
desc="are_users_erased", desc="are_users_erased",
),
) )
erased_users = {row["user_id"] for row in rows} erased_users = {row[0] for row in rows}
return {u: u in erased_users for u in user_ids} return {u: u in erased_users for u in user_ids}

View file

@ -13,7 +13,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)
import attr import attr
@ -730,19 +740,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"[purge] found %i state groups to delete", len(state_groups_to_delete) "[purge] found %i state groups to delete", len(state_groups_to_delete)
) )
rows = self.db_pool.simple_select_many_txn( rows = cast(
List[Tuple[int]],
self.db_pool.simple_select_many_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
column="prev_state_group", column="prev_state_group",
iterable=state_groups_to_delete, iterable=state_groups_to_delete,
keyvalues={}, keyvalues={},
retcols=("state_group",), retcols=("state_group",),
),
) )
remaining_state_groups = { remaining_state_groups = {
row["state_group"] state_group
for row in rows for state_group, in rows
if row["state_group"] not in state_groups_to_delete if state_group not in state_groups_to_delete
} }
logger.info( logger.info(
@ -799,16 +812,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
A mapping from state group to previous state group. A mapping from state group to previous state group.
""" """
rows = await self.db_pool.simple_select_many_batch( rows = cast(
List[Tuple[int, int]],
await self.db_pool.simple_select_many_batch(
table="state_group_edges", table="state_group_edges",
column="prev_state_group", column="prev_state_group",
iterable=state_groups, iterable=state_groups,
keyvalues={}, keyvalues={},
retcols=("prev_state_group", "state_group"), retcols=("state_group", "prev_state_group"),
desc="get_previous_state_groups", desc="get_previous_state_groups",
),
) )
return {row["state_group"]: row["prev_state_group"] for row in rows} return dict(rows)
async def purge_room_state( async def purge_room_state(
self, room_id: str, state_groups_to_delete: Collection[int] self, room_id: str, state_groups_to_delete: Collection[int]

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest from twisted.trial import unittest
@ -421,7 +421,9 @@ class EventChainStoreTestCase(HomeserverTestCase):
self, events: List[EventBase] self, events: List[EventBase]
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
# Fetch the map from event ID -> (chain ID, sequence number) # Fetch the map from event ID -> (chain ID, sequence number)
rows = self.get_success( rows = cast(
List[Tuple[str, int, int]],
self.get_success(
self.store.db_pool.simple_select_many_batch( self.store.db_pool.simple_select_many_batch(
table="event_auth_chains", table="event_auth_chains",
column="event_id", column="event_id",
@ -429,14 +431,18 @@ class EventChainStoreTestCase(HomeserverTestCase):
retcols=("event_id", "chain_id", "sequence_number"), retcols=("event_id", "chain_id", "sequence_number"),
keyvalues={}, keyvalues={},
) )
),
) )
chain_map = { chain_map = {
row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows event_id: (chain_id, sequence_number)
for event_id, chain_id, sequence_number in rows
} }
# Fetch all the links and pass them to the _LinkMap. # Fetch all the links and pass them to the _LinkMap.
rows = self.get_success( auth_chain_rows = cast(
List[Tuple[int, int, int, int]],
self.get_success(
self.store.db_pool.simple_select_many_batch( self.store.db_pool.simple_select_many_batch(
table="event_auth_chain_links", table="event_auth_chain_links",
column="origin_chain_id", column="origin_chain_id",
@ -449,13 +455,19 @@ class EventChainStoreTestCase(HomeserverTestCase):
), ),
keyvalues={}, keyvalues={},
) )
),
) )
link_map = _LinkMap() link_map = _LinkMap()
for row in rows: for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in auth_chain_rows:
added = link_map.add_link( added = link_map.add_link(
(row["origin_chain_id"], row["origin_sequence_number"]), (origin_chain_id, origin_sequence_number),
(row["target_chain_id"], row["target_sequence_number"]), (target_chain_id, target_sequence_number),
) )
# We shouldn't have persisted any redundant links # We shouldn't have persisted any redundant links