Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505)

This should use fewer allocations and improves type hints.
This commit is contained in:
Patrick Cloke 2023-10-26 13:01:36 -04:00 committed by GitHub
parent c14a7de6af
commit 9407d5ba78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 607 additions and 507 deletions

1
changelog.d/16505.misc Normal file
View file

@ -0,0 +1 @@
Reduce memory allocations.

View file

@ -103,10 +103,10 @@ class DeactivateAccountHandler:
# Attempt to unbind any known bound threepids to this account from identity # Attempt to unbind any known bound threepids to this account from identity
# server(s). # server(s).
bound_threepids = await self.store.user_get_bound_threepids(user_id) bound_threepids = await self.store.user_get_bound_threepids(user_id)
for threepid in bound_threepids: for medium, address in bound_threepids:
try: try:
result = await self._identity_handler.try_unbind_threepid( result = await self._identity_handler.try_unbind_threepid(
user_id, threepid["medium"], threepid["address"], id_server user_id, medium, address, id_server
) )
except Exception: except Exception:
# Do we want this to be a fatal error or should we carry on? # Do we want this to be a fatal error or should we carry on?

View file

@ -1206,10 +1206,7 @@ class SsoHandler:
# We have no guarantee that all the devices of that session are for the same # We have no guarantee that all the devices of that session are for the same
# `user_id`. Hence, we have to iterate over the list of devices and log them out # `user_id`. Hence, we have to iterate over the list of devices and log them out
# one by one. # one by one.
for device in devices: for user_id, device_id in devices:
user_id = device["user_id"]
device_id = device["device_id"]
# If the user_id associated with that device/session is not the one we got # If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error. # out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id: if expected_user_id is not None and user_id != expected_user_id:

View file

