Improve typing in user_directory files (#10891)

* Improve typing in user_directory files

This makes the user_directory.py in storage pass most of mypy's
checks (including `no-untyped-defs`). Unfortunately that file is in the
tangled web of Store class inheritance so doesn't pass mypy at the moment.

The handlers directory has already been mypyed.

Co-authored-by: reivilibre <olivier@librepush.net>
This commit is contained in:
David Robertson 2021-09-24 10:38:22 +01:00 committed by GitHub
parent e704cc2a48
commit 7f3352743e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 37 deletions

1
changelog.d/10891.misc Normal file
View file

@ -0,0 +1 @@
Improve type hinting in the user directory code.

View file

@ -85,9 +85,11 @@ files =
tests/handlers/test_room_summary.py, tests/handlers/test_room_summary.py,
tests/handlers/test_send_email.py, tests/handlers/test_send_email.py,
tests/handlers/test_sync.py, tests/handlers/test_sync.py,
tests/handlers/test_user_directory.py,
tests/rest/client/test_login.py, tests/rest/client/test_login.py,
tests/rest/client/test_auth.py, tests/rest/client/test_auth.py,
tests/storage/test_state.py, tests/storage/test_state.py,
tests/storage/test_user_directory.py,
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py

View file

@ -14,14 +14,28 @@
import logging import logging
import re import re
from typing import Any, Dict, Iterable, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
cast,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.storage.types import Connection
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms? # add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500 SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(
self,
database: DatabasePool,
db_conn: Connection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname
@ -57,10 +76,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup "populate_user_directory_cleanup", self._populate_user_directory_cleanup
) )
async def _populate_user_directory_createtables(self, progress, batch_size): async def _populate_user_directory_createtables(
self, progress: JsonDict, batch_size: int
) -> int:
# Get all the rooms that we want to process. # Get all the rooms that we want to process.
def _make_staging_area(txn): def _make_staging_area(txn: LoggingTransaction) -> None:
sql = ( sql = (
"CREATE TABLE IF NOT EXISTS " "CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE + TEMP_TABLE
@ -110,16 +131,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
return 1 return 1
async def _populate_user_directory_cleanup(self, progress, batch_size): async def _populate_user_directory_cleanup(
self,
progress: JsonDict,
batch_size: int,
) -> int:
""" """
Update the user directory stream position, then clean up the old tables. Update the user directory stream position, then clean up the old tables.
""" """
position = await self.db_pool.simple_select_one_onecol( position = await self.db_pool.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position" TEMP_TABLE + "_position", {}, "position"
) )
await self.update_user_directory_stream_pos(position) await self.update_user_directory_stream_pos(position)
def _delete_staging_area(txn): def _delete_staging_area(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
@ -133,18 +158,32 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
return 1 return 1
async def _populate_user_directory_process_rooms(self, progress, batch_size): async def _populate_user_directory_process_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
""" """
Rescan the state of all rooms so we can track
- who's in a public room;
- which local users share a private room with other users (local
and remote); and
- who should be in the user_directory.
Args: Args:
progress (dict) progress (dict)
batch_size (int): Maximum number of state events to process batch_size (int): Maximum number of state events to process
per cycle. per cycle.
Returns:
number of events processed.
""" """
# If we don't have progress filed, delete everything. # If we don't have progress filed, delete everything.
if not progress: if not progress:
await self.delete_all_from_user_dir() await self.delete_all_from_user_dir()
def _get_next_batch(txn): def _get_next_batch(
txn: LoggingTransaction,
) -> Optional[Sequence[Tuple[str, int]]]:
# Only fetch 250 rooms, so we don't fetch too many at once, even # Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events. # if those 250 rooms have less than batch_size state events.
sql = """ sql = """
@ -155,7 +194,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
TEMP_TABLE + "_rooms", TEMP_TABLE + "_rooms",
) )
txn.execute(sql) txn.execute(sql)
rooms_to_work_on = txn.fetchall() rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
if not rooms_to_work_on: if not rooms_to_work_on:
return None return None
@ -163,7 +202,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# Get how many are left to process, so we can give status on how # Get how many are left to process, so we can give status on how
# far we are in processing # far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
progress["remaining"] = txn.fetchone()[0] result = txn.fetchone()
assert result is not None
progress["remaining"] = result[0]
return rooms_to_work_on return rooms_to_work_on
@ -261,29 +302,33 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count return processed_event_count
async def _populate_user_directory_process_users(self, progress, batch_size): async def _populate_user_directory_process_users(
self, progress: JsonDict, batch_size: int
) -> int:
""" """
Add all local users to the user directory. Add all local users to the user directory.
""" """
def _get_next_batch(txn): def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
sql = "SELECT user_id FROM %s LIMIT %s" % ( sql = "SELECT user_id FROM %s LIMIT %s" % (
TEMP_TABLE + "_users", TEMP_TABLE + "_users",
str(batch_size), str(batch_size),
) )
txn.execute(sql) txn.execute(sql)
users_to_work_on = txn.fetchall() user_result = cast(List[Tuple[str]], txn.fetchall())
if not users_to_work_on: if not user_result:
return None return None
users_to_work_on = [x[0] for x in users_to_work_on] users_to_work_on = [x[0] for x in user_result]
# Get how many are left to process, so we can give status on how # Get how many are left to process, so we can give status on how
# far we are in processing # far we are in processing
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users" sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
txn.execute(sql) txn.execute(sql)
progress["remaining"] = txn.fetchone()[0] count_result = txn.fetchone()
assert count_result is not None
progress["remaining"] = count_result[0]
return users_to_work_on return users_to_work_on
@ -324,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on) return len(users_to_work_on)
async def is_room_world_readable_or_publicly_joinable(self, room_id): async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable""" """Check if the room is either world_readable or publically joinable"""
# Create a state filter that only queries join and history state event # Create a state filter that only queries join and history state event
@ -368,7 +413,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if not isinstance(avatar_url, str): if not isinstance(avatar_url, str):
avatar_url = None avatar_url = None
def _update_profile_in_user_dir_txn(txn): def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="user_directory", table="user_directory",
@ -435,7 +480,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id, other_user_id in user_id_tuples for user_id, other_user_id in user_id_tuples
], ],
value_names=(), value_names=(),
value_values=None, value_values=(),
desc="add_users_who_share_room", desc="add_users_who_share_room",
) )
@ -454,14 +499,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
key_names=["user_id", "room_id"], key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids], key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(), value_names=(),
value_values=None, value_values=(),
desc="add_users_in_public_rooms", desc="add_users_in_public_rooms",
) )
async def delete_all_from_user_dir(self) -> None: async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory""" """Delete the entire user directory"""
def _delete_all_from_user_dir_txn(txn): def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM user_directory") txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search") txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms") txn.execute("DELETE FROM users_in_public_rooms")
@ -473,7 +518,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
@cached() @cached()
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
return await self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
table="user_directory", table="user_directory",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms? # add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500 SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(
self,
database: DatabasePool,
db_conn: Connection,
hs: "HomeServer",
) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._prefer_local_users_in_search = ( self._prefer_local_users_in_search = (
@ -506,7 +556,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
self._server_name = hs.config.server.server_name self._server_name = hs.config.server.server_name
async def remove_from_user_dir(self, user_id: str) -> None: async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn): def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id} txn, table="user_directory", keyvalues={"user_id": user_id}
) )
@ -532,7 +582,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_from_user_dir", _remove_from_user_dir_txn "remove_from_user_dir", _remove_from_user_dir_txn
) )
async def get_users_in_dir_due_to_room(self, room_id): async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
"""Get all user_ids that are in the room directory because they're """Get all user_ids that are in the room directory because they're
in the given room_id in the given room_id
""" """
@ -565,7 +615,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
room_id room_id
""" """
def _remove_user_who_share_room_txn(txn): def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="users_who_share_private_rooms", table="users_who_share_private_rooms",
@ -586,7 +636,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_user_who_share_room", _remove_user_who_share_room_txn "remove_user_who_share_room", _remove_user_who_share_room_txn
) )
async def get_user_dir_rooms_user_is_in(self, user_id): async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
""" """
Returns the rooms that a user is in. Returns the rooms that a user is in.
@ -628,7 +678,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share. A set of room ID's that the users share.
""" """
def _get_shared_rooms_for_users_txn(txn): def _get_shared_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute( txn.execute(
""" """
SELECT p1.room_id SELECT p1.room_id
@ -669,7 +721,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
desc="get_user_directory_stream_pos", desc="get_user_directory_stream_pos",
) )
async def search_user_dir(self, user_id, search_term, limit): async def search_user_dir(
self, user_id: str, search_term: str, limit: int
) -> JsonDict:
"""Searches for users in directory """Searches for users in directory
Returns: Returns:
@ -705,7 +759,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# We allow manipulating the ranking algorithm by injecting statements # We allow manipulating the ranking algorithm by injecting statements
# based on config options. # based on config options.
additional_ordering_statements = [] additional_ordering_statements = []
ordering_arguments = () ordering_arguments: Tuple[str, ...] = ()
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term) full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@ -811,7 +865,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results} return {"limited": limited, "results": results}
def _parse_query_sqlite(search_term): def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form """Takes a plain unicode string from the user and converts it into a form
that can be passed to database. that can be passed to database.
We use this so that we can add prefix matching, which isn't something We use this so that we can add prefix matching, which isn't something
@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
return " & ".join("(%s* OR %s)" % (result, result) for result in results) return " & ".join("(%s* OR %s)" % (result, result) for result in results)
def _parse_query_postgres(search_term): def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
"""Takes a plain unicode string from the user and converts it into a form """Takes a plain unicode string from the user and converts it into a form
that can be passed to database. that can be passed to database.
We use this so that we can add prefix matching, which isn't something We use this so that we can add prefix matching, which isn't something

View file

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Tuple
from unittest.mock import Mock from unittest.mock import Mock
from urllib.parse import quote from urllib.parse import quote
@ -325,7 +326,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
r.add((i["user_id"], i["other_user_id"], i["room_id"])) r.add((i["user_id"], i["other_user_id"], i["room_id"]))
return r return r
def get_users_in_public_rooms(self): def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
r = self.get_success( r = self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id") "users_in_public_rooms", None, ("user_id", "room_id")
@ -336,7 +337,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
retval.append((i["user_id"], i["room_id"])) retval.append((i["user_id"], i["room_id"]))
return retval return retval
def get_users_who_share_private_rooms(self): def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
return self.get_success( return self.get_success(
self.store.db_pool.simple_select_list( self.store.db_pool.simple_select_list(
"users_who_share_private_rooms", "users_who_share_private_rooms",