Type hints for RegistrationStore (#8615)

This commit is contained in:
Erik Johnston 2020-10-22 11:56:58 +01:00 committed by GitHub
parent 2ac908f377
commit a9f90fa73a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 85 additions and 74 deletions

1
changelog.d/8615.misc Normal file
View file

@ -0,0 +1 @@
Type hints for `RegistrationStore`.

View file

@ -57,6 +57,7 @@ files =
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,

View file

@ -146,7 +146,6 @@ class DataStore(
db_conn, "e2e_cross_signing_keys", "stream_id"
)
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")

View file

@ -16,29 +16,33 @@
# limitations under the License.
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
@ -55,7 +59,7 @@ class RegistrationWorkerStore(SQLBaseStore):
# Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks:
self.clock.looping_call(
self._clock.looping_call(
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
)
@ -92,7 +96,7 @@ class RegistrationWorkerStore(SQLBaseStore):
if not info:
return False
now = self.clock.time_msec()
now = self._clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
@ -257,7 +261,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
self._clock.time_msec(),
self.config.account_validity.renew_at,
)
@ -328,13 +332,17 @@ class RegistrationWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
sql = """
SELECT users.name,
users.is_guest,
users.shadow_banned,
access_tokens.id as token_id,
access_tokens.device_id,
access_tokens.valid_until_ms
FROM users
INNER JOIN access_tokens on users.name = access_tokens.user_id
WHERE token = ?
"""
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
@ -803,7 +811,7 @@ class RegistrationWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
self._clock.time_msec(),
)
@wrap_as_background_process("account_validity_set_expiration_dates")
@ -890,10 +898,10 @@ class RegistrationWorkerStore(SQLBaseStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.clock = hs.get_clock()
self._clock = hs.get_clock()
self.config = hs.config
self.db_pool.updates.register_background_index_update(
@ -1016,13 +1024,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return 1
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
Args:
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""
await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
async def add_access_token_to_user(
self,
user_id: str,
@ -1138,19 +1189,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def _register_user(
self,
txn,
user_id,
password_hash,
was_guest,
make_guest,
appservice_id,
create_profile_with_displayname,
admin,
user_type,
shadow_banned,
user_id: str,
password_hash: Optional[str],
was_guest: bool,
make_guest: bool,
appservice_id: Optional[str],
create_profile_with_displayname: Optional[str],
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
):
user_id_obj = UserID.from_string(user_id)
now = int(self.clock.time())
now = int(self._clock.time())
try:
if was_guest:
@ -1374,18 +1425,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
@ -1479,7 +1518,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
updatevalues={"validated_at": self.clock.time_msec()},
updatevalues={"validated_at": self._clock.time_msec()},
)
return next_link
@ -1547,35 +1586,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
Args:
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""
await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""