Remove more usages of cursor_to_dict. (#16551)

Mostly to improve type safety.
This commit is contained in:
Patrick Cloke 2023-10-26 15:12:28 -04:00 committed by GitHub
parent 85e5f2dc25
commit 679c691f6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 193 additions and 134 deletions

1
changelog.d/16551.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -19,6 +19,8 @@ import logging
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
import attr
from synapse.api.errors import (
CodeMessageException,
Codes,
@ -357,9 +359,9 @@ class IdentityHandler:
# Check to see if a session already exists and that it is not yet
# marked as validated
if session and session.get("validated_at") is None:
session_id = session["session_id"]
last_send_attempt = session["last_send_attempt"]
if session and session.validated_at is None:
session_id = session.session_id
last_send_attempt = session.last_send_attempt
# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
@ -480,7 +482,6 @@ class IdentityHandler:
# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
validation_session = None
# Try to validate as email
if self.hs.config.email.can_verify_email:
@ -488,19 +489,18 @@ class IdentityHandler:
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
if validation_session:
return validation_session
if validation_session:
return attr.asdict(validation_session)
# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)
return validation_session
return None
async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str

View file

@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker:
if row:
threepid = {
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
"medium": row.medium,
"address": row.address,
"validated_at": row.validated_at,
}
# Valid threepid returned, delete from the db

View file

@ -949,10 +949,7 @@ class MediaRepository:
deleted = 0
for media in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
for origin, media_id, file_id in old_media:
key = (origin, media_id)
logger.info("Deleting: %r", key)

View file

@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet):
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
)
response = {"destinations": destinations, "total": total}
response = {
"destinations": [
{
"destination": r[0],
"retry_last_ts": r[1],
"retry_interval": r[2],
"failure_ts": r[3],
"last_successful_stream_ordering": r[4],
}
for r in destinations
],
"total": total,
}
if (start + limit) < total:
response["next_token"] = str(start + len(destinations))

View file

@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
result = [
{
"event_id": ex[0],
"state_group": ex[1],
"depth": ex[2],
"received_ts": ex[3],
}
for ex in extremities
]
return HTTPStatus.OK, {"count": len(extremities), "results": result}
class RoomEventContextServlet(RestServlet):

View file

@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet):
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
)
ret = {"users": users_media, "total": total}
ret = {
"users": [
{
"user_id": r[0],
"displayname": r[1],
"media_count": r[2],
"media_length": r[3],
}
for r in users_media
],
"total": total,
}
if (start + limit) < total:
ret["next_token"] = start + len(users_media)

View file