@ -606,13 +606,16 @@ class DatabasePool:
If the background updates have not completed, wait 15 sec and check again. If the background updates have not completed, wait 15 sec and check again.
""" """
updates = await self.simple_select_list( updates = cast(
List[Tuple[str]],
await self.simple_select_list(
"background_updates", "background_updates",
keyvalues=None, keyvalues=None,
retcols=["update_name"], retcols=["update_name"],
desc="check_background_updates", desc="check_background_updates",
),
) )
background_update_names = [x["update_name"] for x in updates] background_update_names = [x[0] for x in updates]
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in background_update_names: if update_name not in background_update_names:
@ -1804,9 +1807,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]], keyvalues: Optional[Dict[str, Any]],
retcols: Collection[str], retcols: Collection[str],
desc: str = "simple_select_list", desc: str = "simple_select_list",
) -> List[Dict[str, Any]]: ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of tuples.
Args: Args:
table: the table name table: the table name
@ -1817,8 +1820,7 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
Returns: Returns:
A list of dictionaries, one per result row, each a mapping between the A list of tuples, one per result row, each the retcolumn's value for the row.
column names from `retcols` and that column's value for the row.
""" """
return await self.runInteraction( return await self.runInteraction(
desc, desc,
@ -1836,9 +1838,9 @@ class DatabasePool:
table: str, table: str,
keyvalues: Optional[Dict[str, Any]], keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str], retcols: Iterable[str],
) -> List[Dict[str, Any]]: ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of tuples.
Args: Args:
txn: Transaction object txn: Transaction object
@ -1849,8 +1851,7 @@ class DatabasePool:
retcols: the names of the columns to return retcols: the names of the columns to return
Returns: Returns:
A list of dictionaries, one per result row, each a mapping between the A list of tuples, one per result row, each the retcolumn's value for the row.
column names from `retcols` and that column's value for the row.
""" """
if keyvalues: if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % ( sql = "SELECT %s FROM %s WHERE %s" % (
@ -1863,7 +1864,7 @@ class DatabasePool:
sql = "SELECT %s FROM %s" % (", ".join(retcols), table) sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql) txn.execute(sql)
return cls.cursor_to_dict(txn) return txn.fetchall()
async def simple_select_many_batch( async def simple_select_many_batch(
self, self,

View file

@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_room_txn( def get_account_data_for_room_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Dict[str, JsonDict]: ) -> Dict[str, JsonMapping]:
rows = self.db_pool.simple_select_list_txn( rows = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_list_txn(
txn, txn,
"room_account_data", table="room_account_data",
{"user_id": user_id, "room_id": room_id}, keyvalues={"user_id": user_id, "room_id": room_id},
["account_data_type", "content"], retcols=["account_data_type", "content"],
),
) )
return { return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows account_data_type: db_to_json(content)
for account_data_type, content in rows
} }
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(

View file

@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
Returns: Returns:
A list of ApplicationServices, which may be empty. A list of ApplicationServices, which may be empty.
""" """
results = await self.db_pool.simple_select_list( results = cast(
"application_services_state", {"state": state.value}, ["as_id"] List[Tuple[str]],
await self.db_pool.simple_select_list(
table="application_services_state",
keyvalues={"state": state.value},
retcols=("as_id",),
),
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services() as_list = self.get_app_services()
services = [] services = []
for res in results: for (as_id,) in results:
for service in as_list: for service in as_list:
if service.id == res["as_id"]: if service.id == as_id:
services.append(service) services.append(service)
return services return services

View file

@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
if device_id is not None: if device_id is not None:
keyvalues["device_id"] = device_id keyvalues["device_id"] = device_id
res = await self.db_pool.simple_select_list( res = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
await self.db_pool.simple_select_list(
table="devices", table="devices",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
) )
return { return {
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo( (user_id, device_id): DeviceLastConnectionInfo(
user_id=d["user_id"], user_id=user_id,
device_id=d["device_id"], device_id=device_id,
ip=d["ip"], ip=ip,
user_agent=d["user_agent"], user_agent=user_agent,
last_seen=d["last_seen"], last_seen=last_seen,
) )
for d in res for user_id, ip, user_agent, device_id, last_seen in res
} }
async def _get_user_ip_and_agents_from_database( async def _get_user_ip_and_agents_from_database(

View file

@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
allow_none=True, allow_none=True,
) )
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: async def get_devices_by_user(
self, user_id: str
) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices """Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden. that are not marked as hidden.
@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id: user_id:
Returns: Returns:
A mapping from device_id to a dict containing "device_id", "user_id" A mapping from device_id to a dict containing "device_id", "user_id"
and "display_name" for each device. and "display_name" for each device. Display name may be null.
""" """
devices = await self.db_pool.simple_select_list( devices = cast(
List[Tuple[str, str, Optional[str]]],
await self.db_pool.simple_select_list(
table="devices", table="devices",
keyvalues={"user_id": user_id, "hidden": False}, keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"), retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user", desc="get_devices_by_user",
),
) )
return {d["device_id"]: d for d in devices} return {
d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
for d in devices
}
async def get_devices_by_auth_provider_session_id( async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str self, auth_provider_id: str, auth_provider_session_id: str
) -> List[Dict[str, Any]]: ) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID. """Retrieve the list of devices associated with a SSO IdP session ID.
Args: Args:
@ -313,7 +321,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
Returns: Returns:
A list of dicts containing the device_id and the user_id of each device A list of dicts containing the device_id and the user_id of each device
""" """
return await self.db_pool.simple_select_list( return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_auth_providers", table="device_auth_providers",
keyvalues={ keyvalues={
"auth_provider_id": auth_provider_id, "auth_provider_id": auth_provider_id,
@ -321,6 +331,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
}, },
retcols=("user_id", "device_id"), retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id", desc="get_devices_by_auth_provider_session_id",
),
) )
@trace @trace
@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_cached_devices_for_user( async def get_cached_devices_for_user(
self, user_id: str self, user_id: str
) -> Mapping[str, JsonMapping]: ) -> Mapping[str, JsonMapping]:
devices = await self.db_pool.simple_select_list( devices = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("device_id", "content"), retcols=("device_id", "content"),
desc="get_cached_devices_for_user", desc="get_cached_devices_for_user",
),
) )
return { return {device[0]: db_to_json(device[1]) for device in devices}
device["device_id"]: db_to_json(device["content"]) for device in devices
}
def get_cached_device_list_changes( def get_cached_device_list_changes(
self, self,
@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync. The IDs of users whose device lists need resync.
""" """
if user_ids: if user_ids:
row_tuples = cast( rows = cast(
List[Tuple[str]], List[Tuple[str]],
await self.db_pool.simple_select_many_batch( await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync", table="device_lists_remote_resync",
@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable", desc="get_user_ids_requiring_device_list_resync_with_iterable",
), ),
) )
return {row[0] for row in row_tuples}
else: else:
rows = cast( rows = cast(
List[Dict[str, str]], List[Tuple[str]],
await self.db_pool.simple_select_list( await self.db_pool.simple_select_list(
table="device_lists_remote_resync", table="device_lists_remote_resync",
keyvalues=None, keyvalues=None,
@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
), ),
) )
return {row["user_id"] for row in rows} return {row[0] for row in rows}
async def mark_remote_users_device_caches_as_stale( async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection self, user_ids: StrCollection

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict from typing_extensions import Literal, TypedDict
@ -274,11 +274,12 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
if session_id: if session_id:
keyvalues["session_id"] = session_id keyvalues["session_id"] = session_id
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str, int, int, int, str]],
await self.db_pool.simple_select_list(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=( retcols=(
"user_id",
"room_id", "room_id",
"session_id", "session_id",
"first_message_index", "first_message_index",
@ -287,19 +288,27 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
"session_data", "session_data",
), ),
desc="get_e2e_room_keys", desc="get_e2e_room_keys",
),
) )
sessions: Dict[ sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}} ] = {"rooms": {}}
for row in rows: for (
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) room_id,
room_entry["sessions"][row["session_id"]] = { session_id,
"first_message_index": row["first_message_index"], first_message_index,
"forwarded_count": row["forwarded_count"], forwarded_count,
is_verified,
session_data,
) in rows:
room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
room_entry["sessions"][session_id] = {
"first_message_index": first_message_index,
"forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean # is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]), "is_verified": bool(is_verified),
"session_data": db_to_json(row["session_data"]), "session_data": db_to_json(session_data),
} }
return sessions return sessions

