Add type hints to synapse/storage/databases/main/events_bg_updates.py (#11654)

This commit is contained in:
Dirk Klimpel 2021-12-30 13:22:31 +01:00 committed by GitHub
parent 2c7f5e74e5
commit 07a3b5daba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 30 deletions

1
changelog.d/11654.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to storage classes.

View file

@ -28,7 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
@ -200,6 +199,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.event_push_actions]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_bg_updates]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
import attr
@ -240,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
################################################################################
async def _background_reindex_fields_sender(self, progress, batch_size):
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
def reindex_txn(txn):
def reindex_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)"
@ -307,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
async def _background_reindex_origin_server_ts(self, progress, batch_size):
async def _background_reindex_origin_server_ts(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
def reindex_search_txn(txn):
def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
@ -381,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
async def _cleanup_extremities_bg_update(self, progress, batch_size):
async def _cleanup_extremities_bg_update(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to clean out extremities that should have been
deleted previously.
@ -402,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# have any descendants, but if they do then we should delete those
# extremities.
def _cleanup_extremities_bg_update_txn(txn):
def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
# The set of extremity event IDs that we're checking this round
original_set = set()
# A dict[str, set[str]] of event ID to their prev events.
graph = {}
# A dict[str, Set[str]] of event ID to their prev events.
graph: Dict[str, Set[str]] = {}
# The set of descendants of the original set that are not rejected
# nor soft-failed. Ancestors of these events should be removed
@ -536,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
)
self.db_pool.simple_delete_many_txn(
@ -558,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
def _drop_table_txn(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE _extremities_to_check")
await self.db_pool.runInteraction(
@ -567,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return num_handled
async def _redactions_received_ts(self, progress, batch_size):
async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
"""Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "")
def _redactions_received_ts_txn(txn):
def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
# Fetch the set of event IDs that we want to update
sql = """
SELECT event_id FROM redactions
@ -622,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return count
async def _event_fix_redactions_bytes(self, progress, batch_size):
async def _event_fix_redactions_bytes(
self, progress: JsonDict, batch_size: int
) -> int:
"""Undoes hex encoded censored redacted event JSON."""
def _event_fix_redactions_bytes_txn(txn):
def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
# This update is quite fast due to new index.
txn.execute(
"""
@ -650,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return 1
async def _event_store_labels(self, progress, batch_size):
async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
def _event_store_labels_txn(txn):
def _event_store_labels_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT event_id, json FROM event_json
@ -754,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
),
)
return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
return cast(
List[Tuple[str, str, JsonDict, bool, bool]],
[(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
)
results = await self.db_pool.runInteraction(
desc="_rejected_events_metadata_get", func=get_rejected_events
@ -912,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
def _calculate_chain_cover_txn(
self,
txn: Cursor,
txn: LoggingTransaction,
last_room_id: str,
last_depth: int,
last_stream: int,
@ -1023,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
self.event_chain_id_gen,
self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id,
event_to_types,
event_to_auth_chain,
cast(Dict[str, Sequence[str]], event_to_auth_chain),
)
return _CalculateChainCover(
@ -1046,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"""
current_event_id = progress.get("current_event_id", "")
def purged_chain_cover_txn(txn) -> int:
def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
# The event ID from events will be null if the chain ID / sequence
# number points to a purged event.
sql = """
@ -1181,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# Iterate the parent IDs and invalidate caches.
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
self._invalidate_cache_and_stream(
txn, self.get_relations_for_event, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream(
txn, self.get_aggregation_groups_for_event, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream(
txn, self.get_thread_summary, cache_tuple
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
)
if results:
@ -1220,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"""
batch_size = max(batch_size, 1)
def process(txn: Cursor) -> int:
def process(txn: LoggingTransaction) -> int:
last_stream = progress.get("last_stream", -(1 << 31))
txn.execute(
"""