Disallow untyped defs in synapse._scripts (#12422)

Of note: 

* No untyped defs in `register_new_matrix_user`

This one might be contraversial. `request_registration` has three
dependency-injection arguments used for testing. I'm removing the
injection of the `requests` module and using `unitest.mock.patch` in the
test cases instead.

Doing `reveal_type(requests)` and `reveal_type(requests.get)` before the
change:

```
synapse/_scripts/register_new_matrix_user.py:45: note: Revealed type is "Any"
synapse/_scripts/register_new_matrix_user.py:46: note: Revealed type is "Any"
```

And after:

```
synapse/_scripts/register_new_matrix_user.py:44: note: Revealed type is "types.ModuleType"
synapse/_scripts/register_new_matrix_user.py:45: note: Revealed type is "def (url: Union[builtins.str, builtins.bytes], params: Union[Union[_typeshed.SupportsItems[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]], Tuple[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]], typing.Iterable[Tuple[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]]], builtins.str, builtins.bytes], None] =, data: Union[Any, None] =, headers: Union[Any, None] =, cookies: Union[Any, None] =, files: Union[Any, None] =, auth: Union[Any, None] =, timeout: Union[Any, None] =, allow_redirects: builtins.bool =, proxies: Union[Any, None] =, hooks: Union[Any, None] =, stream: Union[Any, None] =, verify: Union[Any, None] =, cert: Union[Any, None] =, json: Union[Any, None] =) -> requests.models.Response"
```

* Drive-by comment in `synapse.storage.types`

* No untyped defs in `synapse_port_db`

This was by far the most painful. I'm happy to break this up into
smaller pieces for review if it's not managable as-is.
This commit is contained in:
David Robertson 2022-04-11 12:41:55 +01:00 committed by GitHub
parent 5f72ea1bde
commit 961ee75a9b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 221 additions and 140 deletions

1
changelog.d/12422.misc Normal file
View file

@ -0,0 +1 @@
Make `synapse._scripts` pass type checks.

View file

@ -93,6 +93,9 @@ exclude = (?x)
|tests/utils.py |tests/utils.py
)$ )$
[mypy-synapse._scripts.*]
disallow_untyped_defs = True
[mypy-synapse.api.*] [mypy-synapse.api.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -15,19 +15,19 @@
import argparse import argparse
import sys import sys
import time import time
from typing import Optional from typing import NoReturn, Optional
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
from signedjson.types import VerifyKey from signedjson.types import VerifyKey
def exit(status: int = 0, message: Optional[str] = None): def exit(status: int = 0, message: Optional[str] = None) -> NoReturn:
if message: if message:
print(message, file=sys.stderr) print(message, file=sys.stderr)
sys.exit(status) sys.exit(status)
def format_plain(public_key: VerifyKey): def format_plain(public_key: VerifyKey) -> None:
print( print(
"%s:%s %s" "%s:%s %s"
% ( % (
@ -38,7 +38,7 @@ def format_plain(public_key: VerifyKey):
) )
def format_for_config(public_key: VerifyKey, expiry_ts: int): def format_for_config(public_key: VerifyKey, expiry_ts: int) -> None:
print( print(
' "%s:%s": { key: "%s", expired_ts: %i }' ' "%s:%s": { key: "%s", expired_ts: %i }'
% ( % (
@ -50,7 +50,7 @@ def format_for_config(public_key: VerifyKey, expiry_ts: int):
) )
def main(): def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -94,7 +94,6 @@ def main():
message="Error reading key from file %s: %s %s" message="Error reading key from file %s: %s %s"
% (file.name, type(e), e), % (file.name, type(e), e),
) )
res = []
for key in res: for key in res:
formatter(get_verify_key(key)) formatter(get_verify_key(key))

View file

@ -7,7 +7,7 @@ import sys
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
def main(): def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--config-dir", "--config-dir",

View file

@ -20,7 +20,7 @@ import sys
from synapse.config.logger import DEFAULT_LOG_CONFIG from synapse.config.logger import DEFAULT_LOG_CONFIG
def main(): def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(

View file

@ -20,7 +20,7 @@ from signedjson.key import generate_signing_key, write_signing_keys
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
def main(): def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(

View file

@ -9,7 +9,7 @@ import bcrypt
import yaml import yaml
def prompt_for_pass(): def prompt_for_pass() -> str:
password = getpass.getpass("Password: ") password = getpass.getpass("Password: ")
if not password: if not password:
@ -23,7 +23,7 @@ def prompt_for_pass():
return password return password
def main(): def main() -> None:
bcrypt_rounds = 12 bcrypt_rounds = 12
password_pepper = "" password_pepper = ""

View file

@ -42,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
logger = logging.getLogger() logger = logging.getLogger()
def main(src_repo, dest_repo): def main(src_repo: str, dest_repo: str) -> None:
src_paths = MediaFilePaths(src_repo) src_paths = MediaFilePaths(src_repo)
dest_paths = MediaFilePaths(dest_repo) dest_paths = MediaFilePaths(dest_repo)
for line in sys.stdin: for line in sys.stdin:
@ -55,14 +55,19 @@ def main(src_repo, dest_repo):
move_media(parts[0], parts[1], src_paths, dest_paths) move_media(parts[0], parts[1], src_paths, dest_paths)
def move_media(origin_server, file_id, src_paths, dest_paths): def move_media(
origin_server: str,
file_id: str,
src_paths: MediaFilePaths,
dest_paths: MediaFilePaths,
) -> None:
"""Move the given file, and any thumbnails, to the dest repo """Move the given file, and any thumbnails, to the dest repo
Args: Args:
origin_server (str): origin_server:
file_id (str): file_id:
src_paths (MediaFilePaths): src_paths:
dest_paths (MediaFilePaths): dest_paths:
""" """
logger.info("%s/%s", origin_server, file_id) logger.info("%s/%s", origin_server, file_id)
@ -91,7 +96,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
) )
def mkdir_and_move(original_file, dest_file): def mkdir_and_move(original_file: str, dest_file: str) -> None:
dirname = os.path.dirname(dest_file) dirname = os.path.dirname(dest_file)
if not os.path.exists(dirname): if not os.path.exists(dirname):
logger.debug("mkdir %s", dirname) logger.debug("mkdir %s", dirname)

View file

@ -22,7 +22,7 @@ import logging
import sys import sys
from typing import Callable, Optional from typing import Callable, Optional
import requests as _requests import requests
import yaml import yaml
@ -33,7 +33,6 @@ def request_registration(
shared_secret: str, shared_secret: str,
admin: bool = False, admin: bool = False,
user_type: Optional[str] = None, user_type: Optional[str] = None,
requests=_requests,
_print: Callable[[str], None] = print, _print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit, exit: Callable[[int], None] = sys.exit,
) -> None: ) -> None:

View file

@ -22,10 +22,26 @@ import sys
import time import time
import traceback import traceback
from types import TracebackType from types import TracebackType
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
Type,
TypeVar,
cast,
)
import yaml import yaml
from matrix_common.versionstring import get_distribution_version_string from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import TypedDict
from twisted.internet import defer, reactor as reactor_ from twisted.internet import defer, reactor as reactor_
@ -36,7 +52,7 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
run_in_background, run_in_background,
) )
from synapse.storage.database import DatabasePool, make_conn from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
@ -173,6 +189,8 @@ end_error_exec_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType] Tuple[Type[BaseException], BaseException, TracebackType]
] = None ] = None
R = TypeVar("R")
class Store( class Store(
ClientIpBackgroundUpdateStore, ClientIpBackgroundUpdateStore,
@ -195,17 +213,19 @@ class Store(
PresenceBackgroundUpdateStore, PresenceBackgroundUpdateStore,
GroupServerWorkerStore, GroupServerWorkerStore,
): ):
def execute(self, f, *args, **kwargs): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args): def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]:
def r(txn): def r(txn: LoggingTransaction) -> List[Tuple]:
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
return self.db_pool.runInteraction("execute_sql", r) return self.db_pool.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows): def insert_many_txn(
self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple]
) -> None:
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table, table,
", ".join(k for k in headers), ", ".join(k for k in headers),
@ -218,14 +238,15 @@ class Store(
logger.exception("Failed to insert: %s", table) logger.exception("Failed to insert: %s", table)
raise raise
def set_room_is_public(self, room_id, is_public): # Note: the parent method is an `async def`.
def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn:
raise Exception( raise Exception(
"Attempt to set room_is_public during port_db: database not empty?" "Attempt to set room_is_public during port_db: database not empty?"
) )
class MockHomeserver: class MockHomeserver:
def __init__(self, config): def __init__(self, config: HomeServerConfig):
self.clock = Clock(reactor) self.clock = Clock(reactor)
self.config = config self.config = config
self.hostname = config.server.server_name self.hostname = config.server.server_name
@ -233,24 +254,30 @@ class MockHomeserver:
"matrix-synapse" "matrix-synapse"
) )
def get_clock(self): def get_clock(self) -> Clock:
return self.clock return self.clock
def get_reactor(self): def get_reactor(self) -> ISynapseReactor:
return reactor return reactor
def get_instance_name(self): def get_instance_name(self) -> str:
return "master" return "master"
class Porter: class Porter:
def __init__(self, sqlite_config, progress, batch_size, hs_config): def __init__(
self,
sqlite_config: Dict[str, Any],
progress: "Progress",
batch_size: int,
hs_config: HomeServerConfig,
):
self.sqlite_config = sqlite_config self.sqlite_config = sqlite_config
self.progress = progress self.progress = progress
self.batch_size = batch_size self.batch_size = batch_size
self.hs_config = hs_config self.hs_config = hs_config
async def setup_table(self, table): async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
row = await self.postgres_store.db_pool.simple_select_one( row = await self.postgres_store.db_pool.simple_select_one(
@ -292,7 +319,7 @@ class Porter:
) )
else: else:
def delete_all(txn): def delete_all(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
) )
@ -317,7 +344,7 @@ class Porter:
async def get_table_constraints(self) -> Dict[str, Set[str]]: async def get_table_constraints(self) -> Dict[str, Set[str]]:
"""Returns a map of tables that have foreign key constraints to tables they depend on.""" """Returns a map of tables that have foreign key constraints to tables they depend on."""
def _get_constraints(txn): def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]:
# We can pull the information about foreign key constraints out from # We can pull the information about foreign key constraints out from
# the postgres schema tables. # the postgres schema tables.
sql = """ sql = """
@ -343,8 +370,13 @@ class Porter:
) )
async def handle_table( async def handle_table(
self, table, postgres_size, table_size, forward_chunk, backward_chunk self,
): table: str,
postgres_size: int,
table_size: int,
forward_chunk: int,
backward_chunk: int,
) -> None:
logger.info( logger.info(
"Table %s: %i/%i (rows %i-%i) already ported", "Table %s: %i/%i (rows %i-%i) already ported",
table, table,
@ -391,7 +423,9 @@ class Porter:
while True: while True:
def r(txn): def r(
txn: LoggingTransaction,
) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
forward_rows = [] forward_rows = []
backward_rows = [] backward_rows = []
if do_forward[0]: if do_forward[0]:
@ -418,6 +452,7 @@ class Porter:
) )
if frows or brows: if frows or brows:
assert headers is not None
if frows: if frows:
forward_chunk = max(row[0] for row in frows) + 1 forward_chunk = max(row[0] for row in frows) + 1
if brows: if brows:
@ -426,7 +461,8 @@ class Porter:
rows = frows + brows rows = frows + brows
rows = self._convert_rows(table, headers, rows) rows = self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn: LoggingTransaction) -> None:
assert headers is not None
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
self.postgres_store.db_pool.simple_update_one_txn( self.postgres_store.db_pool.simple_update_one_txn(
@ -448,8 +484,12 @@ class Porter:
return return
async def handle_search_table( async def handle_search_table(
self, postgres_size, table_size, forward_chunk, backward_chunk self,
): postgres_size: int,
table_size: int,
forward_chunk: int,
backward_chunk: int,
) -> None:
select = ( select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es" " FROM event_search as es"
@ -460,7 +500,7 @@ class Porter:
while True: while True:
def r(txn): def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
txn.execute(select, (forward_chunk, self.batch_size)) txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@ -474,7 +514,7 @@ class Porter:
# We have to treat event_search differently since it has a # We have to treat event_search differently since it has a
# different structure in the two different databases. # different structure in the two different databases.
def insert(txn): def insert(txn: LoggingTransaction) -> None:
sql = ( sql = (
"INSERT INTO event_search (event_id, room_id, key," "INSERT INTO event_search (event_id, room_id, key,"
" sender, vector, origin_server_ts, stream_ordering)" " sender, vector, origin_server_ts, stream_ordering)"
@ -528,7 +568,7 @@ class Porter:
self, self,
db_config: DatabaseConnectionConfig, db_config: DatabaseConnectionConfig,
allow_outdated_version: bool = False, allow_outdated_version: bool = False,
): ) -> Store:
"""Builds and returns a database store using the provided configuration. """Builds and returns a database store using the provided configuration.
Args: Args:
@ -556,7 +596,7 @@ class Porter:
return store return store
async def run_background_updates_on_postgres(self): async def run_background_updates_on_postgres(self) -> None:
# Manually apply all background updates on the PostgreSQL database. # Manually apply all background updates on the PostgreSQL database.
postgres_ready = ( postgres_ready = (
await self.postgres_store.db_pool.updates.has_completed_background_updates() await self.postgres_store.db_pool.updates.has_completed_background_updates()
@ -568,12 +608,12 @@ class Porter:
self.progress.set_state("Running background updates on PostgreSQL") self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready: while not postgres_ready:
await self.postgres_store.db_pool.updates.do_next_background_update(100) await self.postgres_store.db_pool.updates.do_next_background_update(True)
postgres_ready = await ( postgres_ready = await (
self.postgres_store.db_pool.updates.has_completed_background_updates() self.postgres_store.db_pool.updates.has_completed_background_updates()
) )
async def run(self): async def run(self) -> None:
"""Ports the SQLite database to a PostgreSQL database. """Ports the SQLite database to a PostgreSQL database.
When a fatal error is met, its message is assigned to the global "end_error" When a fatal error is met, its message is assigned to the global "end_error"
@ -609,7 +649,7 @@ class Porter:
self.progress.set_state("Creating port tables") self.progress.set_state("Creating port tables")
def create_port_table(txn): def create_port_table(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" "CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE," " table_name varchar(100) NOT NULL UNIQUE,"
@ -622,7 +662,7 @@ class Porter:
# We want people to be able to rerun this script from an old port # We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not # so that they can pick up any missing events that were not
# ported across. # ported across.
def alter_table(txn): def alter_table(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3" "ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid" " RENAME rowid TO forward_rowid"
@ -742,7 +782,9 @@ class Porter:
finally: finally:
reactor.stop() reactor.stop()
def _convert_rows(self, table, headers, rows): def _convert_rows(
self, table: str, headers: List[str], rows: List[Tuple]
) -> List[Tuple]:
bool_col_names = BOOLEAN_COLUMNS.get(table, []) bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
@ -750,7 +792,7 @@ class Porter:
class BadValueException(Exception): class BadValueException(Exception):
pass pass
def conv(j, col): def conv(j: int, col: object) -> object:
if j in bool_cols: if j in bool_cols:
return bool(col) return bool(col)
if isinstance(col, bytes): if isinstance(col, bytes):
@ -776,7 +818,7 @@ class Porter:
return outrows return outrows
async def _setup_sent_transactions(self): async def _setup_sent_transactions(self) -> Tuple[int, int, int]:
# Only save things from the last day # Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000 yesterday = int(time.time() * 1000) - 86400000
@ -788,10 +830,10 @@ class Porter:
")" ")"
) )
def r(txn): def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
txn.execute(select) txn.execute(select)
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers: List[str] = [column[0] for column in txn.description]
ts_ind = headers.index("ts") ts_ind = headers.index("ts")
@ -805,7 +847,7 @@ class Porter:
if inserted_rows: if inserted_rows:
max_inserted_rowid = max(r[0] for r in rows) max_inserted_rowid = max(r[0] for r in rows)
def insert(txn): def insert(txn: LoggingTransaction) -> None:
self.postgres_store.insert_many_txn( self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows txn, "sent_transactions", headers[1:], rows
) )
@ -814,7 +856,7 @@ class Porter:
else: else:
max_inserted_rowid = 0 max_inserted_rowid = 0
def get_start_id(txn): def get_start_id(txn: LoggingTransaction) -> int:
txn.execute( txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?" "SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1", " ORDER BY rowid ASC LIMIT 1",
@ -839,12 +881,13 @@ class Porter:
}, },
) )
def get_sent_table_size(txn): def get_sent_table_size(txn: LoggingTransaction) -> int:
txn.execute( txn.execute(
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
) )
(size,) = txn.fetchone() result = txn.fetchone()
return int(size) assert result is not None
return int(result[0])
remaining_count = await self.sqlite_store.execute(get_sent_table_size) remaining_count = await self.sqlite_store.execute(get_sent_table_size)
@ -852,25 +895,35 @@ class Porter:
return next_chunk, inserted_rows, total_count return next_chunk, inserted_rows, total_count
async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): async def _get_remaining_count_to_port(
frows = await self.sqlite_store.execute_sql( self, table: str, forward_chunk: int, backward_chunk: int
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk ) -> int:
frows = cast(
List[Tuple[int]],
await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
),
) )
brows = await self.sqlite_store.execute_sql( brows = cast(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk List[Tuple[int]],
await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
),
) )
return frows[0][0] + brows[0][0] return frows[0][0] + brows[0][0]
async def _get_already_ported_count(self, table): async def _get_already_ported_count(self, table: str) -> int:
rows = await self.postgres_store.execute_sql( rows = await self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,) "SELECT count(*) FROM %s" % (table,)
) )
return rows[0][0] return rows[0][0]
async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): async def _get_total_count_to_port(
self, table: str, forward_chunk: int, backward_chunk: int
) -> Tuple[int, int]:
remaining, done = await make_deferred_yieldable( remaining, done = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
@ -891,14 +944,17 @@ class Porter:
return done, remaining + done return done, remaining + done
async def _setup_state_group_id_seq(self) -> None: async def _setup_state_group_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( curr_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
) )
if not curr_id: if not curr_id:
return return
def r(txn): def r(txn: LoggingTransaction) -> None:
assert curr_id is not None
next_id = curr_id + 1 next_id = curr_id + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
@ -909,7 +965,7 @@ class Porter:
"setup_user_id_seq", find_max_generated_user_id_localpart "setup_user_id_seq", find_max_generated_user_id_localpart
) )
def r(txn): def r(txn: LoggingTransaction) -> None:
next_id = curr_id + 1 next_id = curr_id + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
@ -931,7 +987,7 @@ class Porter:
allow_none=True, allow_none=True,
) )
def _setup_events_stream_seqs_set_pos(txn): def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None:
if curr_forward_id: if curr_forward_id:
txn.execute( txn.execute(
"ALTER SEQUENCE events_stream_seq RESTART WITH %s", "ALTER SEQUENCE events_stream_seq RESTART WITH %s",
@ -955,17 +1011,20 @@ class Porter:
"""Set a sequence to the correct value.""" """Set a sequence to the correct value."""
current_stream_ids = [] current_stream_ids = []
for stream_id_table in stream_id_tables: for stream_id_table in stream_id_tables:
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol( max_stream_id = cast(
table=stream_id_table, int,
keyvalues={}, await self.sqlite_store.db_pool.simple_select_one_onecol(
retcol="COALESCE(MAX(stream_id), 1)", table=stream_id_table,
allow_none=True, keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
),
) )
current_stream_ids.append(max_stream_id) current_stream_ids.append(max_stream_id)
next_id = max(current_stream_ids) + 1 next_id = max(current_stream_ids) + 1
def r(txn): def r(txn: LoggingTransaction) -> None:
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,) sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
txn.execute(sql + " %s", (next_id,)) txn.execute(sql + " %s", (next_id,))
@ -974,14 +1033,18 @@ class Porter:
) )
async def _setup_auth_chain_sequence(self) -> None: async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol( curr_chain_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains", table="event_auth_chains",
keyvalues={}, keyvalues={},
retcol="MAX(chain_id)", retcol="MAX(chain_id)",
allow_none=True, allow_none=True,
) )
def r(txn): def r(txn: LoggingTransaction) -> None:
# Presumably there is at least one row in event_auth_chains.
assert curr_chain_id is not None
txn.execute( txn.execute(
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s", "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
(curr_chain_id + 1,), (curr_chain_id + 1,),
@ -999,15 +1062,22 @@ class Porter:
############################################## ##############################################
class Progress(object): class TableProgress(TypedDict):
start: int
num_done: int
total: int
perc: int
class Progress:
"""Used to report progress of the port""" """Used to report progress of the port"""
def __init__(self): def __init__(self) -> None:
self.tables = {} self.tables: Dict[str, TableProgress] = {}
self.start_time = int(time.time()) self.start_time = int(time.time())
def add_table(self, table, cur, size): def add_table(self, table: str, cur: int, size: int) -> None:
self.tables[table] = { self.tables[table] = {
"start": cur, "start": cur,
"num_done": cur, "num_done": cur,
@ -1015,19 +1085,22 @@ class Progress(object):
"perc": int(cur * 100 / size), "perc": int(cur * 100 / size),
} }
def update(self, table, num_done): def update(self, table: str, num_done: int) -> None:
data = self.tables[table] data = self.tables[table]
data["num_done"] = num_done data["num_done"] = num_done
data["perc"] = int(num_done * 100 / data["total"]) data["perc"] = int(num_done * 100 / data["total"])
def done(self): def done(self) -> None:
pass
def set_state(self, state: str) -> None:
pass pass
class CursesProgress(Progress): class CursesProgress(Progress):
"""Reports progress to a curses window""" """Reports progress to a curses window"""
def __init__(self, stdscr): def __init__(self, stdscr: "curses.window"):
self.stdscr = stdscr self.stdscr = stdscr
curses.use_default_colors() curses.use_default_colors()
@ -1045,7 +1118,7 @@ class CursesProgress(Progress):
super(CursesProgress, self).__init__() super(CursesProgress, self).__init__()
def update(self, table, num_done): def update(self, table: str, num_done: int) -> None:
super(CursesProgress, self).update(table, num_done) super(CursesProgress, self).update(table, num_done)
self.total_processed = 0 self.total_processed = 0
@ -1056,7 +1129,7 @@ class CursesProgress(Progress):
self.render() self.render()
def render(self, force=False): def render(self, force: bool = False) -> None:
now = time.time() now = time.time()
if not force and now - self.last_update < 0.2: if not force and now - self.last_update < 0.2:
@ -1128,12 +1201,12 @@ class CursesProgress(Progress):
self.stdscr.refresh() self.stdscr.refresh()
self.last_update = time.time() self.last_update = time.time()
def done(self): def done(self) -> None:
self.finished = True self.finished = True
self.render(True) self.render(True)
self.stdscr.getch() self.stdscr.getch()
def set_state(self, state): def set_state(self, state: str) -> None:
self.stdscr.clear() self.stdscr.clear()
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
self.stdscr.refresh() self.stdscr.refresh()
@ -1142,7 +1215,7 @@ class CursesProgress(Progress):
class TerminalProgress(Progress): class TerminalProgress(Progress):
"""Just prints progress to the terminal""" """Just prints progress to the terminal"""
def update(self, table, num_done): def update(self, table: str, num_done: int) -> None:
super(TerminalProgress, self).update(table, num_done) super(TerminalProgress, self).update(table, num_done)
data = self.tables[table] data = self.tables[table]
@ -1151,7 +1224,7 @@ class TerminalProgress(Progress):
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
) )
def set_state(self, state): def set_state(self, state: str) -> None:
print(state + "...") print(state + "...")
@ -1159,7 +1232,7 @@ class TerminalProgress(Progress):
############################################## ##############################################
def main(): def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to" description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database." " a new PostgreSQL database."
@ -1225,7 +1298,7 @@ def main():
config = HomeServerConfig() config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "") config.parse_config_dict(hs_config, "", "")
def start(stdscr=None): def start(stdscr: Optional["curses.window"] = None) -> None:
progress: Progress progress: Progress
if stdscr: if stdscr:
progress = CursesProgress(stdscr) progress = CursesProgress(stdscr)
@ -1240,7 +1313,7 @@ def main():
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def run(): def run() -> Generator["defer.Deferred[Any]", Any, None]:
with LoggingContext("synapse_port_db_run"): with LoggingContext("synapse_port_db_run"):
yield defer.ensureDeferred(porter.run()) yield defer.ensureDeferred(porter.run())