View file

@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# keeping only the forward extremities (i.e. the events not referenced # keeping only the forward extremities (i.e. the events not referenced
# by other events in the queue). We do this so that we can always # by other events in the queue). We do this so that we can always
# backpaginate in all the events we have dropped. # backpaginate in all the events we have dropped.
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="federation_inbound_events_staging", table="federation_inbound_events_staging",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=("event_id", "event_json"), retcols=("event_id", "event_json"),
desc="prune_staged_events_in_room_fetch", desc="prune_staged_events_in_room_fetch",
),
) )
# Find the set of events referenced by those in the queue, as well as # Find the set of events referenced by those in the queue, as well as
# collecting all the event IDs in the queue. # collecting all the event IDs in the queue.
referenced_events: Set[str] = set() referenced_events: Set[str] = set()
seen_events: Set[str] = set() seen_events: Set[str] = set()
for row in rows: for event_id, event_json in rows:
event_id = row["event_id"]
seen_events.add(event_id) seen_events.add(event_id)
event_d = db_to_json(row["event_json"]) event_d = db_to_json(event_json)
# We don't bother parsing the dicts into full blown event objects, # We don't bother parsing the dicts into full blown event objects,
# as that is needlessly expensive. # as that is needlessly expensive.

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, FrozenSet from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
Returns: Returns:
the features currently enabled for the user the features currently enabled for the user
""" """
enabled = await self.db_pool.simple_select_list( enabled = cast(
"per_user_experimental_features", List[Tuple[str]],
{"user_id": user_id, "enabled": True}, await self.db_pool.simple_select_list(
["feature"], table="per_user_experimental_features",
keyvalues={"user_id": user_id, "enabled": True},
retcols=("feature",),
),
) )
return frozenset(feature["feature"] for feature in enabled) return frozenset(feature[0] for feature in enabled)
async def set_features_for_user( async def set_features_for_user(
self, self,

View file

@ -248,7 +248,9 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent. If we have multiple entries for a given key ID, returns the most recent.
""" """
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
await self.db_pool.simple_select_list(
table="server_keys_json", table="server_keys_json",
keyvalues={"server_name": server_name}, keyvalues={"server_name": server_name},
retcols=( retcols=(
@ -259,6 +261,7 @@ class KeyStore(CacheInvalidationWorkerStore):
"key_json", "key_json",
), ),
desc="get_server_keys_json_for_remote", desc="get_server_keys_json_for_remote",
),
) )
if not rows: if not rows:
@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
# We sort the rows by ts_added_ms so that the most recently added entry # We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary. # will stomp over older entries in the dictionary.
rows.sort(key=lambda r: r["ts_added_ms"]) rows.sort(key=lambda r: r[2])
return { return {
row["key_id"]: FetchKeyResultForRemote( key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview. # Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]), key_json=bytes(key_json),
valid_until_ts=row["ts_valid_until_ms"], valid_until_ts=ts_valid_until_ms,
added_ts=row["ts_added_ms"], added_ts=ts_added_ms,
) )
for row in rows for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
} }

View file

@ -437,7 +437,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]: async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[int, int, str, str, int]],
await self.db_pool.simple_select_list(
"local_media_repository_thumbnails", "local_media_repository_thumbnails",
{"media_id": media_id}, {"media_id": media_id},
( (
@ -448,14 +450,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_length", "thumbnail_length",
), ),
desc="get_local_media_thumbnails", desc="get_local_media_thumbnails",
),
) )
return [ return [
ThumbnailInfo( ThumbnailInfo(
width=row["thumbnail_width"], width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
) )
for row in rows for row in rows
] ]
@ -568,7 +567,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails( async def get_remote_media_thumbnails(
self, origin: str, media_id: str self, origin: str, media_id: str
) -> List[ThumbnailInfo]: ) -> List[ThumbnailInfo]:
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[int, int, str, str, int]],
await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (
@ -579,14 +580,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_length", "thumbnail_length",
), ),
desc="get_remote_media_thumbnails", desc="get_remote_media_thumbnails",
),
) )
return [ return [
ThumbnailInfo( ThumbnailInfo(
width=row["thumbnail_width"], width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
) )
for row in rows for row in rows
] ]

View file