@ -35,7 +35,6 @@ from typing import (
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
@ -1047,43 +1046,20 @@ class DatabasePool:
results = [dict(zip(col_headers, row)) for row in cursor]
return results
@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...
@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...
async def execute(
self,
desc: str,
decoder: Optional[Callable[[Cursor], R]],
query: str,
*args: Any,
) -> Union[List[Tuple[Any, ...]], R]:
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set.
Args:
desc: description of the transaction, for logging and metrics
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""
def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
txn.execute(query, args)
if decoder:
return decoder(txn)
else:
return txn.fetchall()
return txn.fetchall()
return await self.runInteraction(desc, interaction)

View file

@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"""
rows = await self.db_pool.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
"_censor_redactions_fetch", sql, before_ts, 100
)
updates = []

View file

@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
rows = await self.db_pool.execute(
"get_all_devices_changed",
None,
sql,
from_key,
to_key,
@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
"get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:

View file

@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
None,
sql,
now_stream_id,
user_id,

View file

@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
# We need to pass execute a dummy function to handle the txn's result otherwise
# it tries to call fetchall() on it and fails because there's no result to fetch.
await self.db_pool.execute(
await self.db_pool.runInteraction(
"background_analyze_new_stream_ordering_column",
lambda txn: None,
"ANALYZE events(stream_ordering2)",
lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)
await self.db_pool.runInteraction(

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import Any, Dict, List
from typing import List, Optional, Tuple, cast
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
async def get_forward_extremities_for_room(
self, room_id: str
) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
) -> List[Tuple[str, int, int, Optional[int]]]:
"""
Get list of forward extremities for a room.
Returns:
A list of tuples of event_id, state_group, depth, and received_ts.
"""
def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
"""
txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",

View file

@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
) -> List[Dict[str, str]]:
) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
* The filesystem ID.
"""
sql = """
SELECT media_origin, media_id, filesystem_id
FROM remote_media_cache
WHERE last_access_ts < ?
"""
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
if include_quarantined_media is False:
# Only include media that has not been quarantined
@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL
"""
return await self.db_pool.execute(
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
return cast(
List[Tuple[str, str, str]],
await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:

View file

@ -151,6 +151,22 @@ class ThreepidResult:
added_at: int
@attr.s(frozen=True, slots=True, auto_attribs=True)
class ThreepidValidationSession:
address: str
"""address of the 3pid"""
medium: str
"""medium of the 3pid"""
client_secret: str
"""a secret provided by the client for this validation session"""
session_id: str
"""ID of the validation session"""
last_send_attempt: int
"""a number serving to dedupe send attempts for this session"""
validated_at: Optional[int]
"""timestamp of when this session was validated if so"""
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
) -> Optional[Dict[str, Any]]:
) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
perform no filtering
Returns:
A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
* session_id - ID of the validation session
* send_attempt - a number serving to dedupe send attempts for this session
* validated_at - timestamp of when this session was validated if so
Otherwise None if a validation session is not found
A ThreepidValidationSession or None if a validation session is not found
"""
if not client_secret:
raise SynapseError(
@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def get_threepid_validation_session_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
) -> Optional[ThreepidValidationSession]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
rows = self.db_pool.cursor_to_dict(txn)
if not rows:
row = txn.fetchone()
if not row:
return None
return rows[0]
return ThreepidValidationSession(
address=row[0],
session_id=row[1],
medium=row[2],
client_secret=row[3],
last_send_attempt=row[4],
validated_at=row[5],
)
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn

View file

@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
"is_host_joined", None, sql, membership, room_id, like_clause
"is_host_joined", sql, membership, room_id, like_clause
)
if not rows:
@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND forgotten = 0;
"""
rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
# `count(*)` returns always an integer
# If any rows still exist it means someone has not forgotten this room yet

View file

