Add some type hints to datastore (#12423)

* Add some type hints to datastore

* newsfile

* change `Collection` to `List`

* refactor return type of `select_users_txn`

* correct type hint in `stream.py`

* Remove `Optional` in `select_users_txn`

* remove not needed return type in `__init__`

* Revert change in `get_stream_id_for_event_txn`

* Remove import from `Literal`
This commit is contained in:
Dirk Klimpel 2022-04-12 12:54:00 +02:00 committed by GitHub
parent 4e13743738
commit 1783156dbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 123 additions and 77 deletions

1
changelog.d/12423.misc Normal file
View file

@ -0,0 +1 @@
Add some type hints to datastore.

View file

@ -180,9 +180,9 @@ class AccountValidityHandler:
expiring_users = await self.store.get_users_expiring_soon() expiring_users = await self.store.get_users_expiring_soon()
if expiring_users: if expiring_users:
for user in expiring_users: for user_id, expiration_ts_ms in expiring_users:
await self._send_renewal_email( await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] user_id=user_id, expiration_ts=expiration_ts_ms
) )
async def send_renewal_email_to_user(self, user_id: str) -> None: async def send_renewal_email_to_user(self, user_id: str) -> None:

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
from synapse.appservice import ( from synapse.appservice import (
ApplicationService, ApplicationService,
@ -26,7 +26,11 @@ from synapse.appservice import (
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
@ -92,7 +96,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
def get_app_services(self): def get_app_services(self) -> List[ApplicationService]:
return self.services_cache return self.services_cache
def get_if_app_services_interested_in_user(self, user_id: str) -> bool: def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
@ -256,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore(
A new transaction. A new transaction.
""" """
def _create_appservice_txn(txn): def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn) new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
# Insert new txn into txn table # Insert new txn into txn table
@ -291,7 +295,7 @@ class ApplicationServiceTransactionWorkerStore(
service: The application service which was sent this transaction. service: The application service which was sent this transaction.
""" """
def _complete_appservice_txn(txn): def _complete_appservice_txn(txn: LoggingTransaction) -> None:
# Set current txn_id for AS to 'txn_id' # Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
@ -322,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore(
An AppServiceTransaction or None. An AppServiceTransaction or None.
""" """
def _get_oldest_unsent_txn(txn): def _get_oldest_unsent_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
# Monotonically increasing txn ids, so just select the smallest # Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent) # one in the txns table (we delete them when they are sent)
txn.execute( txn.execute(
@ -364,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
) )
async def set_appservice_last_pos(self, pos: int) -> None: async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn): def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
) )
@ -378,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore(
) -> Tuple[int, List[EventBase]]: ) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice""" """Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn): def get_new_events_for_appservice_txn(
txn: LoggingTransaction,
) -> Tuple[int, List[str]]:
sql = ( sql = (
"SELECT e.stream_ordering, e.event_id" "SELECT e.stream_ordering, e.event_id"
" FROM events AS e" " FROM events AS e"
@ -416,7 +424,7 @@ class ApplicationServiceTransactionWorkerStore(
% (type,) % (type,)
) )
def get_type_stream_id_for_appservice_txn(txn): def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
stream_id_type = "%s_stream_id" % type stream_id_type = "%s_stream_id" % type
txn.execute( txn.execute(
# We do NOT want to escape `stream_id_type`. # We do NOT want to escape `stream_id_type`.
@ -444,7 +452,7 @@ class ApplicationServiceTransactionWorkerStore(
% (stream_type,) % (stream_type,)
) )
def set_appservice_stream_type_pos_txn(txn): def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
stream_id_type = "%s_stream_id" % stream_type stream_id_type = "%s_stream_id" % stream_type
txn.execute( txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?" "UPDATE application_services_state SET %s = ? WHERE as_id=?"

View file

@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID, UserInfo from synapse.types import JsonDict, UserID, UserInfo
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING: if TYPE_CHECKING:
@ -79,7 +79,7 @@ class TokenLookupResult:
# Make the token owner default to the user ID, which is the common case. # Make the token owner default to the user ID, which is the common case.
@token_owner.default @token_owner.default
def _default_token_owner(self): def _default_token_owner(self) -> str:
return self.user_id return self.user_id
@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
the account. the account.
""" """
def set_account_validity_for_user_txn(txn): def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_txn( self.db_pool.simple_update_txn(
txn=txn, txn=txn,
table="account_validity", table="account_validity",
@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_renewal_token_for_user", desc="get_renewal_token_for_user",
) )
async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time """Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at window (see configuration for account_validity for information on what renew_at
refers to). refers to).
Returns: Returns:
A list of dictionaries, each with a user ID and expiration time (in milliseconds). A list of tuples, each with a user ID and expiration time (in milliseconds).
""" """
def select_users_txn(txn, now_ms, renew_at): def select_users_txn(
txn: LoggingTransaction, now_ms: int, renew_at: int
) -> List[Tuple[str, int]]:
sql = ( sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity" "SELECT user_id, expiration_ts_ms FROM account_validity"
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
) )
values = [False, now_ms, renew_at] values = [False, now_ms, renew_at]
txn.execute(sql, values) txn.execute(sql, values)
return self.db_pool.cursor_to_dict(txn) return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_expiring_soon", "get_users_expiring_soon",
@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
admin: true iff the user is to be a server admin, false otherwise. admin: true iff the user is to be a server admin, false otherwise.
""" """
def set_server_admin_txn(txn): def set_server_admin_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
) )
@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_type: type of the user or None for a user without a type. user_type: type of the user or None for a user without a type.
""" """
def set_user_type_txn(txn): def set_user_type_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"user_type": user_type} txn, "users", {"name": user.to_string()}, {"user_type": user_type}
) )
@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_user_type", set_user_type_txn) await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: def _query_for_auth(
self, txn: LoggingTransaction, token: str
) -> Optional[TokenLookupResult]:
sql = """ sql = """
SELECT users.name as user_id, SELECT users.name as user_id,
users.is_guest, users.is_guest,
@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"is_support_user", self.is_support_user_txn, user_id "is_support_user", self.is_support_user_txn, user_id
) )
def is_real_user_txn(self, txn, user_id): def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn( res = self.db_pool.simple_select_one_onecol_txn(
txn=txn, txn=txn,
table="users", table="users",
@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
return res is None return res is None
def is_support_user_txn(self, txn, user_id): def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn( res = self.db_pool.simple_select_one_onecol_txn(
txn=txn, txn=txn,
table="users", table="users",
@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A mapping of user_id -> password_hash. A mapping of user_id -> password_hash.
""" """
def f(txn): def f(txn: LoggingTransaction) -> Dict[str, str]:
sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return dict(txn) result = cast(List[Tuple[str, str]], txn.fetchall())
return dict(result)
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def _replace_user_external_id_txn( def _replace_user_external_id_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
): ) -> None:
_remove_user_external_ids_txn(txn, user_id) _remove_user_external_ids_txn(txn, user_id)
for auth_provider, external_id in record_external_ids: for auth_provider, external_id in record_external_ids:
@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
return [(r["auth_provider"], r["external_id"]) for r in res] return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self): async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""
def _count_users(txn): def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users") txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if rows: if rows:
@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
who registered on the homeserver in the past 24 hours who registered on the homeserver in the past 24 hours
""" """
def _count_daily_user_type(txn): def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
yesterday = int(self._clock.time()) - (60 * 60 * 24) yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """ sql = """
@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"count_daily_user_type", _count_daily_user_type "count_daily_user_type", _count_daily_user_type
) )
async def count_nonbridged_users(self): async def count_nonbridged_users(self) -> int:
def _count_users(txn): def _count_users(txn: LoggingTransaction) -> int:
txn.execute( txn.execute(
""" """
SELECT COUNT(*) FROM users SELECT COUNT(*) FROM users
WHERE appservice_id IS NULL WHERE appservice_id IS NULL
""" """
) )
(count,) = txn.fetchone() (count,) = cast(Tuple[int], txn.fetchone())
return count return count
return await self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
async def count_real_users(self): async def count_real_users(self) -> int:
"""Counts all users without a special user_type registered on the homeserver.""" """Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn): def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if rows: if rows:
@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return user_id return user_id
def get_user_id_by_threepid_txn( def get_user_id_by_threepid_txn(
self, txn, medium: str, address: str self, txn: LoggingTransaction, medium: str, address: str
) -> Optional[str]: ) -> Optional[str]:
"""Returns user id from threepid """Returns user id from threepid
@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
) )
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list( return await self.db_pool.simple_select_list(
"user_threepids", "user_threepids",
{"user_id": user_id}, {"user_id": user_id},
@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def add_user_bound_threepid( async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str self, user_id: str, medium: str, address: str, id_server: str
): ) -> None:
"""The server proxied a bind request to the given identity server on """The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid. asks us to unbind the threepid.
@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
assert address or sid assert address or sid
def get_threepid_validation_session_txn(txn): def get_threepid_validation_session_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
sql = """ sql = """
SELECT address, session_id, medium, client_secret, SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at last_send_attempt, validated_at
@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
session_id: The ID of the session to delete session_id: The ID of the session to delete
""" """
def delete_threepid_session_txn(txn): def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="threepid_validation_token", table="threepid_validation_token",
@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def cull_expired_threepid_validation_tokens(self) -> None: async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed""" """Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts): def cull_expired_threepid_validation_tokens_txn(
txn: LoggingTransaction, ts: int
) -> None:
sql = """ sql = """
DELETE FROM threepid_validation_token WHERE DELETE FROM threepid_validation_token WHERE
expires < ? expires < ?
@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
@wrap_as_background_process("account_validity_set_expiration_dates") @wrap_as_background_process("account_validity_set_expiration_dates")
async def _set_expiration_date_when_missing(self): async def _set_expiration_date_when_missing(self) -> None:
""" """
Retrieves the list of registered users that don't have an expiration date, and Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them. adds an expiration date for each of them.
""" """
def select_users_with_no_expiration_date_txn(txn): def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
"""Retrieves the list of registered users with no expiration date from the """Retrieves the list of registered users with no expiration date from the
database, filtering out deactivated users. database, filtering out deactivated users.
""" """
@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
select_users_with_no_expiration_date_txn, select_users_with_no_expiration_date_txn,
) )
def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): def set_expiration_date_for_user_txn(
self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
) -> None:
"""Sets an expiration date to the account with the given user ID. """Sets an expiration date to the account with the given user ID.
Args: Args:
@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token pending use token: The registration token pending use
""" """
def _set_registration_token_pending_txn(txn): def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
pending = self.db_pool.simple_select_one_onecol_txn( pending = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
"registration_tokens", "registration_tokens",
@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
updatevalues={"pending": pending + 1}, updatevalues={"pending": pending + 1},
) )
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_registration_token_pending", _set_registration_token_pending_txn "set_registration_token_pending", _set_registration_token_pending_txn
) )
@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
token: The registration token to be 'used' token: The registration token to be 'used'
""" """
def _use_registration_token_txn(txn): def _use_registration_token_txn(txn: LoggingTransaction) -> None:
# Normally, res is Optional[Dict[str, Any]]. # Normally, res is Optional[Dict[str, Any]].
# Override type because the return type is only optional if # Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors # allow_none is True, and we don't want mypy throwing errors
@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
}, },
) )
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"use_registration_token", _use_registration_token_txn "use_registration_token", _use_registration_token_txn
) )
@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A list of dicts, each containing details of a token. A list of dicts, each containing details of a token.
""" """
def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): def select_registration_tokens_txn(
txn: LoggingTransaction, now: int, valid: Optional[bool]
) -> List[Dict[str, Any]]:
if valid is None: if valid is None:
# Return all tokens regardless of validity # Return all tokens regardless of validity
txn.execute("SELECT * FROM registration_tokens") txn.execute("SELECT * FROM registration_tokens")
@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Whether the row was inserted or not. Whether the row was inserted or not.
""" """
def _create_registration_token_txn(txn): def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
row = self.db_pool.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
"registration_tokens", "registration_tokens",
@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
A dict with all info about the token, or None if token doesn't exist. A dict with all info about the token, or None if token doesn't exist.
""" """
def _update_registration_token_txn(txn): def _update_registration_token_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
try: try:
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> Optional[RefreshTokenLookupResult]: ) -> Optional[RefreshTokenLookupResult]:
"""Lookup a refresh token with hints about its validity.""" """Lookup a refresh token with hints about its validity."""
def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: def _lookup_refresh_token_txn(
txn: LoggingTransaction,
) -> Optional[RefreshTokenLookupResult]:
txn.execute( txn.execute(
""" """
SELECT SELECT
@ -1807,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
unique=False, unique=False,
) )
async def _background_update_set_deactivated_flag(self, progress, batch_size): async def _background_update_set_deactivated_flag(
self, progress: JsonDict, batch_size: int
) -> int:
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them. for each of them.
""" """
last_user = progress.get("user_id", "") last_user = progress.get("user_id", "")
def _background_update_set_deactivated_flag_txn(txn): def _background_update_set_deactivated_flag_txn(
txn: LoggingTransaction,
) -> Tuple[bool, int]:
txn.execute( txn.execute(
""" """
SELECT SELECT
@ -1886,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
deactivated, deactivated,
) )
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): def set_user_deactivated_status_txn(
self, txn: LoggingTransaction, user_id: str, deactivated: bool
) -> None:
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn=txn, txn=txn,
table="users", table="users",
@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id return next_id
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: def _set_device_for_access_token_txn(
self, txn: LoggingTransaction, token: str, device_id: str
) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn( old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id" txn, "access_tokens", {"token": token}, "device_id"
) )
@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def _register_user( def _register_user(
self, self,
txn, txn: LoggingTransaction,
user_id: str, user_id: str,
password_hash: Optional[str], password_hash: Optional[str],
was_guest: bool, was_guest: bool,
@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool, admin: bool,
user_type: Optional[str], user_type: Optional[str],
shadow_banned: bool, shadow_banned: bool,
): ) -> None:
user_id_obj = UserID.from_string(user_id) user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time()) now = int(self._clock.time())
@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
pointless. Use flush_user separately. pointless. Use flush_user separately.
""" """
def user_set_password_hash_txn(txn): def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash} txn, "users", {"name": user_id}, {"password_hash": password_hash}
) )
@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found StoreError(404) if user not found
""" """
def f(txn): def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="users", table="users",
@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
StoreError(404) if user not found StoreError(404) if user not found
""" """
def f(txn): def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="users", table="users",
@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
A tuple of (token, token id, device id) for each of the deleted tokens A tuple of (token, token id, device id) for each of the deleted tokens
""" """
def f(txn): def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
keyvalues = {"user_id": user_id} keyvalues = {"user_id": user_id}
if device_id is not None: if device_id is not None:
keyvalues["device_id"] = device_id keyvalues["device_id"] = device_id
@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return await self.db_pool.runInteraction("user_delete_access_tokens", f) return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def delete_access_token(self, access_token: str) -> None: async def delete_access_token(self, access_token: str) -> None:
def f(txn): def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn( self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token} txn, table="access_tokens", keyvalues={"token": access_token}
) )
@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f) await self.db_pool.runInteraction("delete_access_token", f)
async def delete_refresh_token(self, refresh_token: str) -> None: async def delete_refresh_token(self, refresh_token: str) -> None:
def f(txn): def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn( self.db_pool.simple_delete_one_txn(
txn, table="refresh_tokens", keyvalues={"token": refresh_token} txn, table="refresh_tokens", keyvalues={"token": refresh_token}
) )
@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
""" """
# Insert everything into a transaction in order to run atomically # Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn): def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
row = self.db_pool.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
table="threepid_validation_session", table="threepid_validation_session",
@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
longer be valid longer be valid
""" """
def start_or_continue_validation_session_txn(txn): def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
# Create or update a validation session # Create or update a validation session
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,