@ -179,11 +179,12 @@ class PushRulesWorkerStore(
@cached(max_entries=5000) @cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, int, int, str, str]],
await self.db_pool.simple_select_list(
table="push_rules", table="push_rules",
keyvalues={"user_name": user_id}, keyvalues={"user_name": user_id},
retcols=( retcols=(
"user_name",
"rule_id", "rule_id",
"priority_class", "priority_class",
"priority", "priority",
@ -191,34 +192,31 @@ class PushRulesWorkerStore(
"actions", "actions",
), ),
desc="get_push_rules_for_user", desc="get_push_rules_for_user",
),
) )
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) # Sort by highest priority_class, then highest priority.
rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
enabled_map = await self.get_push_rules_enabled_for_user(user_id) enabled_map = await self.get_push_rules_enabled_for_user(user_id)
return _load_rules( return _load_rules(
[ [(row[0], row[1], row[3], row[4]) for row in rows],
(
row["rule_id"],
row["priority_class"],
row["conditions"],
row["actions"],
)
for row in rows
],
enabled_map, enabled_map,
self.hs.config.experimental, self.hs.config.experimental,
) )
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list( results = cast(
List[Tuple[str, Optional[Union[int, bool]]]],
await self.db_pool.simple_select_list(
table="push_rules_enable", table="push_rules_enable",
keyvalues={"user_name": user_id}, keyvalues={"user_name": user_id},
retcols=("rule_id", "enabled"), retcols=("rule_id", "enabled"),
desc="get_push_rules_enabled_for_user", desc="get_push_rules_enabled_for_user",
),
) )
return {r["rule_id"]: bool(r["enabled"]) for r in results} return {r[0]: bool(r[1]) for r in results}
async def have_push_rules_changed_for_user( async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int self, user_id: str, last_id: int

View file

@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore):
async def get_throttle_params_by_room( async def get_throttle_params_by_room(
self, pusher_id: int self, pusher_id: int
) -> Dict[str, ThrottleParams]: ) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list( res = cast(
List[Tuple[str, Optional[int], Optional[int]]],
await self.db_pool.simple_select_list(
"pusher_throttle", "pusher_throttle",
{"pusher": pusher_id}, {"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"], ["room_id", "last_sent_ts", "throttle_ms"],
desc="get_throttle_params_by_room", desc="get_throttle_params_by_room",
),
) )
params_by_room = {} params_by_room = {}
for row in res: for room_id, last_sent_ts, throttle_ms in res:
params_by_room[row["room_id"]] = ThrottleParams( params_by_room[room_id] = ThrottleParams(
row["last_sent_ts"], last_sent_ts or 0, throttle_ms or 0
row["throttle_ms"],
) )
return params_by_room return params_by_room

View file

@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns: Returns:
Tuples of (auth_provider, external_id) Tuples of (auth_provider, external_id)
""" """
res = await self.db_pool.simple_select_list( return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="user_external_ids", table="user_external_ids",
keyvalues={"user_id": mxid}, keyvalues={"user_id": mxid},
retcols=("auth_provider", "external_id"), retcols=("auth_provider", "external_id"),
desc="get_external_ids_by_user", desc="get_external_ids_by_user",
),
) )
return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self) -> int: async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""
@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]: async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
results = await self.db_pool.simple_select_list( results = cast(
List[Tuple[str, str, int, int]],
await self.db_pool.simple_select_list(
"user_threepids", "user_threepids",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["medium", "address", "validated_at", "added_at"], retcols=["medium", "address", "validated_at", "added_at"],
desc="user_get_threepids", desc="user_get_threepids",
),
) )
return [ThreepidResult(**r) for r in results] return [
ThreepidResult(
medium=r[0],
address=r[1],
validated_at=r[2],
added_at=r[3],
)
for r in results
]
async def user_delete_threepid( async def user_delete_threepid(
self, user_id: str, medium: str, address: str self, user_id: str, medium: str, address: str
@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="add_user_bound_threepid", desc="add_user_bound_threepid",
) )
async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
"""Get the threepids that a user has bound to an identity server through the homeserver """Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids. method can retrieve those threepids.
@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id: The ID of the user to retrieve threepids for user_id: The ID of the user to retrieve threepids for
Returns: Returns:
List of dictionaries containing the following keys: List of tuples of two strings:
medium (str): The medium of the threepid (e.g "email") medium: The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com") address: The address of the threepid (e.g "bob@example.com")
""" """
return await self.db_pool.simple_select_list( return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="user_threepid_id_server", table="user_threepid_id_server",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["medium", "address"], retcols=["medium", "address"],
desc="user_get_bound_threepids", desc="user_get_bound_threepids",
),
) )
async def remove_user_bound_threepid( async def remove_user_bound_threepid(

View file

@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_txn( def get_all_relation_ids_for_event_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[str]: ) -> List[str]:
rows = self.db_pool.simple_select_list_txn( rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_list_txn(
txn=txn, txn=txn,
table="event_relations", table="event_relations",
keyvalues={"relates_to_id": event_id}, keyvalues={"relates_to_id": event_id},
retcols=["event_id"], retcols=["event_id"],
),
) )
return [row["event_id"] for row in rows] return [row[0] for row in rows]
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event", desc="get_all_relation_ids_for_event",

View file

@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
""" """
room_servers: Dict[str, PartialStateResyncInfo] = {} room_servers: Dict[str, PartialStateResyncInfo] = {}
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="partial_state_rooms", table="partial_state_rooms",
keyvalues={}, keyvalues={},
retcols=("room_id", "joined_via"), retcols=("room_id", "joined_via"),
desc="get_server_which_served_partial_join", desc="get_server_which_served_partial_join",
),
) )
for row in rows: for room_id, joined_via in rows:
room_id = row["room_id"]
joined_via = row["joined_via"]
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
"partial_state_rooms_servers", "partial_state_rooms_servers",
keyvalues=None, keyvalues=None,
retcols=("room_id", "server_name"), retcols=("room_id", "server_name"),
desc="get_partial_state_rooms", desc="get_partial_state_rooms",
),
) )
for row in rows: for room_id, server_name in rows:
room_id = row["room_id"]
server_name = row["server_name"]
entry = room_servers.get(room_id) entry = room_servers.get(room_id)
if entry is None: if entry is None:
# There is a foreign key constraint which enforces that every room_id in # There is a foreign key constraint which enforces that every room_id in