@ -26,6 +26,7 @@ from typing import (
Set,
Tuple,
Union,
cast,
)
import attr
@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
results = await self.db_pool.execute(
"search_msgs", self.db_pool.cursor_to_dict, sql, *args
# List of tuples of (rank, room_id, event_id).
results = cast(
List[Tuple[Union[int, float], str, str]],
await self.db_pool.execute("search_msgs", sql, *args),
)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
[r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
# List of tuples of (room_id, count).
count_results = cast(
List[Tuple[str, int]],
await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
{"event": event_map[r["event_id"]], "rank": r["rank"]}
{"event": event_map[r[2]], "rank": r[0]}
for r in results
if r["event_id"] in event_map
if r[2] in event_map
],
"highlights": highlights,
"count": count,
@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = search_term
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
origin_server_ts, stream_ordering, room_id, event_id
room_id, event_id, origin_server_ts, stream_ordering
FROM event_search
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# mypy expects to append only a `str`, not an `int`
args.append(limit)
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
# List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
results = cast(
List[Tuple[Union[int, float], str, str, int, int]],
await self.db_pool.execute("search_rooms", sql, *args),
)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
[r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
# List of tuples of (room_id, count).
count_results = cast(
List[Tuple[str, int]],
await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
"pagination_token": "%s,%s"
% (r["origin_server_ts"], r["stream_ordering"]),
"event": event_map[r[2]],
"rank": r[0],
"pagination_token": "%s,%s" % (r[3], r[4]),
}
for r in results
if r["event_id"] in event_map
if r[2] in event_map
],
"highlights": highlights,
"count": count,

View file

@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore):
order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the
total number of users matching the filter criteria.
@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
search_term: a string to filter user names by
Returns:
A list of user dicts and an integer representing the total number of
users that exist given this query
A tuple of:
A list of tuples of user information (the user ID, displayname,
total number of media, total length of media) and
An integer representing the total number of users that exist
given this query
"""
def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
filters = []
args: list = []
@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore):
args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
return users, count

View file

@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
row = await self.db_pool.execute(
"get_current_topological_token", None, sql, room_id, room_id, stream_key
"get_current_topological_token", sql, room_id, room_id, stream_key
)
return row[0][0] if row else 0
@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = await self.db_pool.execute(
"get_timeline_gaps",
None,
sql,
room_id,
from_token.stream if from_token else 0,

View file

@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[
List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
int,
]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
total number of destinations matching the filter criteria.
@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns:
A tuple of a list of mappings from destination to information
A tuple of a list of tuples of destination information:
* destination
* retry_last_ts
* retry_interval
* failure_ts
* last_successful_stream_ordering
and a count of total destinations.
"""
def get_destinations_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[
List[
Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
],
int,
]:
order_by_column = DestinationSortOrder(order_by).value
if direction == Direction.BACKWARDS:
@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
LIMIT ? OFFSET ?
"""
txn.execute(sql, args + [limit, start])
destinations = self.db_pool.cursor_to_dict(txn)
destinations = cast(
List[
Tuple[
str, Optional[int], Optional[int], Optional[int], Optional[int]
]
],
txn.fetchall(),
)
return destinations, count
return await self.db_pool.runInteraction(

View file

@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
raise Exception("Unrecognized database engine")
results = cast(
List[UserProfile],
await self.db_pool.execute(
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
),
List[Tuple[str, Optional[str], Optional[str]]],
await self.db_pool.execute("search_user_dir", sql, *args),
)
limited = len(results) > limit
return {"limited": limited, "results": results[0:limit]}
return {
"limited": limited,
"results": [
{"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
for r in results[0:limit]
],
}
def _filter_text_for_index(text: str) -> str:

View file

@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
if max_group is None:
rows = await self.db_pool.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]

View file

@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_id, stream_ordering = self.get_success(
self.hs.get_datastores().main.db_pool.execute(
"test:get_destination_rooms",
None,
"""
SELECT event_id, stream_ordering
FROM destination_rooms dr

View file

@ -457,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
);
"""
self.get_success(
self.store.db_pool.execute(
"test_not_null_constraint", lambda _: None, table_sql
self.store.db_pool.runInteraction(
"test_not_null_constraint", lambda txn: txn.execute(table_sql)
)
)
@ -466,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
# using SQLite.
index_sql = "CREATE INDEX test_index ON test_constraint(a)"
self.get_success(
self.store.db_pool.execute(
"test_not_null_constraint", lambda _: None, index_sql
self.store.db_pool.runInteraction(
"test_not_null_constraint", lambda txn: txn.execute(index_sql)
)
)
@ -574,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
);
"""
self.get_success(
self.store.db_pool.execute(
"test_foreign_key_constraint", lambda _: None, base_sql
self.store.db_pool.runInteraction(
"test_foreign_key_constraint", lambda txn: txn.execute(base_sql)
)
)
self.get_success(
self.store.db_pool.execute(
"test_foreign_key_constraint", lambda _: None, table_sql
self.store.db_pool.runInteraction(
"test_foreign_key_constraint", lambda txn: txn.execute(table_sql)
)
)

View file

@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.store.db_pool.execute(
"", None, "SELECT full_user_id from profiles ORDER BY full_user_id"
"", "SELECT full_user_id from profiles ORDER BY full_user_id"
)
)
self.assertEqual(len(res), len(expected_values))

View file

@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.store.db_pool.execute(
"", None, "SELECT full_user_id from user_filters ORDER BY full_user_id"
"", "SELECT full_user_id from user_filters ORDER BY full_user_id"
)
)
self.assertEqual(len(res), len(expected_values))