View file

@ -24,7 +24,7 @@ import signal
import subprocess import subprocess
import sys import sys
import time import time
from typing import Iterable, Optional from typing import Iterable, NoReturn, Optional, TextIO
import yaml import yaml
@ -45,7 +45,7 @@ one of the following:
--------------------------------------------------------------------------------""" --------------------------------------------------------------------------------"""
def pid_running(pid): def pid_running(pid: int) -> bool:
try: try:
os.kill(pid, 0) os.kill(pid, 0)
except OSError as err: except OSError as err:
@ -68,7 +68,7 @@ def pid_running(pid):
return True return True
def write(message, colour=NORMAL, stream=sys.stdout): def write(message: str, colour: str = NORMAL, stream: TextIO = sys.stdout) -> None:
# Lets check if we're writing to a TTY before colouring # Lets check if we're writing to a TTY before colouring
should_colour = False should_colour = False
try: try:
@ -84,7 +84,7 @@ def write(message, colour=NORMAL, stream=sys.stdout):
stream.write(colour + message + NORMAL + "\n") stream.write(colour + message + NORMAL + "\n")
def abort(message, colour=RED, stream=sys.stderr): def abort(message: str, colour: str = RED, stream: TextIO = sys.stderr) -> NoReturn:
write(message, colour, stream) write(message, colour, stream)
sys.exit(1) sys.exit(1)
@ -166,7 +166,7 @@ Worker = collections.namedtuple(
) )
def main(): def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View file