View file

@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for fully-joined rooms. for fully-joined rooms.
""" """
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, Optional[str]]],
await self.db_pool.simple_select_list(
"current_state_events", "current_state_events",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=("event_id", "membership"), retcols=("event_id", "membership"),
desc="has_completed_background_updates", desc="has_completed_background_updates",
),
) )
return {row["event_id"]: row["membership"] for row in rows} return dict(rows)
# TODO This returns a mutable object, which is generally confusing when using a cache. # TODO This returns a mutable object, which is generally confusing when using a cache.
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable] @cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]

View file

@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag content. tag content.
""" """
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str, str]],
await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
),
) )
tags_by_room: Dict[str, Dict[str, JsonDict]] = {} tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows: for room_id, tag, content in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {}) room_tags = tags_by_room.setdefault(room_id, {})
room_tags[row["tag"]] = db_to_json(row["content"]) room_tags[tag] = db_to_json(content)
return tags_by_room return tags_by_room
async def get_all_updated_tags( async def get_all_updated_tags(
@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns: Returns:
A mapping of tags to tag content. A mapping of tags to tag content.
""" """
rows = await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="room_tags", table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id}, keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"), retcols=("tag", "content"),
desc="get_tags_for_room", desc="get_tags_for_room",
),
) )
return {row["tag"]: db_to_json(row["content"]) for row in rows} return {tag: db_to_json(content) for tag, content in rows}
async def add_tag_to_room( async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict self, user_id: str, room_id: str, tag: str, content: JsonDict

View file

@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type. that auth-type.
""" """
results = {} results = {}
for row in await self.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="ui_auth_sessions_credentials", table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
retcols=("stage_type", "result"), retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages", desc="get_completed_ui_auth_stages",
): ),
results[row["stage_type"]] = db_to_json(row["result"]) )
for stage_type, result in rows:
results[stage_type] = db_to_json(result)
return results return results
@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
Returns: Returns:
List of user_agent/ip pairs List of user_agent/ip pairs
""" """
rows = await self.db_pool.simple_select_list( return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="ui_auth_sessions_ips", table="ui_auth_sessions_ips",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
retcols=("user_agent", "ip"), retcols=("user_agent", "ip"),
desc="get_user_agents_ips_to_ui_auth_session", desc="get_user_agents_ips_to_ui_auth_session",
),
) )
return [(row["user_agent"], row["ip"]) for row in rows]
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
""" """

View file

