Add type hints to synapse/storage/databases/main/events_bg_updates.py (#11654)

This commit is contained in:
Dirk Klimpel 2021-12-30 13:22:31 +01:00 committed by GitHub
parent 2c7f5e74e5
commit 07a3b5daba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 30 deletions

1
changelog.d/11654.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to storage classes.

View file

@ -28,7 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py |synapse/storage/databases/main/monthly_active_users.py
@ -200,6 +199,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.event_push_actions] [mypy-synapse.storage.databases.main.event_push_actions]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_bg_updates]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker] [mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
import attr import attr
@ -240,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
################################################################################ ################################################################################
async def _background_reindex_fields_sender(self, progress, batch_size): async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
def reindex_txn(txn): def reindex_txn(txn: LoggingTransaction) -> int:
sql = ( sql = (
"SELECT stream_ordering, event_id, json FROM events" "SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)" " INNER JOIN event_json USING (event_id)"
@ -307,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result return result
async def _background_reindex_origin_server_ts(self, progress, batch_size): async def _background_reindex_origin_server_ts(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
def reindex_search_txn(txn): def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = ( sql = (
"SELECT stream_ordering, event_id FROM events" "SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?" " WHERE ? <= stream_ordering AND stream_ordering < ?"
@ -381,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result return result
async def _cleanup_extremities_bg_update(self, progress, batch_size): async def _cleanup_extremities_bg_update(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to clean out extremities that should have been """Background update to clean out extremities that should have been
deleted previously. deleted previously.
@ -402,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# have any descendants, but if they do then we should delete those # have any descendants, but if they do then we should delete those
# extremities. # extremities.
def _cleanup_extremities_bg_update_txn(txn): def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
# The set of extremity event IDs that we're checking this round # The set of extremity event IDs that we're checking this round
original_set = set() original_set = set()
# A dict[str, set[str]] of event ID to their prev events. # A dict[str, Set[str]] of event ID to their prev events.
graph = {} graph: Dict[str, Set[str]] = {}
# The set of descendants of the original set that are not rejected # The set of descendants of the original set that are not rejected
# nor soft-failed. Ancestors of these events should be removed # nor soft-failed. Ancestors of these events should be removed
@ -536,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
room_ids = {row["room_id"] for row in rows} room_ids = {row["room_id"] for row in rows}
for room_id in room_ids: for room_id in room_ids:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,) self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
) )
self.db_pool.simple_delete_many_txn( self.db_pool.simple_delete_many_txn(
@ -558,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
) )
def _drop_table_txn(txn): def _drop_table_txn(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE _extremities_to_check") txn.execute("DROP TABLE _extremities_to_check")
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -567,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return num_handled return num_handled
async def _redactions_received_ts(self, progress, batch_size): async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
"""Handles filling out the `received_ts` column in redactions.""" """Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "") last_event_id = progress.get("last_event_id", "")
def _redactions_received_ts_txn(txn): def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
# Fetch the set of event IDs that we want to update # Fetch the set of event IDs that we want to update
sql = """ sql = """
SELECT event_id FROM redactions SELECT event_id FROM redactions
@ -622,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return count return count
async def _event_fix_redactions_bytes(self, progress, batch_size): async def _event_fix_redactions_bytes(
self, progress: JsonDict, batch_size: int
) -> int:
"""Undoes hex encoded censored redacted event JSON.""" """Undoes hex encoded censored redacted event JSON."""
def _event_fix_redactions_bytes_txn(txn): def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
# This update is quite fast due to new index. # This update is quite fast due to new index.
txn.execute( txn.execute(
""" """
@ -650,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return 1 return 1
async def _event_store_labels(self, progress, batch_size): async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store labels for existing events.""" """Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "") last_event_id = progress.get("last_event_id", "")
def _event_store_labels_txn(txn): def _event_store_labels_txn(txn: LoggingTransaction) -> int:
txn.execute( txn.execute(
""" """
SELECT event_id, json FROM event_json SELECT event_id, json FROM event_json
@ -754,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
), ),
) )
return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore return cast(
List[Tuple[str, str, JsonDict, bool, bool]],
[(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
)
results = await self.db_pool.runInteraction( results = await self.db_pool.runInteraction(
desc="_rejected_events_metadata_get", func=get_rejected_events desc="_rejected_events_metadata_get", func=get_rejected_events
@ -912,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
def _calculate_chain_cover_txn( def _calculate_chain_cover_txn(
self, self,
txn: Cursor, txn: LoggingTransaction,
last_room_id: str, last_room_id: str,
last_depth: int, last_depth: int,
last_stream: int, last_stream: int,
@ -1023,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index( PersistEventsStore._add_chain_cover_index(
txn, txn,
self.db_pool, self.db_pool,
self.event_chain_id_gen, self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, cast(Dict[str, Sequence[str]], event_to_auth_chain),
) )
return _CalculateChainCover( return _CalculateChainCover(
@ -1046,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
""" """
current_event_id = progress.get("current_event_id", "") current_event_id = progress.get("current_event_id", "")
def purged_chain_cover_txn(txn) -> int: def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
# The event ID from events will be null if the chain ID / sequence # The event ID from events will be null if the chain ID / sequence
# number points to a purged event. # number points to a purged event.
sql = """ sql = """
@ -1181,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# Iterate the parent IDs and invalidate caches. # Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}: for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,) cache_tuple = (parent_id,)
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuple txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_aggregation_groups_for_event, cache_tuple txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuple txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
) )
if results: if results:
@ -1220,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
""" """
batch_size = max(batch_size, 1) batch_size = max(batch_size, 1)
def process(txn: Cursor) -> int: def process(txn: LoggingTransaction) -> int:
last_stream = progress.get("last_stream", -(1 << 31)) last_stream = progress.get("last_stream", -(1 << 31))
txn.execute( txn.execute(
""" """