View file

@ -742,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s; %s;
""" """
def _get_if_events_have_relations(txn) -> List[str]: def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
clauses: List[str] = [] clauses: List[str] = []
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids txn.database_engine, "relates_to_id", parent_ids

View file

@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore): class SignatureWorkerStore(EventsWorkerStore):
@cached() @cached()
def get_event_reference_hash(self, event_id): def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
# This is a dummy function to allow get_event_reference_hashes # This is a dummy function to allow get_event_reference_hashes
# to use its cache # to use its cache
raise NotImplementedError() raise NotImplementedError()

View file

@ -204,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The current state of the room. The current state of the room.
""" """
def _get_current_state_ids_txn(txn): def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
txn.execute( txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events """SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ? WHERE room_id = ?

View file

@ -36,7 +36,17 @@ what sort order was used:
""" """
import logging import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
List,
Optional,
Set,
Tuple,
cast,
)
import attr import attr
from frozendict import frozendict from frozendict import frozendict
@ -732,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A tuple of (stream ordering, topological ordering, event_id) A tuple of (stream ordering, topological ordering, event_id)
""" """
def _f(txn): def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
sql = ( sql = (
"SELECT stream_ordering, topological_ordering, event_id" "SELECT stream_ordering, topological_ordering, event_id"
" FROM events" " FROM events"
@ -742,7 +752,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" LIMIT 1" " LIMIT 1"
) )
txn.execute(sql, (room_id, stream_ordering)) txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone() return cast(Optional[Tuple[int, int, str]], txn.fetchone())
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f "get_room_event_before_stream_ordering", _f
@ -839,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod @staticmethod
def _set_before_and_after( def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
): ) -> None:
"""Inserts ordering information to events' internal metadata from """Inserts ordering information to events' internal metadata from
the DB rows. the DB rows.
@ -985,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the `current_id`). the `current_id`).
""" """
def get_all_new_events_stream_txn(txn): def get_all_new_events_stream_txn(
txn: LoggingTransaction,
) -> Tuple[int, List[str]]:
sql = ( sql = (
"SELECT e.stream_ordering, e.event_id" "SELECT e.stream_ordering, e.event_id"
" FROM events AS e" " FROM events AS e"
@ -1331,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_id_for_instance(self, instance_name: str) -> int: async def get_id_for_instance(self, instance_name: str) -> int:
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance.""" """Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
def _get_id_for_instance_txn(txn): def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
instance_id = self.db_pool.simple_select_one_onecol_txn( instance_id = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="instance_map", table="instance_map",

View file

@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
) )
def get_tag_content( def get_tag_content(
txn: LoggingTransaction, tag_ids txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
) -> List[Tuple[int, Tuple[str, str, str]]]: ) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = [] results = []
@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
def _update_revision_txn( def _update_revision_txn(
self, txn, user_id: str, room_id: str, next_id: int self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int
) -> None: ) -> None:
"""Update the latest revision of the tags for the given user and room. """Update the latest revision of the tags for the given user and room.