@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group: if not prev_group:
return _GetStateGroupDelta(None, None) return _GetStateGroupDelta(None, None)
delta_ids = self.db_pool.simple_select_list_txn( delta_ids = cast(
List[Tuple[str, str, str]],
self.db_pool.simple_select_list_txn(
txn, txn,
table="state_groups_state", table="state_groups_state",
keyvalues={"state_group": state_group}, keyvalues={"state_group": state_group},
retcols=("type", "state_key", "event_id"), retcols=("type", "state_key", "event_id"),
),
) )
return _GetStateGroupDelta( return _GetStateGroupDelta(
prev_group, prev_group,
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, {
(event_type, state_key): event_id
for event_type, state_key, event_id in delta_ids
},
) )
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
) )
) )
async def get_all_room_state(self) -> List[Dict[str, Any]]: async def get_all_room_state(self) -> List[Optional[str]]:
return await self.store.db_pool.simple_select_list( rows = cast(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias") List[Tuple[Optional[str]]],
await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("topic",)
),
) )
return [r[0] for r in rows]
def _get_current_stats( def _get_current_stats(
self, stats_type: str, stat_id: str self, stats_type: str, stat_id: str
@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r = self.get_success(self.get_all_room_state()) r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 1) self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo") self.assertEqual(r[0], "foo")
def test_create_user(self) -> None: def test_create_user(self) -> None:
""" """

View file

@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
if expected_row is not None: if expected_row is not None:
columns += expected_row.keys() columns += expected_row.keys()
rows = self.get_success( row_tuples = self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table=table, table=table,
keyvalues={ keyvalues={
@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
if expected_row is not None: if expected_row is not None:
self.assertEqual( self.assertEqual(
len(rows), len(row_tuples),
1, 1,
f"Background update did not leave behind latest receipt in {table}", f"Background update did not leave behind latest receipt in {table}",
) )
self.assertEqual( self.assertEqual(
rows[0], row_tuples[0],
{ (
"room_id": room_id, room_id,
"receipt_type": receipt_type, receipt_type,
"user_id": user_id, user_id,
**expected_row, *expected_row.values(),
}, ),
) )
else: else:
self.assertEqual( self.assertEqual(
len(rows), len(row_tuples),
0, 0,
f"Background update did not remove all duplicate receipts from {table}", f"Background update did not remove all duplicate receipts from {table}",
) )

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import secrets import secrets
from typing import Generator, Tuple from typing import Generator, List, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
) )
def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]: def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
res = self.get_success( yield from cast(
List[Tuple[int, str, str]],
self.get_success(
self.storage.db_pool.simple_select_list( self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"] self.table_name, None, ["id, username, value"]
) )
),
) )
for i in res:
yield (i["id"], i["username"], i["value"])
def test_upsert_many(self) -> None: def test_upsert_many(self) -> None:
""" """
Upsert_many will perform the upsert operation across a batch of data. Upsert_many will perform the upsert operation across a batch of data.

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple, cast
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
import yaml import yaml
@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
self.wait_for_background_updates() self.wait_for_background_updates()
# Check the correct values are in the new table. # Check the correct values are in the new table.
rows = self.get_success( rows = cast(
List[Tuple[int, int]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="test_constraint", table="test_constraint",
keyvalues={}, keyvalues={},
retcols=("a", "b"), retcols=("a", "b"),
) )
),
) )
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) self.assertCountEqual(rows, [(1, 1), (3, 3)])
# And check that invalid rows get correctly rejected. # And check that invalid rows get correctly rejected.
self.get_failure( self.get_failure(
@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
self.wait_for_background_updates() self.wait_for_background_updates()
# Check the correct values are in the new table. # Check the correct values are in the new table.
rows = self.get_success( rows = cast(
List[Tuple[int, int]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="test_constraint", table="test_constraint",
keyvalues={}, keyvalues={},
retcols=("a", "b"), retcols=("a", "b"),
) )
),
) )
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) self.assertCountEqual(rows, [(1, 1), (3, 3)])
# And check that invalid rows get correctly rejected. # And check that invalid rows get correctly rejected.
self.get_failure( self.get_failure(

View file

@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3 self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)]
self.mock_txn.description = (("colA", None, None, None, None, None, None),) self.mock_txn.description = (("colA", None, None, None, None, None, None),)
ret = yield defer.ensureDeferred( ret = yield defer.ensureDeferred(
@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.assertEqual([(1,), (2,), (3,)], ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"] "SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
) )

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Any, Dict, List, Optional, Tuple, cast
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from parameterized import parameterized from parameterized import parameterized
@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200) self.reactor.advance(200)
self.pump(0) self.pump(0)
result = self.get_success( result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents", desc="get_user_ip_and_agents",
) )
),
) )
self.assertEqual( self.assertEqual(
result, result, [("access_token", "ip", "user_agent", None, 12345678000)]
[
{
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": None,
"last_seen": 12345678000,
}
],
) )
# Add another & trigger the storage loop # Add another & trigger the storage loop
@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10) self.reactor.advance(10)
self.pump(0) self.pump(0)
result = self.get_success( result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents", desc="get_user_ip_and_agents",
) )
),
) )
# Only one result, has been upserted. # Only one result, has been upserted.
self.assertEqual( self.assertEqual(
result, result, [("access_token", "ip", "user_agent", None, 12345878000)]
[
{
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": None,
"last_seen": 12345878000,
}
],
) )
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10) self.reactor.advance(10)
else: else:
# Check that the new IP and user agent has not been stored yet # Check that the new IP and user agent has not been stored yet
db_result = self.get_success( db_result = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="devices", table="devices",
keyvalues={}, keyvalues={},
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), retcols=(
"user_id",
"ip",
"user_agent",
"device_id",
"last_seen",
),
),
), ),
) )
self.assertEqual( self.assertEqual(db_result, [(user_id, None, None, device_id, None)])
db_result,
[
{
"user_id": user_id,
"device_id": device_id,
"ip": None,
"user_agent": None,
"last_seen": None,
},
],
)
result = self.get_success( result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id) self.store.get_last_client_ip_by_device(user_id, device_id)
@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
# Check that the new IP and user agent has not been stored yet # Check that the new IP and user agent has not been stored yet
db_result = self.get_success( db_result = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="devices", table="devices",
keyvalues={}, keyvalues={},
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
), ),
),
) )
self.assertCountEqual( self.assertCountEqual(
db_result, db_result,
[ [
{ (user_id, "ip_1", "user_agent_1", device_id_1, 12345678000),
"user_id": user_id, (user_id, "ip_2", "user_agent_2", device_id_2, 12345678000),
"device_id": device_id_1,
"ip": "ip_1",
"user_agent": "user_agent_1",
"last_seen": 12345678000,
},
{
"user_id": user_id,
"device_id": device_id_2,
"ip": "ip_2",
"user_agent": "user_agent_2",
"last_seen": 12345678000,
},
], ],
) )
@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
# Check that the new IP and user agent has not been stored yet # Check that the new IP and user agent has not been stored yet
db_result = self.get_success( db_result = cast(
List[Tuple[str, str, str, int]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="user_ips", table="user_ips",
keyvalues={}, keyvalues={},
retcols=("access_token", "ip", "user_agent", "last_seen"), retcols=("access_token", "ip", "user_agent", "last_seen"),
), ),
),
) )
self.assertEqual( self.assertEqual(
db_result, db_result,
[ [
{ ("access_token", "ip_1", "user_agent_1", 12345678000),
"access_token": "access_token", ("access_token", "ip_2", "user_agent_2", 12345678000),
"ip": "ip_1",
"user_agent": "user_agent_1",
"last_seen": 12345678000,
},
{
"access_token": "access_token",
"ip": "ip_2",
"user_agent": "user_agent_2",
"last_seen": 12345678000,
},
], ],
) )
@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200) self.reactor.advance(200)
# We should see that in the DB # We should see that in the DB
result = self.get_success( result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents", desc="get_user_ip_and_agents",
) )
),
) )
self.assertEqual( self.assertEqual(
result, result,
[ [("access_token", "ip", "user_agent", device_id, 0)],
{
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": device_id,
"last_seen": 0,
}
],
) )
# Now advance by a couple of months # Now advance by a couple of months
self.reactor.advance(60 * 24 * 60 * 60) self.reactor.advance(60 * 24 * 60 * 60)
# We should get no results. # We should get no results.
result = self.get_success( result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents", desc="get_user_ip_and_agents",
) )
),
) )
self.assertEqual(result, []) self.assertEqual(result, [])
@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200) self.reactor.advance(200)
# We should see that in the DB # We should see that in the DB
result = self.get_success( result = cast(
List[Tuple[str, str, str, Optional[str], int]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="user_ips", table="user_ips",
keyvalues={}, keyvalues={},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], retcols=[
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
],
desc="get_user_ip_and_agents", desc="get_user_ip_and_agents",
) )
),
) )
# ensure user1 is filtered out # ensure user1 is filtered out
self.assertEqual( self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)])
result,
[
{
"access_token": access_token2,
"ip": "ip",
"user_agent": "user_agent",
"device_id": device_id2,
"last_seen": 0,
}
],
)
class ClientIpAuthTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase):

