Type hint the constructors of the data store classes (#11555)

This commit is contained in:
Sean Quah 2021-12-13 17:05:00 +00:00 committed by GitHub
parent 1abfb15f07
commit 5305a5e881
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 351 additions and 87 deletions

1
changelog.d/11555.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to storage classes.

View file

@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Optional
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@ -27,7 +27,12 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[

View file

@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache
@ -25,7 +25,12 @@ if TYPE_CHECKING:
class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.client_ip_last_seen: LruCache[tuple, int] = LruCache(

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -27,7 +27,12 @@ if TYPE_CHECKING:
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.hs = hs

View file

@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
@ -58,7 +58,12 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()

View file

@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore
@ -24,7 +24,12 @@ if TYPE_CHECKING:
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -26,7 +26,12 @@ if TYPE_CHECKING:
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.hs = hs

View file

@ -17,10 +17,8 @@ import logging
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine

View file

@ -175,7 +175,7 @@ class LoggingDatabaseConnection:
def rollback(self) -> None:
self.conn.rollback()
def __enter__(self) -> "Connection":
def __enter__(self) -> "LoggingDatabaseConnection":
self.conn.__enter__()
return self

View file

@ -18,7 +18,7 @@ import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@ -129,7 +129,12 @@ class DataStore(
LockStore,
SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine

View file

@ -24,9 +24,8 @@ from synapse.appservice import (
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
@ -58,7 +57,12 @@ def _make_exclusive_regex(
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.services_cache = load_appservices(
hs.hostname, hs.config.appservice.app_service_config_files
)

View file

@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()

View file

@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
@ -31,7 +35,12 @@ logger = logging.getLogger(__name__)
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if (

View file

@ -26,7 +26,6 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
from synapse.storage.types import Connection
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):
class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.server.user_ips_max_age
@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](

View file

@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(

View file

@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@ -953,7 +959,12 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@ -1085,7 +1096,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies

View file

@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@ -1514,7 +1523,12 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(

View file

@ -20,7 +20,11 @@ from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@ -82,7 +86,12 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
@ -910,7 +919,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(

View file

@ -41,10 +41,13 @@ from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
@ -95,7 +98,7 @@ class PersistEventsStore:
hs: "HomeServer",
db: DatabasePool,
main_data_store: "DataStore",
db_conn: Connection,
db_conn: LoggingDatabaseConnection,
):
self.hs = hs
self.db_pool = db

View file

@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
@ -83,7 +84,12 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(

View file

@ -19,8 +19,7 @@ from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import JsonDict
from synapse.util import json_encoder
@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict):
class GroupServerWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
database.updates.register_background_index_update(
update_name="local_group_updates_index",
index_name="local_group_updates_stream_id_index",

View file

@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.types import Connection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.util import Clock
from synapse.util.stringutils import random_string
@ -54,7 +57,12 @@ class LockStore(SQLBaseStore):
`last_renewed_ts` column with the current time.
"""
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._reactor = hs.get_reactor()

View file

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes

View file

@ -16,7 +16,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
make_in_list_sql_clause,
)
from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email
@ -31,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@ -213,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

View file

@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
@ -81,7 +81,12 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:

View file

@ -32,7 +32,11 @@ from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
@ -47,7 +51,12 @@ logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):

View file

@ -24,7 +24,11 @@ from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict, ThirdPartyInstanceID
@ -72,7 +76,12 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.config = hs.config
@ -1050,7 +1059,12 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
class RoomBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.config = hs.config
@ -1435,7 +1449,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.config = hs.config

View file

@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:

View file

@ -20,7 +20,11 @@ from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@ -105,7 +109,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search:
@ -358,7 +367,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):

View file

@ -22,7 +22,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
@ -56,7 +60,12 @@ class _GetStateGroupDelta(
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
@ -349,7 +358,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@ -536,5 +550,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

View file

@ -24,7 +24,7 @@ from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@ -96,7 +96,12 @@ class UserSortOrder(Enum):
class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname

View file

@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
@ -339,7 +340,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()

View file

@ -22,7 +22,11 @@ from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@ -71,7 +75,12 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:

View file

@ -32,11 +32,14 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(
self,
database: DatabasePool,
db_conn: Connection,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
) -> None:
super().__init__(database, db_conn, hs)