@ -38,25 +38,25 @@ logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer): class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore [assignment] DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config, **kwargs): def __init__(self, config: HomeServerConfig):
super(MockHomeserver, self).__init__( super(MockHomeserver, self).__init__(
config.server.server_name, reactor=reactor, config=config, **kwargs hostname=config.server.server_name,
) config=config,
reactor=reactor,
self.version_string = "Synapse/" + get_distribution_version_string( version_string="Synapse/"
"matrix-synapse" + get_distribution_version_string("matrix-synapse"),
) )
def run_background_updates(hs): def run_background_updates(hs: HomeServer) -> None:
store = hs.get_datastores().main store = hs.get_datastores().main
async def run_background_updates(): async def run_background_updates() -> None:
await store.db_pool.updates.run_background_updates(sleep=False) await store.db_pool.updates.run_background_updates(sleep=False)
# Stop the reactor to exit the script once every background update is run. # Stop the reactor to exit the script once every background update is run.
reactor.stop() reactor.stop()
def run(): def run() -> None:
# Apply all background updates on the database. # Apply all background updates on the database.
defer.ensureDeferred( defer.ensureDeferred(
run_as_background_process("background_updates", run_background_updates) run_as_background_process("background_updates", run_background_updates)
@ -67,7 +67,7 @@ def run_background_updates(hs):
reactor.run() reactor.run()
def main(): def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=( description=(
"Updates a synapse database to the latest schema and optionally runs background updates" "Updates a synapse database to the latest schema and optionally runs background updates"

View file

@ -45,6 +45,7 @@ class Cursor(Protocol):
Sequence[ Sequence[
# Note that this is an approximate typing based on sqlite3 and other # Note that this is an approximate typing based on sqlite3 and other
# drivers, and may not be entirely accurate. # drivers, and may not be entirely accurate.
# FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
Tuple[ Tuple[
str, str,
Optional[Any], Optional[Any],

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import Mock from unittest.mock import Mock, patch
from synapse._scripts.register_new_matrix_user import request_registration from synapse._scripts.register_new_matrix_user import request_registration
@ -52,16 +52,16 @@ class RegisterTestCase(TestCase):
out = [] out = []
err_code = [] err_code = []
request_registration( with patch("synapse._scripts.register_new_matrix_user.requests", requests):
"user", request_registration(
"pass", "user",
"matrix.org", "pass",
"shared", "matrix.org",
admin=False, "shared",
requests=requests, admin=False,
_print=out.append, _print=out.append,
exit=err_code.append, exit=err_code.append,
) )
# We should get the success message making sure everything is OK. # We should get the success message making sure everything is OK.
self.assertIn("Success!", out) self.assertIn("Success!", out)
@ -88,16 +88,16 @@ class RegisterTestCase(TestCase):
out = [] out = []
err_code = [] err_code = []
request_registration( with patch("synapse._scripts.register_new_matrix_user.requests", requests):
"user", request_registration(
"pass", "user",
"matrix.org", "pass",
"shared", "matrix.org",
admin=False, "shared",
requests=requests, admin=False,
_print=out.append, _print=out.append,
exit=err_code.append, exit=err_code.append,
) )
# Exit was called # Exit was called
self.assertEqual(err_code, [1]) self.assertEqual(err_code, [1])
@ -140,16 +140,16 @@ class RegisterTestCase(TestCase):
out = [] out = []
err_code = [] err_code = []
request_registration( with patch("synapse._scripts.register_new_matrix_user.requests", requests):
"user", request_registration(
"pass", "user",
"matrix.org", "pass",
"shared", "matrix.org",
admin=False, "shared",
requests=requests, admin=False,
_print=out.append, _print=out.append,
exit=err_code.append, exit=err_code.append,
) )
# Exit was called # Exit was called
self.assertEqual(err_code, [1]) self.assertEqual(err_code, [1])