View file

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import Membership from synapse.api.constants import Membership
@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test__null_byte_in_display_name_properly_handled(self) -> None: def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
res = self.get_success( res = cast(
List[Tuple[Optional[str], str]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
"room_memberships", "room_memberships",
{"user_id": "@alice:test"}, {"user_id": "@alice:test"},
["display_name", "event_id"], ["display_name", "event_id"],
) )
),
) )
# Check that we only got one result back # Check that we only got one result back
self.assertEqual(len(res), 1) self.assertEqual(len(res), 1)
# Check that alice's display name is "alice" # Check that alice's display name is "alice"
self.assertEqual(res[0]["display_name"], "alice") self.assertEqual(res[0][0], "alice")
# Grab the event_id to use later # Grab the event_id to use later
event_id = res[0]["event_id"] event_id = res[0][1]
# Create a profile with the offending null byte in the display name # Create a profile with the offending null byte in the display name
new_profile = {"displayname": "ali\u0000ce"} new_profile = {"displayname": "ali\u0000ce"}
@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
tok=self.t_alice, tok=self.t_alice,
) )
res2 = self.get_success( res2 = cast(
List[Tuple[Optional[str], str]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
"room_memberships", "room_memberships",
{"user_id": "@alice:test"}, {"user_id": "@alice:test"},
["display_name", "event_id"], ["display_name", "event_id"],
) )
),
) )
# Check that we only have two results # Check that we only have two results
self.assertEqual(len(res2), 2) self.assertEqual(len(res2), 2)
# Filter out the previous event using the event_id we grabbed above # Filter out the previous event using the event_id we grabbed above
row = [row for row in res2 if row["event_id"] != event_id] row = [row for row in res2 if row[1] != event_id]
# Check that alice's display name is now None # Check that alice's display name is now None
self.assertEqual(row[0]["display_name"], None) self.assertIsNone(row[0][0])
def test_room_is_locally_forgotten(self) -> None: def test_room_is_locally_forgotten(self) -> None:
"""Test that when the last local user has forgotten a room it is known as forgotten.""" """Test that when the last local user has forgotten a room it is known as forgotten."""

View file

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple, cast
from immutabledict import immutabledict from immutabledict import immutabledict
@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
) )
# check that only state events are in state_groups, and all state events are in state_groups # check that only state events are in state_groups, and all state events are in state_groups
res = self.get_success( res = cast(
List[Tuple[str]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="state_groups", table="state_groups",
keyvalues=None, keyvalues=None,
retcols=("event_id",), retcols=("event_id",),
) )
),
) )
events = [] events = []
for result in res: for result in res:
self.assertNotIn(event3.event_id, result) self.assertNotIn(event3.event_id, result) # XXX
events.append(result.get("event_id")) events.append(result[0])
for event, _ in processed_events_and_context: for event, _ in processed_events_and_context:
if event.is_state(): if event.is_state():
@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
# has an entry and prev event in state_group_edges # has an entry and prev event in state_group_edges
for event, context in processed_events_and_context: for event, context in processed_events_and_context:
if event.is_state(): if event.is_state():
state = self.get_success( state = cast(
List[Tuple[str, str]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="state_groups_state", table="state_groups_state",
keyvalues={"state_group": context.state_group_after_event}, keyvalues={"state_group": context.state_group_after_event},
retcols=("type", "state_key"), retcols=("type", "state_key"),
) )
),
) )
self.assertEqual(event.type, state[0].get("type")) self.assertEqual(event.type, state[0][0])
self.assertEqual(event.state_key, state[0].get("state_key")) self.assertEqual(event.state_key, state[0][1])
groups = self.get_success( groups = cast(
List[Tuple[str]],
self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
table="state_group_edges", table="state_group_edges",
keyvalues={"state_group": str(context.state_group_after_event)}, keyvalues={
retcols=("*",), "state_group": str(context.state_group_after_event)
},
retcols=("prev_state_group",),
) )
),
) )
self.assertEqual( self.assertEqual(context.state_group_before_event, groups[0][0])
context.state_group_before_event, groups[0].get("prev_state_group")
)

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import Any, Dict, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple, cast
from unittest import mock from unittest import mock
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -62,14 +62,13 @@ class GetUserDirectoryTables:
Returns a list of tuples (user_id, room_id) where room_id is public and Returns a list of tuples (user_id, room_id) where room_id is public and
contains the user with the given id. contains the user with the given id.
""" """
r = await self.store.db_pool.simple_select_list( r = cast(
List[Tuple[str, str]],
await self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id") "users_in_public_rooms", None, ("user_id", "room_id")
),
) )
return set(r)
retval = set()
for i in r:
retval.add((i["user_id"], i["room_id"]))
return retval
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]: async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
"""Fetch the entire `users_who_share_private_rooms` table. """Fetch the entire `users_who_share_private_rooms` table.
@ -78,27 +77,30 @@ class GetUserDirectoryTables:
to the rows of `users_who_share_private_rooms`. to the rows of `users_who_share_private_rooms`.
""" """
rows = await self.store.db_pool.simple_select_list( rows = cast(
List[Tuple[str, str, str]],
await self.store.db_pool.simple_select_list(
"users_who_share_private_rooms", "users_who_share_private_rooms",
None, None,
["user_id", "other_user_id", "room_id"], ["user_id", "other_user_id", "room_id"],
),
) )
rv = set() return set(rows)
for row in rows:
rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
return rv
async def get_users_in_user_directory(self) -> Set[str]: async def get_users_in_user_directory(self) -> Set[str]:
"""Fetch the set of users in the `user_directory` table. """Fetch the set of users in the `user_directory` table.
This is useful when checking we've correctly excluded users from the directory. This is useful when checking we've correctly excluded users from the directory.
""" """
result = await self.store.db_pool.simple_select_list( result = cast(
List[Tuple[str]],
await self.store.db_pool.simple_select_list(
"user_directory", "user_directory",
None, None,
["user_id"], ["user_id"],
),
) )
return {row["user_id"] for row in result} return {row[0] for row in result}
async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]: async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
"""Fetch users and their profiles from the `user_directory` table. """Fetch users and their profiles from the `user_directory` table.
@ -107,16 +109,17 @@ class GetUserDirectoryTables:
It's almost the entire contents of the `user_directory` table: the only It's almost the entire contents of the `user_directory` table: the only
thing missing is an unused room_id column. thing missing is an unused room_id column.
""" """
rows = await self.store.db_pool.simple_select_list( rows = cast(
List[Tuple[str, Optional[str], Optional[str]]],
await self.store.db_pool.simple_select_list(
"user_directory", "user_directory",
None, None,
("user_id", "display_name", "avatar_url"), ("user_id", "display_name", "avatar_url"),
),
) )
return { return {
row["user_id"]: ProfileInfo( user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url)
display_name=row["display_name"], avatar_url=row["avatar_url"] for user_id, display_name, avatar_url in rows
)
for row in rows
} }
async def get_tables( async def get_tables(