Rename database classes to make some sense (#8033)

This commit is contained in:
Erik Johnston 2020-08-05 21:38:57 +01:00 committed by GitHub
parent 0a86850ba3
commit a7bdf98d01
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
337 changed files with 1408 additions and 1323 deletions

1
changelog.d/8033.misc Normal file
View file

@ -0,0 +1 @@
Rename storage layer objects to be more sensible.

View file

@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server.
The directory info is stored in various tables, which can (typically after The directory info is stored in various tables, which can (typically after
DB corruption) get stale or out of sync. If this happens, for now the DB corruption) get stale or out of sync. If this happens, for now the
solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql) solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql)
and then restart synapse. This should then start a background task to and then restart synapse. This should then start a background task to
flush the current tables and regenerate the directory. flush the current tables and regenerate the directory.

View file

@ -40,7 +40,7 @@ class MockHomeserver(HomeServer):
config.server_name, reactor=reactor, config=config, **kwargs config.server_name, reactor=reactor, config=config, **kwargs
) )
self.version_string = "Synapse/"+get_version_string(synapse) self.version_string = "Synapse/" + get_version_string(synapse)
if __name__ == "__main__": if __name__ == "__main__":
@ -86,7 +86,7 @@ if __name__ == "__main__":
store = hs.get_datastore() store = hs.get_datastore()
async def run_background_updates(): async def run_background_updates():
await store.db.updates.run_background_updates(sleep=False) await store.db_pool.updates.run_background_updates(sleep=False)
# Stop the reactor to exit the script once every background update is run. # Stop the reactor to exit the script once every background update is run.
reactor.stop() reactor.stop()

View file

@ -35,31 +35,29 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
run_in_background, run_in_background,
) )
from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.data_stores.main.deviceinbox import ( from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
DeviceInboxBackgroundUpdateStore, from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
) from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.events_bg_updates import (
from synapse.storage.data_stores.main.events_bg_updates import (
EventsBackgroundUpdatesStore, EventsBackgroundUpdatesStore,
) )
from synapse.storage.data_stores.main.media_repository import ( from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore, MediaRepositoryBackgroundUpdateStore,
) )
from synapse.storage.data_stores.main.registration import ( from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore, RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart, find_max_generated_user_id_localpart,
) )
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore
from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.data_stores.main.user_directory import ( from synapse.storage.databases.main.user_directory import (
UserDirectoryBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore,
) )
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock from synapse.util import Clock
@ -175,14 +173,14 @@ class Store(
StatsStore, StatsStore,
): ):
def execute(self, f, *args, **kwargs): def execute(self, f, *args, **kwargs):
return self.db.runInteraction(f.__name__, f, *args, **kwargs) return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args): def execute_sql(self, sql, *args):
def r(txn): def r(txn):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
return self.db.runInteraction("execute_sql", r) return self.db_pool.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows): def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( sql = "INSERT INTO %s (%s) VALUES (%s)" % (
@ -227,7 +225,7 @@ class Porter(object):
async def setup_table(self, table): async def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
row = await self.postgres_store.db.simple_select_one( row = await self.postgres_store.db_pool.simple_select_one(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"), retcols=("forward_rowid", "backward_rowid"),
@ -244,7 +242,7 @@ class Porter(object):
) = await self._setup_sent_transactions() ) = await self._setup_sent_transactions()
backward_chunk = 0 backward_chunk = 0
else: else:
await self.postgres_store.db.simple_insert( await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={
"table_name": table, "table_name": table,
@ -274,7 +272,7 @@ class Porter(object):
await self.postgres_store.execute(delete_all) await self.postgres_store.execute(delete_all)
await self.postgres_store.db.simple_insert( await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
) )
@ -318,7 +316,7 @@ class Porter(object):
if table == "user_directory_stream_pos": if table == "user_directory_stream_pos":
# We need to make sure there is a single row, `(X, null), as that is # We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there. # what synapse expects to be there.
await self.postgres_store.db.simple_insert( await self.postgres_store.db_pool.simple_insert(
table=table, values={"stream_id": None} table=table, values={"stream_id": None}
) )
self.progress.update(table, table_size) # Mark table as done self.progress.update(table, table_size) # Mark table as done
@ -359,7 +357,7 @@ class Porter(object):
return headers, forward_rows, backward_rows return headers, forward_rows, backward_rows
headers, frows, brows = await self.sqlite_store.db.runInteraction( headers, frows, brows = await self.sqlite_store.db_pool.runInteraction(
"select", r "select", r
) )
@ -375,7 +373,7 @@ class Porter(object):
def insert(txn): def insert(txn):
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
self.postgres_store.db.simple_update_one_txn( self.postgres_store.db_pool.simple_update_one_txn(
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
@ -413,7 +411,7 @@ class Porter(object):
return headers, rows return headers, rows
headers, rows = await self.sqlite_store.db.runInteraction("select", r) headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r)
if rows: if rows:
forward_chunk = rows[-1][0] + 1 forward_chunk = rows[-1][0] + 1
@ -451,7 +449,7 @@ class Porter(object):
], ],
) )
self.postgres_store.db.simple_update_one_txn( self.postgres_store.db_pool.simple_update_one_txn(
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": "event_search"}, keyvalues={"table_name": "event_search"},
@ -494,7 +492,7 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version db_conn, allow_outdated_version=allow_outdated_version
) )
prepare_database(db_conn, engine, config=self.hs_config) prepare_database(db_conn, engine, config=self.hs_config)
store = Store(Database(hs, db_config, engine), db_conn, hs) store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
db_conn.commit() db_conn.commit()
return store return store
@ -502,7 +500,7 @@ class Porter(object):
async def run_background_updates_on_postgres(self): async def run_background_updates_on_postgres(self):
# Manually apply all background updates on the PostgreSQL database. # Manually apply all background updates on the PostgreSQL database.
postgres_ready = ( postgres_ready = (
await self.postgres_store.db.updates.has_completed_background_updates() await self.postgres_store.db_pool.updates.has_completed_background_updates()
) )
if not postgres_ready: if not postgres_ready:
@ -511,9 +509,9 @@ class Porter(object):
self.progress.set_state("Running background updates on PostgreSQL") self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready: while not postgres_ready:
await self.postgres_store.db.updates.do_next_background_update(100) await self.postgres_store.db_pool.updates.do_next_background_update(100)
postgres_ready = await ( postgres_ready = await (
self.postgres_store.db.updates.has_completed_background_updates() self.postgres_store.db_pool.updates.has_completed_background_updates()
) )
async def run(self): async def run(self):
@ -534,7 +532,7 @@ class Porter(object):
# Check if all background updates are done, abort if not. # Check if all background updates are done, abort if not.
updates_complete = ( updates_complete = (
await self.sqlite_store.db.updates.has_completed_background_updates() await self.sqlite_store.db_pool.updates.has_completed_background_updates()
) )
if not updates_complete: if not updates_complete:
end_error = ( end_error = (
@ -576,22 +574,24 @@ class Porter(object):
) )
try: try:
await self.postgres_store.db.runInteraction("alter_table", alter_table) await self.postgres_store.db_pool.runInteraction(
"alter_table", alter_table
)
except Exception: except Exception:
# On Error Resume Next # On Error Resume Next
pass pass
await self.postgres_store.db.runInteraction( await self.postgres_store.db_pool.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
) )
# Step 2. Get tables. # Step 2. Get tables.
self.progress.set_state("Fetching tables") self.progress.set_state("Fetching tables")
sqlite_tables = await self.sqlite_store.db.simple_select_onecol( sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol(
table="sqlite_master", keyvalues={"type": "table"}, retcol="name" table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
) )
postgres_tables = await self.postgres_store.db.simple_select_onecol( postgres_tables = await self.postgres_store.db_pool.simple_select_onecol(
table="information_schema.tables", table="information_schema.tables",
keyvalues={}, keyvalues={},
retcol="distinct table_name", retcol="distinct table_name",
@ -692,7 +692,7 @@ class Porter(object):
return headers, [r for r in rows if r[ts_ind] < yesterday] return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = await self.sqlite_store.db.runInteraction("select", r) headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r)
rows = self._convert_rows("sent_transactions", headers, rows) rows = self._convert_rows("sent_transactions", headers, rows)
@ -725,7 +725,7 @@ class Porter(object):
next_chunk = await self.sqlite_store.execute(get_start_id) next_chunk = await self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk) next_chunk = max(max_inserted_rowid + 1, next_chunk)
await self.postgres_store.db.simple_insert( await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={
"table_name": "sent_transactions", "table_name": "sent_transactions",
@ -794,14 +794,14 @@ class Porter(object):
next_id = curr_id + 1 next_id = curr_id + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
def _setup_user_id_seq(self): def _setup_user_id_seq(self):
def r(txn): def r(txn):
next_id = find_max_generated_user_id_localpart(txn) + 1 next_id = find_max_generated_user_id_localpart(txn) + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db.runInteraction("setup_user_id_seq", r) return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
############################################## ##############################################

View file

@ -268,7 +268,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# It is now safe to start your Synapse. # It is now safe to start your Synapse.
hs.start_listening(listeners) hs.start_listening(listeners)
hs.get_datastore().db.start_profiling() hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start() hs.get_pusherpool().start()
setup_sentry(hs) setup_sentry(hs)

View file

@ -125,15 +125,15 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.data_stores.main.censor_events import CensorEventsStore from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore
from synapse.storage.data_stores.main.monthly_active_users import ( from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore, MonthlyActiveUsersWorkerStore,
) )
from synapse.storage.data_stores.main.presence import UserPresenceState from synapse.storage.databases.main.presence import UserPresenceState
from synapse.storage.data_stores.main.search import SearchWorkerStore from synapse.storage.databases.main.search import SearchWorkerStore
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree

View file

@ -441,7 +441,7 @@ def setup(config_options):
_base.start(hs, config.listeners) _base.start(hs, config.listeners)
hs.get_datastore().db.updates.start_doing_background_updates() hs.get_datastore().db_pool.updates.start_doing_background_updates()
except Exception: except Exception:
# Print the exception and bail out. # Print the exception and bail out.
print("Error during startup:", file=sys.stderr) print("Error during startup:", file=sys.stderr)
@ -551,8 +551,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
# #
# This only reports info about the *main* database. # This only reports info about the *main* database.
stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
stats["database_server_version"] = hs.get_datastore().db.engine.server_version stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try: try:

View file

@ -100,7 +100,10 @@ class DatabaseConnectionConfig:
self.name = name self.name = name
self.config = db_config self.config = db_config
self.data_stores = data_stores
# The `data_stores` config is actually talking about `databases` (we
# changed the name).
self.databases = data_stores
class DatabaseConfig(Config): class DatabaseConfig(Config):

View file

@ -23,7 +23,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap from synapse.types import StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.storage.data_stores.main import DataStore from synapse.storage.databases.main import DataStore
@attr.s(slots=True) @attr.s(slots=True)

View file

@ -71,7 +71,7 @@ from synapse.replication.http.federation import (
) )
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room

View file

@ -45,7 +45,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
Collection, Collection,

View file

@ -38,7 +38,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage.data_stores.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -319,7 +319,7 @@ class PresenceHandler(BasePresenceHandler):
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
# If the DB pool has already terminated, don't try updating # If the DB pool has already terminated, don't try updating
if not self.store.db.is_running(): if not self.store.db_pool.is_running():
return return
logger.info( logger.info(

View file

@ -219,7 +219,7 @@ class ModuleApi(object):
Returns: Returns:
Deferred[object]: result of func Deferred[object]: result of func
""" """
return self._store.db.runInteraction(desc, func, *args, **kwargs) return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
def complete_sso_login( def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str

View file

@ -16,8 +16,8 @@
import logging import logging
from typing import Optional from typing import Optional
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore): class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs) super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = MultiWriterIdGenerator( self._cache_id_gen = MultiWriterIdGenerator(

View file

@ -17,13 +17,13 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.tags import TagsWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.database import Database from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker( self._account_data_id_gen = SlavedIdTracker(
db_conn, db_conn,
"account_data", "account_data",

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.appservice import ( from synapse.storage.databases.main.appservice import (
ApplicationServiceTransactionWorkerStore, ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore, ApplicationServiceWorkerStore,
) )

View file

@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore): class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs) super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(

View file

@ -16,14 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker( self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id" db_conn, "device_inbox", "stream_id"

View file

@ -16,14 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.database import Database from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceStore, self).__init__(database, db_conn, hs) super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.directory import DirectoryWorkerStore from synapse.storage.databases.main.directory import DirectoryWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore

View file

@ -15,18 +15,18 @@
# limitations under the License. # limitations under the License.
import logging import logging
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.event_push_actions import ( from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
) )
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.relations import RelationsWorkerStore from synapse.storage.databases.main.relations import RelationsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -55,11 +55,11 @@ class SlavedEventStore(
RelationsWorkerStore, RelationsWorkerStore,
BaseSlavedStore, BaseSlavedStore,
): ):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedEventStore, self).__init__(database, db_conn, hs) super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"current_state_delta_stream", "current_state_delta_stream",
entity_column="room_id", entity_column="room_id",

View file

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.filtering import FilteringStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore): class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedFilteringStore, self).__init__(database, db_conn, hs) super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired # Filters are immutable so this cache doesn't need to be expired

View file

@ -16,13 +16,13 @@
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.keys import KeyStore from synapse.storage.databases.main.keys import KeyStore
# KeyStore isn't really safe to use from a worker, but for now we do so and hope that # KeyStore isn't really safe to use from a worker, but for now we do so and hope that
# the races it creates aren't too bad. # the races it creates aren't too bad.

View file

@ -15,8 +15,8 @@
from synapse.replication.tcp.streams import PresenceStream from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.data_stores.main.presence import PresenceStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.presence import PresenceStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore): class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs) super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.storage.data_stores.main.profile import ProfileWorkerStore from synapse.storage.databases.main.profile import ProfileWorkerStore
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import PushRulesStream from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore from .events import SlavedEventStore

View file

@ -15,15 +15,15 @@
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import PushersStream from synapse.replication.tcp.streams import PushersStream
from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.pusher import PusherWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPusherStore, self).__init__(database, db_conn, hs) super(SlavedPusherStore, self).__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker( self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]

View file

@ -15,15 +15,15 @@
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor # We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id # needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker( self._receipts_id_gen = SlavedIdTracker(

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.registration import RegistrationWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore

View file

@ -14,15 +14,15 @@
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import PublicRoomsStream from synapse.replication.tcp.streams import PublicRoomsStream
from synapse.storage.data_stores.main.room import RoomWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.room import RoomWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore): class RoomStore(RoomWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs) super(RoomStore, self).__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker( self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.data_stores.main.transactions import TransactionStore from synapse.storage.databases.main.transactions import TransactionStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore

View file

@ -31,7 +31,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin, assert_user_is_admin,
historical_admin_path_patterns, historical_admin_path_patterns,
) )
from synapse.storage.data_stores.main.room import RoomSortOrder from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.types import RoomAlias, RoomID, UserID, create_requester
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -586,7 +586,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Running url preview cache expiry") logger.debug("Running url preview cache expiry")
if not (await self.store.db.updates.has_completed_background_updates()): if not (await self.store.db_pool.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry") logger.info("Still running DB updates; skipping expiry")
return return

View file

@ -105,7 +105,7 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender, WorkerServerNoticesSender,
) )
from synapse.state import StateHandler, StateResolutionHandler from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore, DataStores, Storage from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
@ -280,7 +280,7 @@ class HomeServer(object):
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.start_time = int(self.get_clock().time()) self.start_time = int(self.get_clock().time())
self.datastores = DataStores(self.DATASTORE_CLASS, self) self.datastores = Databases(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.") logger.info("Finished setting up.")
def setup_master(self): def setup_master(self):

View file

@ -28,7 +28,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap from synapse.types import StateMap
from synapse.util import Clock from synapse.util import Clock

View file

@ -17,18 +17,19 @@
""" """
The storage layer is split up into multiple parts to allow Synapse to run The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple against different configurations of databases (e.g. single or multiple
databases). The `Database` class represents a single physical database. The databases). The `DatabasePool` class represents connections to a single physical
`data_stores` are classes that talk directly to a `Database` instance and have database. The `databases` are classes that talk directly to a `DatabasePool`
associated schemas, background updates, etc. On top of those there are classes instance and have associated schemas, background updates, etc. On top of those
that provide high level interfaces that combine calls to multiple `data_stores`. there are classes that provide high level interfaces that combine calls to
multiple `databases`.
There are also schemas that get applied to every database, regardless of the There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`. stored in `synapse.storage.schema`.
""" """
from synapse.storage.data_stores import DataStores from synapse.storage.databases import Databases
from synapse.storage.data_stores.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage from synapse.storage.state import StateGroupStorage
@ -40,7 +41,7 @@ class Storage(object):
"""The high level interfaces for talking to various storage layers. """The high level interfaces for talking to various storage layers.
""" """
def __init__(self, hs, stores: DataStores): def __init__(self, hs, stores: Databases):
# We include the main data store here mainly so that we don't have to # We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level # rewrite all the existing code to split it into high vs low level
# interfaces. # interfaces.

View file

@ -23,7 +23,7 @@ from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id from synapse.types import Collection, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,11 +37,11 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database). per data store (and not one per physical database).
""" """
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine
self.db = database self.db_pool = database
self.rand = random.SystemRandom() self.rand = random.SystemRandom()
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):

View file

@ -88,7 +88,7 @@ class BackgroundUpdater(object):
def __init__(self, hs, database): def __init__(self, hs, database):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.db = database self.db_pool = database
# if a background update is currently running, its name. # if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str] self._current_background_update = None # type: Optional[str]
@ -139,7 +139,7 @@ class BackgroundUpdater(object):
# otherwise, check if there are updates to be run. This is important, # otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates # as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen. # itself, but still wants to wait for them to happen.
updates = await self.db.simple_select_onecol( updates = await self.db_pool.simple_select_onecol(
"background_updates", "background_updates",
keyvalues=None, keyvalues=None,
retcol="1", retcol="1",
@ -160,7 +160,7 @@ class BackgroundUpdater(object):
if update_name == self._current_background_update: if update_name == self._current_background_update:
return False return False
update_exists = await self.db.simple_select_one_onecol( update_exists = await self.db_pool.simple_select_one_onecol(
"background_updates", "background_updates",
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},
retcol="1", retcol="1",
@ -189,10 +189,10 @@ class BackgroundUpdater(object):
ORDER BY ordering, update_name ORDER BY ordering, update_name
""" """
) )
return self.db.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
if not self._current_background_update: if not self._current_background_update:
all_pending_updates = await self.db.runInteraction( all_pending_updates = await self.db_pool.runInteraction(
"background_updates", get_background_updates_txn, "background_updates", get_background_updates_txn,
) )
if not all_pending_updates: if not all_pending_updates:
@ -243,7 +243,7 @@ class BackgroundUpdater(object):
else: else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
progress_json = await self.db.simple_select_one_onecol( progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates", "background_updates",
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},
retcol="progress_json", retcol="progress_json",
@ -402,7 +402,7 @@ class BackgroundUpdater(object):
logger.debug("[SQL] %s", sql) logger.debug("[SQL] %s", sql)
c.execute(sql) c.execute(sql)
if isinstance(self.db.engine, engines.PostgresEngine): if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql runner = create_index_psql
elif psql_only: elif psql_only:
runner = None runner = None
@ -413,7 +413,7 @@ class BackgroundUpdater(object):
def updater(progress, batch_size): def updater(progress, batch_size):
if runner is not None: if runner is not None:
logger.info("Adding index %s to %s", index_name, table) logger.info("Adding index %s to %s", index_name, table)
yield self.db.runWithConnection(runner) yield self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name) yield self._end_background_update(update_name)
return 1 return 1
@ -433,7 +433,7 @@ class BackgroundUpdater(object):
% update_name % update_name
) )
self._current_background_update = None self._current_background_update = None
return self.db.simple_delete_one( return self.db_pool.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name} "background_updates", keyvalues={"update_name": update_name}
) )
@ -445,7 +445,7 @@ class BackgroundUpdater(object):
progress: The progress of the update. progress: The progress of the update.
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"background_update_progress", "background_update_progress",
self._background_update_progress_txn, self._background_update_progress_txn,
update_name, update_name,
@ -463,7 +463,7 @@ class BackgroundUpdater(object):
progress_json = json.dumps(progress) progress_json = json.dumps(progress)
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
"background_updates", "background_updates",
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},

View file

@ -279,7 +279,7 @@ class PerformanceCounters(object):
return top_n_counters return top_n_counters
class Database(object): class DatabasePool(object):
"""Wraps a single physical database and connection pool. """Wraps a single physical database and connection pool.
A single database may be used by multiple data stores. A single database may be used by multiple data stores.

View file

@ -15,17 +15,17 @@
import logging import logging
from synapse.storage.data_stores.main.events import PersistEventsStore from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.data_stores.state import StateGroupDataStore from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.database import Database, make_conn from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DataStores(object): class Databases(object):
"""The various data stores. """The various databases.
These are low level interfaces to physical databases. These are low level interfaces to physical databases.
@ -51,12 +51,12 @@ class DataStores(object):
engine.check_database(db_conn) engine.check_database(db_conn)
prepare_database( prepare_database(
db_conn, engine, hs.config, data_stores=database_config.data_stores, db_conn, engine, hs.config, databases=database_config.databases,
) )
database = Database(hs, database_config, engine) database = DatabasePool(hs, database_config, engine)
if "main" in database_config.data_stores: if "main" in database_config.databases:
logger.info("Starting 'main' data store") logger.info("Starting 'main' data store")
# Sanity check we don't try and configure the main store on # Sanity check we don't try and configure the main store on
@ -73,7 +73,7 @@ class DataStores(object):
hs, database, self.main hs, database, self.main
) )
if "state" in database_config.data_stores: if "state" in database_config.databases:
logger.info("Starting 'state' data store") logger.info("Starting 'state' data store")
# Sanity check we don't try and configure the state store on # Sanity check we don't try and configure the state store on

View file

@ -21,7 +21,7 @@ import time
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
IdGenerator, IdGenerator,
@ -119,7 +119,7 @@ class DataStore(
CacheInvalidationWorkerStore, CacheInvalidationWorkerStore,
ServerMetricsStore, ServerMetricsStore,
): ):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine
@ -174,7 +174,7 @@ class DataStore(
self._presence_on_startup = self._get_active_presence(db_conn) self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self.db.get_cache_dict( presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
db_conn, db_conn,
"presence_stream", "presence_stream",
entity_column="user_id", entity_column="user_id",
@ -188,7 +188,7 @@ class DataStore(
) )
max_device_inbox_id = self._device_inbox_id_gen.get_current_token() max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"device_inbox", "device_inbox",
entity_column="user_id", entity_column="user_id",
@ -203,7 +203,7 @@ class DataStore(
) )
# The federation outbox and the local device inbox uses the same # The federation outbox and the local device inbox uses the same
# stream_id generator. # stream_id generator.
device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"device_federation_outbox", "device_federation_outbox",
entity_column="destination", entity_column="destination",
@ -229,7 +229,7 @@ class DataStore(
) )
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"current_state_delta_stream", "current_state_delta_stream",
entity_column="room_id", entity_column="room_id",
@ -243,7 +243,7 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill, prefilled_cache=curr_state_delta_prefill,
) )
_group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"local_group_updates", "local_group_updates",
entity_column="user_id", entity_column="user_id",
@ -282,7 +282,7 @@ class DataStore(
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,)) txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
txn.close() txn.close()
for row in rows: for row in rows:
@ -295,7 +295,9 @@ class DataStore(
Counts the number of users who used this homeserver in the last 24 hours. Counts the number of users who used this homeserver in the last 24 hours.
""" """
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
return self.db.runInteraction("count_daily_users", self._count_users, yesterday) return self.db_pool.runInteraction(
"count_daily_users", self._count_users, yesterday
)
def count_monthly_users(self): def count_monthly_users(self):
""" """
@ -305,7 +307,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts. amongst other things, includes a 3 day grace period before a user counts.
""" """
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
return self.db.runInteraction( return self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago "count_monthly_users", self._count_users, thirty_days_ago
) )
@ -405,7 +407,7 @@ class DataStore(
return results return results
return self.db.runInteraction("count_r30_users", _count_r30_users) return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self): def _get_start_of_day(self):
""" """
@ -470,7 +472,7 @@ class DataStore(
# frequently # frequently
self._last_user_visit_update = now self._last_user_visit_update = now
return self.db.runInteraction( return self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits "generate_user_daily_visits", _generate_user_daily_visits
) )
@ -481,7 +483,7 @@ class DataStore(
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
return self.db.simple_select_list( return self.db_pool.simple_select_list(
table="users", table="users",
keyvalues={}, keyvalues={},
retcols=[ retcols=[
@ -543,10 +545,12 @@ class DataStore(
where_clause where_clause
) )
txn.execute(sql, args) txn.execute(sql, args)
users = self.db.cursor_to_dict(txn) users = self.db_pool.cursor_to_dict(txn)
return users, count return users, count
return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) return self.db_pool.runInteraction(
"get_users_paginate_txn", get_users_paginate_txn
)
def search_users(self, term): def search_users(self, term):
"""Function to search users list for one or more users with """Function to search users list for one or more users with
@ -558,7 +562,7 @@ class DataStore(
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
return self.db.simple_search_list( return self.db_pool.simple_search_list(
table="users", table="users",
term=term, term=term,
col="name", col="name",

View file

@ -23,7 +23,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -40,7 +40,7 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented. # the abstract methods being implemented.
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
account_max = self.get_max_account_data_stream_id() account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max "AccountDataAndTagsChangeCache", account_max
@ -69,7 +69,7 @@ class AccountDataWorkerStore(SQLBaseStore):
""" """
def get_account_data_for_user_txn(txn): def get_account_data_for_user_txn(txn):
rows = self.db.simple_select_list_txn( rows = self.db_pool.simple_select_list_txn(
txn, txn,
"account_data", "account_data",
{"user_id": user_id}, {"user_id": user_id},
@ -80,7 +80,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows row["account_data_type"]: db_to_json(row["content"]) for row in rows
} }
rows = self.db.simple_select_list_txn( rows = self.db_pool.simple_select_list_txn(
txn, txn,
"room_account_data", "room_account_data",
{"user_id": user_id}, {"user_id": user_id},
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room return global_account_data, by_room
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn "get_account_data_for_user", get_account_data_for_user_txn
) )
@ -104,7 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred: A dict Deferred: A dict
""" """
result = yield self.db.simple_select_one_onecol( result = yield self.db_pool.simple_select_one_onecol(
table="account_data", table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type}, keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content", retcol="content",
@ -129,7 +129,7 @@ class AccountDataWorkerStore(SQLBaseStore):
""" """
def get_account_data_for_room_txn(txn): def get_account_data_for_room_txn(txn):
rows = self.db.simple_select_list_txn( rows = self.db_pool.simple_select_list_txn(
txn, txn,
"room_account_data", "room_account_data",
{"user_id": user_id, "room_id": room_id}, {"user_id": user_id, "room_id": room_id},
@ -140,7 +140,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows row["account_data_type"]: db_to_json(row["content"]) for row in rows
} }
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn "get_account_data_for_room", get_account_data_for_room_txn
) )
@ -158,7 +158,7 @@ class AccountDataWorkerStore(SQLBaseStore):
""" """
def get_account_data_for_room_and_type_txn(txn): def get_account_data_for_room_and_type_txn(txn):
content_json = self.db.simple_select_one_onecol_txn( content_json = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="room_account_data", table="room_account_data",
keyvalues={ keyvalues={
@ -172,7 +172,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None return db_to_json(content_json) if content_json else None
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
) )
@ -202,7 +202,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn "get_updated_global_account_data", get_updated_global_account_data_txn
) )
@ -232,7 +232,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn "get_updated_room_account_data", get_updated_room_account_data_txn
) )
@ -277,7 +277,7 @@ class AccountDataWorkerStore(SQLBaseStore):
if not changed: if not changed:
return defer.succeed(({}, {})) return defer.succeed(({}, {}))
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
) )
@ -295,7 +295,7 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore): class AccountDataStore(AccountDataWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator( self._account_data_id_gen = StreamIdGenerator(
db_conn, db_conn,
"account_data_max_stream_id", "account_data_max_stream_id",
@ -333,7 +333,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as room_account_data has a unique constraint # no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will # on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict. # retry if there is a conflict.
yield self.db.simple_upsert( yield self.db_pool.simple_upsert(
desc="add_room_account_data", desc="add_room_account_data",
table="room_account_data", table="room_account_data",
keyvalues={ keyvalues={
@ -379,7 +379,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as account_data has a unique constraint on # no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if # (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict. # there is a conflict.
yield self.db.simple_upsert( yield self.db_pool.simple_upsert(
desc="add_user_account_data", desc="add_user_account_data",
table="account_data", table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type}, keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@ -427,4 +427,4 @@ class AccountDataStore(AccountDataWorkerStore):
) )
txn.execute(update_max_id_sql, (next_id, next_id)) txn.execute(update_max_id_sql, (next_id, next_id))
return self.db.runInteraction("update_account_data_max_stream_id", _update) return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)

View file

@ -23,8 +23,8 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.events_worker import EventsWorkerStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,7 +49,7 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore): class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.services_cache = load_appservices( self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files hs.hostname, hs.config.app_service_config_files
) )
@ -134,7 +134,7 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves to a list of ApplicationServices, which A Deferred which resolves to a list of ApplicationServices, which
may be empty. may be empty.
""" """
results = yield self.db.simple_select_list( results = yield self.db_pool.simple_select_list(
"application_services_state", {"state": state}, ["as_id"] "application_services_state", {"state": state}, ["as_id"]
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
@ -156,7 +156,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns: Returns:
A Deferred which resolves to ApplicationServiceState. A Deferred which resolves to ApplicationServiceState.
""" """
result = yield self.db.simple_select_one( result = yield self.db_pool.simple_select_one(
"application_services_state", "application_services_state",
{"as_id": service.id}, {"as_id": service.id},
["state"], ["state"],
@ -176,7 +176,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns: Returns:
A Deferred which resolves when the state was set successfully. A Deferred which resolves when the state was set successfully.
""" """
return self.db.simple_upsert( return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state} "application_services_state", {"as_id": service.id}, {"state": state}
) )
@ -217,7 +217,9 @@ class ApplicationServiceTransactionWorkerStore(
) )
return AppServiceTransaction(service=service, id=new_txn_id, events=events) return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) return self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
)
def complete_appservice_txn(self, txn_id, service): def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction. """Completes an application service transaction.
@ -250,7 +252,7 @@ class ApplicationServiceTransactionWorkerStore(
) )
# Set current txn_id for AS to 'txn_id' # Set current txn_id for AS to 'txn_id'
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
"application_services_state", "application_services_state",
{"as_id": service.id}, {"as_id": service.id},
@ -258,13 +260,13 @@ class ApplicationServiceTransactionWorkerStore(
) )
# Delete txn # Delete txn
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
"application_services_txns", "application_services_txns",
{"txn_id": txn_id, "as_id": service.id}, {"txn_id": txn_id, "as_id": service.id},
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"complete_appservice_txn", _complete_appservice_txn "complete_appservice_txn", _complete_appservice_txn
) )
@ -288,7 +290,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1", " ORDER BY txn_id ASC LIMIT 1",
(service.id,), (service.id,),
) )
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if not rows: if not rows:
return None return None
@ -296,7 +298,7 @@ class ApplicationServiceTransactionWorkerStore(
return entry return entry
entry = yield self.db.runInteraction( entry = yield self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
) )
@ -326,7 +328,7 @@ class ApplicationServiceTransactionWorkerStore(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn "set_appservice_last_pos", set_appservice_last_pos_txn
) )
@ -355,7 +357,7 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows] return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.db.runInteraction( upper_bound, event_ids = yield self.db_pool.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn "get_new_events_for_appservice", get_new_events_for_appservice_txn
) )

View file

@ -26,7 +26,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow, EventsStreamEventRow,
) )
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -39,7 +39,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore): class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -92,7 +92,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return updates, upto_token, limited return updates, upto_token, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
@ -203,7 +203,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return return
cache_func.invalidate(keys) cache_func.invalidate(keys)
await self.db.runInteraction( await self.db_pool.runInteraction(
"invalidate_cache_and_stream", "invalidate_cache_and_stream",
self._send_invalidation_to_replication, self._send_invalidation_to_replication,
cache_func.__name__, cache_func.__name__,
@ -288,7 +288,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if keys is not None: if keys is not None:
keys = list(keys) keys = list(keys)
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="cache_invalidation_stream_by_instance", table="cache_invalidation_stream_by_instance",
values={ values={

View file

@ -21,10 +21,10 @@ from twisted.internet import defer
from synapse.events.utils import prune_event_dict from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.events import encode_json from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events import encode_json
from synapse.storage.database import Database from synapse.storage.databases.main.events_worker import EventsWorkerStore
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
def __init__(self, database: Database, db_conn, hs: "HomeServer"): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
def _censor_redactions(): def _censor_redactions():
@ -56,7 +56,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return return
if not ( if not (
await self.db.updates.has_completed_background_update( await self.db_pool.updates.has_completed_background_update(
"redactions_have_censored_ts_idx" "redactions_have_censored_ts_idx"
) )
): ):
@ -85,7 +85,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
LIMIT ? LIMIT ?
""" """
rows = await self.db.execute( rows = await self.db_pool.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100 "_censor_redactions_fetch", None, sql, before_ts, 100
) )
@ -123,14 +123,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
if pruned_json: if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json) self._censor_event_txn(txn, event_id, pruned_json)
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="redactions", table="redactions",
keyvalues={"event_id": redaction_id}, keyvalues={"event_id": redaction_id},
updatevalues={"have_censored": True}, updatevalues={"have_censored": True},
) )
await self.db.runInteraction("_update_censor_txn", _update_censor_txn) await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json): def _censor_event_txn(self, txn, event_id, pruned_json):
"""Censor an event by replacing its JSON in the event_json table with the """Censor an event by replacing its JSON in the event_json table with the
@ -141,7 +141,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
event_id (str): The ID of the event to censor. event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON pruned_json (str): The pruned JSON
""" """
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="event_json", table="event_json",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
@ -193,7 +193,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
txn, "_get_event_cache", (event.event_id,) txn, "_get_event_cache", (event.event_id,)
) )
yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) yield self.db_pool.runInteraction(
"delete_expired_event", delete_expired_event_txn
)
def _delete_event_expiry_txn(self, txn, event_id): def _delete_event_expiry_txn(self, txn, event_id):
"""Delete the expiry timestamp associated with an event ID without deleting the """Delete the expiry timestamp associated with an event ID without deleting the
@ -203,6 +205,6 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
txn (LoggingTransaction): The transaction to use to perform the deletion. txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of. event_id (str): The event ID to delete the associated expiry timestamp of.
""" """
return self.db.simple_delete_txn( return self.db_pool.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id} txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
) )

View file

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database, make_tuple_comparison_clause from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,40 +31,40 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"user_ips_device_index", "user_ips_device_index",
index_name="user_ips_device_id", index_name="user_ips_device_id",
table="user_ips", table="user_ips",
columns=["user_id", "device_id", "last_seen"], columns=["user_id", "device_id", "last_seen"],
) )
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"user_ips_last_seen_index", "user_ips_last_seen_index",
index_name="user_ips_last_seen", index_name="user_ips_last_seen",
table="user_ips", table="user_ips",
columns=["user_id", "last_seen"], columns=["user_id", "last_seen"],
) )
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"user_ips_last_seen_only_index", "user_ips_last_seen_only_index",
index_name="user_ips_last_seen_only", index_name="user_ips_last_seen_only",
table="user_ips", table="user_ips",
columns=["last_seen"], columns=["last_seen"],
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"user_ips_analyze", self._analyze_user_ip "user_ips_analyze", self._analyze_user_ip
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"user_ips_remove_dupes", self._remove_user_ip_dupes "user_ips_remove_dupes", self._remove_user_ip_dupes
) )
# Register a unique index # Register a unique index
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"user_ips_device_unique_index", "user_ips_device_unique_index",
index_name="user_ips_user_token_ip_unique_index", index_name="user_ips_user_token_ip_unique_index",
table="user_ips", table="user_ips",
@ -73,12 +73,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
) )
# Drop the old non-unique index # Drop the old non-unique index
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
) )
# Update the last seen info in devices. # Update the last seen info in devices.
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"devices_last_seen", self._devices_last_seen_update "devices_last_seen", self._devices_last_seen_update
) )
@ -89,8 +89,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close() txn.close()
yield self.db.runWithConnection(f) yield self.db_pool.runWithConnection(f)
yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") yield self.db_pool.updates._end_background_update(
"user_ips_drop_nonunique_index"
)
return 1 return 1
@defer.inlineCallbacks @defer.inlineCallbacks
@ -104,9 +106,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
def user_ips_analyze(txn): def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips") txn.execute("ANALYZE user_ips")
yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
yield self.db.updates._end_background_update("user_ips_analyze") yield self.db_pool.updates._end_background_update("user_ips_analyze")
return 1 return 1
@ -138,7 +140,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return None return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen` # Get a last seen that has roughly `batch_size` since `begin_last_seen`
end_last_seen = yield self.db.runInteraction( end_last_seen = yield self.db_pool.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen "user_ips_dups_get_last_seen", get_last_seen
) )
@ -269,14 +271,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
(user_id, access_token, ip, device_id, user_agent, last_seen), (user_id, access_token, ip, device_id, user_agent, last_seen),
) )
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
) )
yield self.db.runInteraction("user_ips_dups_remove", remove) yield self.db_pool.runInteraction("user_ips_dups_remove", remove)
if last: if last:
yield self.db.updates._end_background_update("user_ips_remove_dupes") yield self.db_pool.updates._end_background_update("user_ips_remove_dupes")
return batch_size return batch_size
@ -336,7 +338,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
txn.execute_batch(sql, rows) txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1] _, _, _, user_id, device_id = rows[-1]
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, txn,
"devices_last_seen", "devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id}, {"last_user_id": user_id, "last_device_id": device_id},
@ -344,18 +346,18 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return len(rows) return len(rows)
updated = yield self.db.runInteraction( updated = yield self.db_pool.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn "_devices_last_seen_update", _devices_last_seen_update_txn
) )
if not updated: if not updated:
yield self.db.updates._end_background_update("devices_last_seen") yield self.db_pool.updates._end_background_update("devices_last_seen")
return updated return updated
class ClientIpStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 name="client_ip_last_seen", keylen=4, max_entries=50000
@ -403,18 +405,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _update_client_ips_batch(self): def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating # If the DB pool has already terminated, don't try updating
if not self.db.is_running(): if not self.db_pool.is_running():
return return
to_update = self._batch_row_update to_update = self._batch_row_update
self._batch_row_update = {} self._batch_row_update = {}
return self.db.runInteraction( return self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
) )
def _update_client_ips_batch_txn(self, txn, to_update): def _update_client_ips_batch_txn(self, txn, to_update):
if "user_ips" in self.db._unsafe_to_upsert_tables or ( if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert not self.database_engine.can_native_upsert
): ):
self.database_engine.lock_table(txn, "user_ips") self.database_engine.lock_table(txn, "user_ips")
@ -423,7 +425,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try: try:
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="user_ips", table="user_ips",
keyvalues={ keyvalues={
@ -445,7 +447,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# this is always an update rather than an upsert: the row should # this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been # already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it. # deleted, and we don't want to re-create it.
self.db.simple_update_txn( self.db_pool.simple_update_txn(
txn, txn,
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
@ -477,7 +479,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if device_id is not None: if device_id is not None:
keyvalues["device_id"] = device_id keyvalues["device_id"] = device_id
res = yield self.db.simple_select_list( res = yield self.db_pool.simple_select_list(
table="devices", table="devices",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@ -510,7 +512,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key] user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen) results[(access_token, ip)] = (user_agent, last_seen)
rows = yield self.db.simple_select_list( rows = yield self.db_pool.simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"], retcols=["access_token", "ip", "user_agent", "last_seen"],
@ -540,7 +542,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Nothing to do # Nothing to do
return return
if not await self.db.updates.has_completed_background_update( if not await self.db_pool.updates.has_completed_background_update(
"devices_last_seen" "devices_last_seen"
): ):
# Only start pruning if we have finished populating the devices # Only start pruning if we have finished populating the devices
@ -573,4 +575,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _prune_old_user_ips_txn(txn): def _prune_old_user_ips_txn(txn):
txn.execute(sql, (timestamp,)) txn.execute(sql, (timestamp,))
await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) await self.db_pool.runInteraction(
"_prune_old_user_ips", _prune_old_user_ips_txn
)

View file

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -70,7 +70,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id stream_pos = current_stream_id
return messages, stream_pos return messages, stream_pos
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn "get_new_messages_for_device", get_new_messages_for_device_txn
) )
@ -110,7 +110,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id)) txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount return txn.rowcount
count = yield self.db.runInteraction( count = yield self.db_pool.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn "delete_messages_for_device", delete_messages_for_device_txn
) )
@ -179,7 +179,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id stream_pos = current_stream_id
return messages, stream_pos return messages, stream_pos
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_new_device_msgs_for_remote", "get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn, get_new_messages_for_remote_destination_txn,
) )
@ -204,7 +204,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
txn.execute(sql, (destination, up_to_stream_id)) txn.execute(sql, (destination, up_to_stream_id))
return self.db.runInteraction( return self.db_pool.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
) )
@ -269,7 +269,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return updates, upto_token, limited return updates, upto_token, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn "get_all_new_device_messages", get_all_new_device_messages_txn
) )
@ -277,17 +277,17 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"device_inbox_stream_index", "device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id", index_name="device_inbox_stream_id_user_id",
table="device_inbox", table="device_inbox",
columns=["stream_id", "user_id"], columns=["stream_id", "user_id"],
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
) )
@ -298,9 +298,9 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close() txn.close()
yield self.db.runWithConnection(reindex_txn) yield self.db_pool.runWithConnection(reindex_txn)
yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1 return 1
@ -308,7 +308,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceInboxStore, self).__init__(database, db_conn, hs) super(DeviceInboxStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been # Map of (user_id, device_id) to the last stream_id that has been
@ -360,7 +360,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id: with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
) )
for user_id in local_messages_by_user_then_device.keys(): for user_id in local_messages_by_user_then_device.keys():
@ -380,7 +380,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Check if we've already inserted a matching message_id for that # Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our # origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message. # acknowledgement from the first time we received the message.
already_inserted = self.db.simple_select_one_txn( already_inserted = self.db_pool.simple_select_one_txn(
txn, txn,
table="device_federation_inbox", table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id}, keyvalues={"origin": origin, "message_id": message_id},
@ -392,7 +392,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add an entry for this message_id so that we know we've processed # Add an entry for this message_id so that we know we've processed
# it. # it.
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="device_federation_inbox", table="device_federation_inbox",
values={ values={
@ -410,7 +410,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id: with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox", "add_messages_from_remote_to_device_inbox",
add_messages_txn, add_messages_txn,
now_ms, now_ms,

View file

@ -31,7 +31,7 @@ from synapse.logging.opentracing import (
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import ( from synapse.storage.database import (
Database, DatabasePool,
LoggingTransaction, LoggingTransaction,
make_tuple_comparison_clause, make_tuple_comparison_clause,
) )
@ -67,7 +67,7 @@ class DeviceWorkerStore(SQLBaseStore):
Raises: Raises:
StoreError: if the device is not found StoreError: if the device is not found
""" """
return self.db.simple_select_one( return self.db_pool.simple_select_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"), retcols=("user_id", "device_id", "display_name"),
@ -86,7 +86,7 @@ class DeviceWorkerStore(SQLBaseStore):
containing "device_id", "user_id" and "display_name" for each containing "device_id", "user_id" and "display_name" for each
device. device.
""" """
devices = yield self.db.simple_select_list( devices = yield self.db_pool.simple_select_list(
table="devices", table="devices",
keyvalues={"user_id": user_id, "hidden": False}, keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"), retcols=("user_id", "device_id", "display_name"),
@ -118,7 +118,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed: if not has_changed:
return now_stream_id, [] return now_stream_id, []
updates = yield self.db.runInteraction( updates = yield self.db_pool.runInteraction(
"get_device_updates_by_remote", "get_device_updates_by_remote",
self._get_device_updates_by_remote_txn, self._get_device_updates_by_remote_txn,
destination, destination,
@ -255,7 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
devices = ( devices = (
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_get_e2e_device_keys_txn", "_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn, self._get_e2e_device_keys_txn,
query_map.keys(), query_map.keys(),
@ -326,12 +326,12 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall() rows = txn.fetchall()
return rows[0][0] return rows[0][0]
return self.db.runInteraction("get_last_device_update_for_remote_user", f) return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id): def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination. """Mark that updates have successfully been sent to the destination.
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote", "mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn, self._mark_as_sent_devices_by_remote_txn,
destination, destination,
@ -350,7 +350,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn.execute(sql, (destination, stream_id)) txn.execute(sql, (destination, stream_id))
rows = txn.fetchall() rows = txn.fetchall()
self.db.simple_upsert_many_txn( self.db_pool.simple_upsert_many_txn(
txn=txn, txn=txn,
table="device_lists_outbound_last_success", table="device_lists_outbound_last_success",
key_names=("destination", "user_id"), key_names=("destination", "user_id"),
@ -376,7 +376,7 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
with self._device_list_id_gen.get_next() as stream_id: with self._device_list_id_gen.get_next() as stream_id:
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"add_user_sig_change_to_streams", "add_user_sig_change_to_streams",
self._add_user_signature_change_txn, self._add_user_signature_change_txn,
from_user_id, from_user_id,
@ -391,7 +391,7 @@ class DeviceWorkerStore(SQLBaseStore):
from_user_id, from_user_id,
stream_id, stream_id,
) )
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
"user_signature_stream", "user_signature_stream",
values={ values={
@ -449,7 +449,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2, tree=True) @cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id): def _get_cached_user_device(self, user_id, device_id):
content = yield self.db.simple_select_one_onecol( content = yield self.db_pool.simple_select_one_onecol(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content", retcol="content",
@ -459,7 +459,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_cached_devices_for_user(self, user_id): def get_cached_devices_for_user(self, user_id):
devices = yield self.db.simple_select_list( devices = yield self.db_pool.simple_select_list(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("device_id", "content"), retcols=("device_id", "content"),
@ -475,7 +475,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns: Returns:
(stream_id, devices) (stream_id, devices)
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_devices_with_keys_by_user", "get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, self._get_devices_with_keys_by_user_txn,
user_id, user_id,
@ -555,7 +555,7 @@ class DeviceWorkerStore(SQLBaseStore):
return changes return changes
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
) )
@ -574,7 +574,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ? WHERE from_user_id = ? AND stream_id > ?
""" """
rows = yield self.db.execute( rows = yield self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key "get_users_whose_signatures_changed", None, sql, user_id, from_key
) )
return {user for row in rows for user in db_to_json(row[0])} return {user for row in rows for user in db_to_json(row[0])}
@ -631,7 +631,7 @@ class DeviceWorkerStore(SQLBaseStore):
return updates, upto_token, limited return updates, upto_token, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_device_list_changes_for_remotes", "get_all_device_list_changes_for_remotes",
_get_all_device_list_changes_for_remotes, _get_all_device_list_changes_for_remotes,
) )
@ -641,7 +641,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get the last stream_id we got for a user. May be None if we haven't """Get the last stream_id we got for a user. May be None if we haven't
got any information for them. got any information for them.
""" """
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="stream_id", retcol="stream_id",
@ -655,7 +655,7 @@ class DeviceWorkerStore(SQLBaseStore):
inlineCallbacks=True, inlineCallbacks=True,
) )
def get_device_list_last_stream_id_for_remotes(self, user_ids): def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
@ -680,7 +680,7 @@ class DeviceWorkerStore(SQLBaseStore):
The IDs of users whose device lists need resync. The IDs of users whose device lists need resync.
""" """
if user_ids: if user_ids:
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync", table="device_lists_remote_resync",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
@ -688,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable", desc="get_user_ids_requiring_device_list_resync_with_iterable",
) )
else: else:
rows = yield self.db.simple_select_list( rows = yield self.db_pool.simple_select_list(
table="device_lists_remote_resync", table="device_lists_remote_resync",
keyvalues=None, keyvalues=None,
retcols=("user_id",), retcols=("user_id",),
@ -701,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Records that the server has reason to believe the cache of the devices """Records that the server has reason to believe the cache of the devices
for the remote users is out of date. for the remote users is out of date.
""" """
return self.db.simple_upsert( return self.db_pool.simple_upsert(
table="device_lists_remote_resync", table="device_lists_remote_resync",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
values={}, values={},
@ -714,7 +714,7 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
def _mark_remote_user_device_list_as_unsubscribed_txn(txn): def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
@ -723,17 +723,17 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,) txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed", "mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn, _mark_remote_user_device_list_as_unsubscribed_txn,
) )
class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx", "device_lists_stream_idx",
index_name="device_lists_stream_user_id", index_name="device_lists_stream_user_id",
table="device_lists_stream", table="device_lists_stream",
@ -741,7 +741,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
) )
# create a unique index on device_lists_remote_cache # create a unique index on device_lists_remote_cache
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"device_lists_remote_cache_unique_idx", "device_lists_remote_cache_unique_idx",
index_name="device_lists_remote_cache_unique_id", index_name="device_lists_remote_cache_unique_id",
table="device_lists_remote_cache", table="device_lists_remote_cache",
@ -750,7 +750,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
) )
# And one on device_lists_remote_extremeties # And one on device_lists_remote_extremeties
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"device_lists_remote_extremeties_unique_idx", "device_lists_remote_extremeties_unique_idx",
index_name="device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx",
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
@ -759,22 +759,22 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
) )
# once they complete, we can remove the old non-unique indexes. # once they complete, we can remove the old non-unique indexes.
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
self._drop_device_list_streams_non_unique_indexes, self._drop_device_list_streams_non_unique_indexes,
) )
# clear out duplicate device list outbound pokes # clear out duplicate device list outbound pokes
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
) )
# a pair of background updates that were added during the 1.14 release cycle, # a pair of background updates that were added during the 1.14 release cycle,
# but replaced with 58/06dlols_unique_idx.py # but replaced with 58/06dlols_unique_idx.py
self.db.updates.register_noop_background_update( self.db_pool.updates.register_noop_background_update(
"device_lists_outbound_last_success_unique_idx", "device_lists_outbound_last_success_unique_idx",
) )
self.db.updates.register_noop_background_update( self.db_pool.updates.register_noop_background_update(
"drop_device_lists_outbound_last_success_non_unique_idx", "drop_device_lists_outbound_last_success_non_unique_idx",
) )
@ -786,8 +786,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close() txn.close()
yield self.db.runWithConnection(f) yield self.db_pool.runWithConnection(f)
yield self.db.updates._end_background_update( yield self.db_pool.updates._end_background_update(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
) )
return 1 return 1
@ -807,7 +807,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn): def _txn(txn):
clause, args = make_tuple_comparison_clause( clause, args = make_tuple_comparison_clause(
self.db.engine, [(x, last_row[x]) for x in KEY_COLS] self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
) )
sql = """ sql = """
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
@ -823,30 +823,32 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
",".join(KEY_COLS), # ORDER BY ",".join(KEY_COLS), # ORDER BY
) )
txn.execute(sql, args + [batch_size]) txn.execute(sql, args + [batch_size])
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
row = None row = None
for row in rows: for row in rows:
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
) )
row["sent"] = False row["sent"] = False
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, "device_lists_outbound_pokes", row, txn, "device_lists_outbound_pokes", row,
) )
if row: if row:
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
) )
return len(rows) return len(rows)
rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) rows = await self.db_pool.runInteraction(
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
)
if not rows: if not rows:
await self.db.updates._end_background_update( await self.db_pool.updates._end_background_update(
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
) )
@ -854,7 +856,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceStore, self).__init__(database, db_conn, hs) super(DeviceStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
@ -885,7 +887,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False return False
try: try:
inserted = yield self.db.simple_insert( inserted = yield self.db_pool.simple_insert(
"devices", "devices",
values={ values={
"user_id": user_id, "user_id": user_id,
@ -899,7 +901,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted: if not inserted:
# if the device already exists, check if it's a real device, or # if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else # if the device ID is reserved by something else
hidden = yield self.db.simple_select_one_onecol( hidden = yield self.db_pool.simple_select_one_onecol(
"devices", "devices",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden", retcol="hidden",
@ -934,7 +936,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns: Returns:
defer.Deferred defer.Deferred
""" """
yield self.db.simple_delete_one( yield self.db_pool.simple_delete_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device", desc="delete_device",
@ -952,7 +954,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns: Returns:
defer.Deferred defer.Deferred
""" """
yield self.db.simple_delete_many( yield self.db_pool.simple_delete_many(
table="devices", table="devices",
column="device_id", column="device_id",
iterable=device_ids, iterable=device_ids,
@ -981,7 +983,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updates["display_name"] = new_display_name updates["display_name"] = new_display_name
if not updates: if not updates:
return defer.succeed(None) return defer.succeed(None)
return self.db.simple_update_one( return self.db_pool.simple_update_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates, updatevalues=updates,
@ -1005,7 +1007,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns: Returns:
Deferred[None] Deferred[None]
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"update_remote_device_list_cache_entry", "update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn, self._update_remote_device_list_cache_entry_txn,
user_id, user_id,
@ -1018,7 +1020,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self, txn, user_id, device_id, content, stream_id self, txn, user_id, device_id, content, stream_id
): ):
if content.get("deleted"): if content.get("deleted"):
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
@ -1026,7 +1028,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else: else:
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
@ -1042,7 +1044,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
) )
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
@ -1066,7 +1068,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns: Returns:
Deferred[None] Deferred[None]
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"update_remote_device_list_cache", "update_remote_device_list_cache",
self._update_remote_device_list_cache_txn, self._update_remote_device_list_cache_txn,
user_id, user_id,
@ -1075,11 +1077,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
) )
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="device_lists_remote_cache", table="device_lists_remote_cache",
values=[ values=[
@ -1098,7 +1100,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
) )
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
@ -1111,7 +1113,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# If we're replacing the remote user's device list cache presumably # If we're replacing the remote user's device list cache presumably
# we've done a full resync, so we remove the entry that says we need # we've done a full resync, so we remove the entry that says we need
# to resync # to resync
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
) )
@ -1124,7 +1126,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return return
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"add_device_change_to_stream", "add_device_change_to_stream",
self._add_device_change_to_stream_txn, self._add_device_change_to_stream_txn,
user_id, user_id,
@ -1139,7 +1141,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
with self._device_list_id_gen.get_next_mult( with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids) len(hosts) * len(device_ids)
) as stream_ids: ) as stream_ids:
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"add_device_outbound_poke_to_stream", "add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn, self._add_device_outbound_poke_to_stream_txn,
user_id, user_id,
@ -1174,7 +1176,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[(user_id, device_id, min_stream_id) for device_id in device_ids], [(user_id, device_id, min_stream_id) for device_id in device_ids],
) )
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="device_lists_stream", table="device_lists_stream",
values=[ values=[
@ -1196,7 +1198,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
now = self._clock.time_msec() now = self._clock.time_msec()
next_stream_id = iter(stream_ids) next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="device_lists_outbound_pokes", table="device_lists_outbound_pokes",
values=[ values=[
@ -1303,7 +1305,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return run_as_background_process( return run_as_background_process(
"prune_old_outbound_device_pokes", "prune_old_outbound_device_pokes",
self.db.runInteraction, self.db_pool.runInteraction,
"_prune_old_outbound_device_pokes", "_prune_old_outbound_device_pokes",
_prune_txn, _prune_txn,
) )

View file

@ -37,7 +37,7 @@ class DirectoryWorkerStore(SQLBaseStore):
Deferred: results in namedtuple with keys "room_id" and Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found "servers" or None if no association can be found
""" """
room_id = yield self.db.simple_select_one_onecol( room_id = yield self.db_pool.simple_select_one_onecol(
"room_aliases", "room_aliases",
{"room_alias": room_alias.to_string()}, {"room_alias": room_alias.to_string()},
"room_id", "room_id",
@ -48,7 +48,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id: if not room_id:
return None return None
servers = yield self.db.simple_select_onecol( servers = yield self.db_pool.simple_select_onecol(
"room_alias_servers", "room_alias_servers",
{"room_alias": room_alias.to_string()}, {"room_alias": room_alias.to_string()},
"server", "server",
@ -61,7 +61,7 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers) return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias): def get_room_alias_creator(self, room_alias):
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="room_aliases", table="room_aliases",
keyvalues={"room_alias": room_alias}, keyvalues={"room_alias": room_alias},
retcol="creator", retcol="creator",
@ -70,7 +70,7 @@ class DirectoryWorkerStore(SQLBaseStore):
@cached(max_entries=5000) @cached(max_entries=5000)
def get_aliases_for_room(self, room_id): def get_aliases_for_room(self, room_id):
return self.db.simple_select_onecol( return self.db_pool.simple_select_onecol(
"room_aliases", "room_aliases",
{"room_id": room_id}, {"room_id": room_id},
"room_alias", "room_alias",
@ -94,7 +94,7 @@ class DirectoryStore(DirectoryWorkerStore):
""" """
def alias_txn(txn): def alias_txn(txn):
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
"room_aliases", "room_aliases",
{ {
@ -104,7 +104,7 @@ class DirectoryStore(DirectoryWorkerStore):
}, },
) )
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="room_alias_servers", table="room_alias_servers",
values=[ values=[
@ -118,7 +118,7 @@ class DirectoryStore(DirectoryWorkerStore):
) )
try: try:
ret = yield self.db.runInteraction( ret = yield self.db_pool.runInteraction(
"create_room_alias_association", alias_txn "create_room_alias_association", alias_txn
) )
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
@ -129,7 +129,7 @@ class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
room_id = yield self.db.runInteraction( room_id = yield self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias "delete_room_alias", self._delete_room_alias_txn, room_alias
) )
@ -190,6 +190,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,) txn, self.get_aliases_for_room, (new_room_id,)
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn "_update_aliases_for_room_txn", _update_aliases_for_room_txn
) )

View file

@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError StoreError
""" """
yield self.db.simple_update_one( yield self.db_pool.simple_update_one(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
} }
) )
yield self.db.simple_insert_many( yield self.db_pool.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys" table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
) )
@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id: if session_id:
keyvalues["session_id"] = session_id keyvalues["session_id"] = session_id
rows = yield self.db.simple_select_list( rows = yield self.db_pool.simple_select_list(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=( retcols=(
@ -171,7 +171,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_e2e_room_keys_multi", "get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn, self._get_e2e_room_keys_multi_txn,
user_id, user_id,
@ -235,7 +235,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version (str): the version ID of the backup we're querying about version (str): the version ID of the backup we're querying about
""" """
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version}, keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)", retcol="COUNT(*)",
@ -268,7 +268,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id: if session_id:
keyvalues["session_id"] = session_id keyvalues["session_id"] = session_id
yield self.db.simple_delete( yield self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
) )
@ -313,7 +313,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there. # it isn't there.
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
result = self.db.simple_select_one_txn( result = self.db_pool.simple_select_one_txn(
txn, txn,
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
@ -325,7 +325,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result["etag"] = 0 result["etag"] = 0
return result return result
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
) )
@ -353,7 +353,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
new_version = str(int(current_version) + 1) new_version = str(int(current_version) + 1)
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
values={ values={
@ -366,7 +366,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version return new_version
return self.db.runInteraction( return self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
) )
@ -392,7 +392,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag updatevalues["etag"] = version_etag
if updatevalues: if updatevalues:
return self.db.simple_update( return self.db_pool.simple_update(
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version}, keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues, updatevalues=updatevalues,
@ -421,19 +421,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
else: else:
this_version = version this_version = version
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="e2e_room_keys", table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": this_version}, keyvalues={"user_id": user_id, "version": this_version},
) )
return self.db.simple_update_one_txn( return self.db_pool.simple_update_one_txn(
txn, txn,
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version}, keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1}, updatevalues={"deleted": 1},
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
) )

View file

@ -51,7 +51,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list: if not query_list:
return {} return {}
results = yield self.db.runInteraction( results = yield self.db_pool.runInteraction(
"get_e2e_device_keys", "get_e2e_device_keys",
self._get_e2e_device_keys_txn, self._get_e2e_device_keys_txn,
query_list, query_list,
@ -128,7 +128,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
) )
txn.execute(sql, query_params) txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
result = {} result = {}
for row in rows: for row in rows:
@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
) )
txn.execute(signature_sql, signature_query_params) txn.execute(signature_sql, signature_query_params)
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
# add each cross-signing signature to the correct device in the result dict. # add each cross-signing signature to the correct device in the result dict.
for row in rows: for row in rows:
@ -189,7 +189,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
key_id) to json string for key key_id) to json string for key
""" """
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
column="key_id", column="key_id",
iterable=key_ids, iterable=key_ids,
@ -222,7 +222,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# a unique constraint. If there is a race of two calls to # a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only # `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set. # insert one set.
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
values=[ values=[
@ -241,7 +241,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
) )
@ -264,7 +264,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count result[algorithm] = key_count
return result return result
return self.db.runInteraction( return self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys "count_e2e_one_time_keys", _count_e2e_one_time_keys
) )
@ -318,7 +318,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
to None. to None.
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk", "get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn, self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids, user_ids,
@ -361,7 +361,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
) )
txn.execute(sql, params) txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
for row in rows: for row in rows:
user_id = row["user_id"] user_id = row["user_id"]
@ -420,7 +420,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
query_params.extend(item) query_params.extend(item)
txn.execute(sql, query_params) txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys # and add the signatures to the appropriate keys
for row in rows: for row in rows:
@ -470,7 +470,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id: if from_user_id:
result = yield self.db.runInteraction( result = yield self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures", "get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn, self._get_e2e_cross_signing_signatures_txn,
result, result,
@ -531,7 +531,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return updates, upto_token, limited return updates, upto_token, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_user_signature_changes_for_remotes", "get_all_user_signature_changes_for_remotes",
_get_all_user_signature_changes_for_remotes_txn, _get_all_user_signature_changes_for_remotes_txn,
) )
@ -549,7 +549,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("time_now", time_now) set_tag("time_now", time_now)
set_tag("device_keys", device_keys) set_tag("device_keys", device_keys)
old_key_json = self.db.simple_select_one_onecol_txn( old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="e2e_device_keys_json", table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
@ -565,7 +565,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"Message": "Device key already stored."}) log_kv({"Message": "Device key already stored."})
return False return False
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="e2e_device_keys_json", table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
@ -574,7 +574,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."}) log_kv({"message": "Device keys stored."})
return True return True
return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) return self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def claim_e2e_one_time_keys(self, query_list): def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database""" """Take a list of one time keys out of the database"""
@ -613,7 +615,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) )
return result return result
return self.db.runInteraction( return self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
) )
@ -626,12 +628,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"user_id": user_id, "user_id": user_id,
} }
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="e2e_device_keys_json", table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
@ -640,7 +642,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
) )
@ -679,7 +681,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# We only need to do this for local users, since remote servers should be # We only need to do this for local users, since remote servers should be
# responsible for checking this for their own users. # responsible for checking this for their own users.
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
"devices", "devices",
values={ values={
@ -692,7 +694,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# and finally, store the key itself # and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id: with self._cross_signing_id_gen.get_next() as stream_id:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
"e2e_cross_signing_keys", "e2e_cross_signing_keys",
values={ values={
@ -715,7 +717,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set key_type (str): the type of cross-signing key to set
key (dict): the key data key (dict): the key data
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"add_e2e_cross_signing_key", "add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn, self._set_e2e_cross_signing_key_txn,
user_id, user_id,
@ -730,7 +732,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
user_id (str): the user who made the signatures user_id (str): the user who made the signatures
signatures (iterable[SignatureListItem]): signatures to add signatures (iterable[SignatureListItem]): signatures to add
""" """
return self.db.simple_insert_many( return self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures", "e2e_cross_signing_signatures",
[ [
{ {

View file

@ -22,9 +22,9 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -65,7 +65,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns: Returns:
list of event_ids list of event_ids
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_auth_chain_ids", "get_auth_chain_ids",
self._get_auth_chain_ids_txn, self._get_auth_chain_ids_txn,
event_ids, event_ids,
@ -114,7 +114,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Deferred[Set[str]] Deferred[Set[str]]
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_auth_chain_difference", "get_auth_chain_difference",
self._get_auth_chain_difference_txn, self._get_auth_chain_difference_txn,
state_sets, state_sets,
@ -260,12 +260,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return {eid for eid, n in event_to_missing_sets.items() if n} return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_in_room(self, room_id): def get_oldest_events_in_room(self, room_id):
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
) )
def get_oldest_events_with_depth_in_room(self, room_id): def get_oldest_events_with_depth_in_room(self, room_id):
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room", "get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn, self.get_oldest_events_with_depth_in_room_txn,
room_id, room_id,
@ -296,7 +296,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns Returns
Deferred[int] Deferred[int]
""" """
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="events", table="events",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
@ -310,7 +310,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return max(row["depth"] for row in rows) return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self.db.simple_select_onecol_txn( return self.db_pool.simple_select_onecol_txn(
txn, txn,
table="event_backward_extremities", table="event_backward_extremities",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
@ -332,7 +332,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
) )
@ -387,13 +387,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args) txn.execute(sql, query_args)
return [room_id for room_id, in txn] return [room_id for room_id, in txn]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
) )
@cached(max_entries=5000, iterable=True) @cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id): def get_latest_event_ids_in_room(self, room_id):
return self.db.simple_select_onecol( return self.db_pool.simple_select_onecol(
table="event_forward_extremities", table="event_forward_extremities",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="event_id", retcol="event_id",
@ -403,12 +403,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
def get_min_depth(self, room_id): def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it. """ For hte given room, get the minimum depth we have seen for it.
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id "get_min_depth", self._get_min_depth_interaction, room_id
) )
def _get_min_depth_interaction(self, txn, room_id): def _get_min_depth_interaction(self, txn, room_id):
min_depth = self.db.simple_select_one_onecol_txn( min_depth = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="room_depth", table="room_depth",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
@ -474,7 +474,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn] return [event_id for event_id, in txn]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
) )
@ -489,7 +489,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit (int) limit (int)
""" """
return ( return (
self.db.runInteraction( self.db_pool.runInteraction(
"get_backfill_events", "get_backfill_events",
self._get_backfill_events, self._get_backfill_events,
room_id, room_id,
@ -520,7 +520,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue() queue = PriorityQueue()
for event_id in event_list: for event_id in event_list:
depth = self.db.simple_select_one_onecol_txn( depth = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="events", table="events",
keyvalues={"event_id": event_id, "room_id": room_id}, keyvalues={"event_id": event_id, "room_id": room_id},
@ -552,7 +552,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
@defer.inlineCallbacks @defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, limit): def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = yield self.db.runInteraction( ids = yield self.db_pool.runInteraction(
"get_missing_events", "get_missing_events",
self._get_missing_events, self._get_missing_events,
room_id, room_id,
@ -605,7 +605,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns: Returns:
Deferred[list[str]] Deferred[list[str]]
""" """
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="event_edges", table="event_edges",
column="prev_event_id", column="prev_event_id",
iterable=event_ids, iterable=event_ids,
@ -628,10 +628,10 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only" EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(EventFederationStore, self).__init__(database, db_conn, hs) super(EventFederationStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
) )
@ -658,13 +658,13 @@ class EventFederationStore(EventFederationWorkerStore):
return run_as_background_process( return run_as_background_process(
"delete_old_forward_extrem_cache", "delete_old_forward_extrem_cache",
self.db.runInteraction, self.db_pool.runInteraction,
"_delete_old_forward_extrem_cache", "_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn, _delete_old_forward_extrem_cache_txn,
) )
def clean_room_for_join(self, room_id): def clean_room_for_join(self, room_id):
return self.db.runInteraction( return self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id "clean_room_for_join", self._clean_room_for_join_txn, room_id
) )
@ -708,17 +708,19 @@ class EventFederationStore(EventFederationWorkerStore):
"max_stream_id_exclusive": min_stream_id, "max_stream_id_exclusive": min_stream_id,
} }
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_AUTH_STATE_ONLY, new_progress txn, self.EVENT_AUTH_STATE_ONLY, new_progress
) )
return min_stream_id >= target_min_stream_id return min_stream_id >= target_min_stream_id
result = yield self.db.runInteraction( result = yield self.db_pool.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth self.EVENT_AUTH_STATE_ONLY, delete_event_auth
) )
if not result: if not result:
yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) yield self.db_pool.updates._end_background_update(
self.EVENT_AUTH_STATE_ONLY
)
return batch_size return batch_size

View file

@ -21,7 +21,7 @@ from canonicaljson import json
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,7 +66,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn # These get correctly set by _find_stream_orderings_for_times_txn
@ -91,7 +91,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user( def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id, user_id, last_read_event_id
): ):
ret = yield self.db.runInteraction( ret = yield self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room", "get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn, self._get_unread_counts_by_receipt_txn,
room_id, room_id,
@ -176,7 +176,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn] return [r[0] for r in txn]
ret = await self.db.runInteraction("get_push_action_users_in_range", f) ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f)
return ret return ret
async def get_unread_push_actions_for_user_in_range_for_http( async def get_unread_push_actions_for_user_in_range_for_http(
@ -230,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = await self.db.runInteraction( after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
) )
@ -258,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = await self.db.runInteraction( no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
) )
@ -332,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = await self.db.runInteraction( after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
) )
@ -360,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = await self.db.runInteraction( no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
) )
@ -410,7 +410,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering)) txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone()) return bool(txn.fetchone())
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user", "get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn, _get_if_maybe_push_in_range_for_user_txn,
) )
@ -461,7 +461,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
), ),
) )
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn "add_push_actions_to_staging", _add_push_actions_to_staging_txn
) )
@ -471,7 +471,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
""" """
try: try:
res = await self.db.simple_delete( res = await self.db_pool.simple_delete(
table="event_push_actions_staging", table="event_push_actions_staging",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging", desc="remove_push_actions_from_staging",
@ -488,7 +488,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _find_stream_orderings_for_times(self): def _find_stream_orderings_for_times(self):
return run_as_background_process( return run_as_background_process(
"event_push_action_stream_orderings", "event_push_action_stream_orderings",
self.db.runInteraction, self.db_pool.runInteraction,
"_find_stream_orderings_for_times", "_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn, self._find_stream_orderings_for_times_txn,
) )
@ -524,7 +524,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Deferred[int]: stream ordering of the first event received on/after Deferred[int]: stream ordering of the first event received on/after
the timestamp the timestamp
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn", "_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn, self._find_first_stream_ordering_after_ts_txn,
ts, ts,
@ -619,24 +619,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (stream_ordering,)) txn.execute(sql, (stream_ordering,))
return txn.fetchone() return txn.fetchone()
result = await self.db.runInteraction("get_time_of_last_push_action_before", f) result = await self.db_pool.runInteraction(
"get_time_of_last_push_action_before", f
)
return result[0] if result else None return result[0] if result else None
class EventPushActionsStore(EventPushActionsWorkerStore): class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(EventPushActionsStore, self).__init__(database, db_conn, hs) super(EventPushActionsStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX, self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight", index_name="event_push_actions_u_highlight",
table="event_push_actions", table="event_push_actions",
columns=["user_id", "stream_ordering"], columns=["user_id", "stream_ordering"],
) )
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"event_push_actions_highlights_index", "event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index", index_name="event_push_actions_highlights_index",
table="event_push_actions", table="event_push_actions",
@ -678,9 +680,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,) " LIMIT ?" % (before_clause,)
) )
txn.execute(sql, args) txn.execute(sql, args)
return self.db.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
push_actions = await self.db.runInteraction("get_push_actions_for_user", f) push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
for pa in push_actions: for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions return push_actions
@ -690,7 +692,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone() return txn.fetchone()
result = await self.db.runInteraction( result = await self.db_pool.runInteraction(
"get_latest_push_action_stream_ordering", f "get_latest_push_action_stream_ordering", f
) )
return result[0] or 0 return result[0] or 0
@ -753,7 +755,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True: while True:
logger.info("Rotating notifications") logger.info("Rotating notifications")
caught_up = await self.db.runInteraction( caught_up = await self.db_pool.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn "_rotate_notifs", self._rotate_notifs_txn
) )
if caught_up: if caught_up:
@ -767,7 +769,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not. the archiving process has caught up or not.
""" """
old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="event_push_summary_stream_ordering", table="event_push_summary_stream_ordering",
keyvalues={}, keyvalues={},
@ -803,7 +805,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="event_push_summary_stream_ordering", table="event_push_summary_stream_ordering",
keyvalues={}, keyvalues={},
@ -835,7 +837,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# If the `old.user_id` above is NULL then we know there isn't already an # If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the # entry in the table, so we simply insert it. Otherwise we update the
# existing table. # existing table.
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="event_push_summary", table="event_push_summary",
values=[ values=[

View file

@ -32,8 +32,8 @@ from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.data_stores.main.search import SearchEntry from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import Database, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import StateMap, get_domain_from_id from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
@ -41,7 +41,7 @@ from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.data_stores.main import DataStore from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -132,9 +132,11 @@ class PersistEventsStore:
Note: This is not part of the `DataStore` mixin. Note: This is not part of the `DataStore` mixin.
""" """
def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"): def __init__(
self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore"
):
self.hs = hs self.hs = hs
self.db = db self.db_pool = db
self.store = main_data_store self.store = main_data_store
self.database_engine = db.engine self.database_engine = db.engine
self._clock = hs.get_clock() self._clock = hs.get_clock()
@ -207,7 +209,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings): for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"persist_events", "persist_events",
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=events_and_contexts, events_and_contexts=events_and_contexts,
@ -283,7 +285,7 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100): for chunk in batch_iter(event_ids, 100):
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
) )
@ -347,7 +349,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id) existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100): for chunk in batch_iter(event_ids, 100):
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
) )
@ -421,7 +423,7 @@ class PersistEventsStore:
# event's auth chain, but its easier for now just to store them (and # event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event # it doesn't take much storage compared to storing the entire event
# anyway). # anyway).
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="event_auth", table="event_auth",
values=[ values=[
@ -484,7 +486,7 @@ class PersistEventsStore:
""" """
txn.execute(sql, (stream_id, room_id)) txn.execute(sql, (stream_id, room_id))
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="current_state_events", keyvalues={"room_id": room_id}, txn, table="current_state_events", keyvalues={"room_id": room_id},
) )
else: else:
@ -632,7 +634,7 @@ class PersistEventsStore:
creator = content.get("creator") creator = content.get("creator")
room_version_id = content.get("room_version", RoomVersions.V1.identifier) room_version_id = content.get("room_version", RoomVersions.V1.identifier)
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
@ -644,14 +646,14 @@ class PersistEventsStore:
self, txn, new_forward_extremities, max_stream_order self, txn, new_forward_extremities, max_stream_order
): ):
for room_id, new_extrem in new_forward_extremities.items(): for room_id, new_extrem in new_forward_extremities.items():
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id} txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
) )
txn.call_after( txn.call_after(
self.store.get_latest_event_ids_in_room.invalidate, (room_id,) self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="event_forward_extremities", table="event_forward_extremities",
values=[ values=[
@ -664,7 +666,7 @@ class PersistEventsStore:
# new stream_ordering to new forward extremeties in the room. # new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties # This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering # for a room before a given stream_ordering
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="stream_ordering_to_exterm", table="stream_ordering_to_exterm",
values=[ values=[
@ -788,7 +790,7 @@ class PersistEventsStore:
# change in outlier status to our workers. # change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group state_group_id = context.state_group
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="ex_outlier_stream", table="ex_outlier_stream",
values={ values={
@ -826,7 +828,7 @@ class PersistEventsStore:
d.pop("redacted_because", None) d.pop("redacted_because", None)
return d return d
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="event_json", table="event_json",
values=[ values=[
@ -843,7 +845,7 @@ class PersistEventsStore:
], ],
) )
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="events", table="events",
values=[ values=[
@ -873,7 +875,7 @@ class PersistEventsStore:
# If we're persisting an unredacted event we go and ensure # If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as # that we mark any redactions that reference this event as
# requiring censoring. # requiring censoring.
self.db.simple_update_txn( self.db_pool.simple_update_txn(
txn, txn,
table="redactions", table="redactions",
keyvalues={"redacts": event.event_id}, keyvalues={"redacts": event.event_id},
@ -1015,7 +1017,9 @@ class PersistEventsStore:
state_values.append(vals) state_values.append(vals)
self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) self.db_pool.simple_insert_many_txn(
txn, table="state_events", values=state_values
)
# Prefill the event cache # Prefill the event cache
self._add_to_cache(txn, events_and_contexts) self._add_to_cache(txn, events_and_contexts)
@ -1046,7 +1050,7 @@ class PersistEventsStore:
) )
txn.execute(sql + clause, args) txn.execute(sql + clause, args)
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
for row in rows: for row in rows:
event = ev_map[row["event_id"]] event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]: if not row["rejects"] and not row["redacts"]:
@ -1066,7 +1070,7 @@ class PersistEventsStore:
# invalidate the cache for the redacted event # invalidate the cache for the redacted event
txn.call_after(self.store._invalidate_get_event_cache, event.redacts) txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="redactions", table="redactions",
values={ values={
@ -1089,7 +1093,7 @@ class PersistEventsStore:
room_id (str): The ID of the room the event was sent to. room_id (str): The ID of the room the event was sent to.
topological_ordering (int): The position of the event in the room's topology. topological_ordering (int): The position of the event in the room's topology.
""" """
return self.db.simple_insert_many_txn( return self.db_pool.simple_insert_many_txn(
txn=txn, txn=txn,
table="event_labels", table="event_labels",
values=[ values=[
@ -1111,7 +1115,7 @@ class PersistEventsStore:
event_id (str): The event ID the expiry timestamp is associated with. event_id (str): The event ID the expiry timestamp is associated with.
expiry_ts (int): The timestamp at which to expire (delete) the event. expiry_ts (int): The timestamp at which to expire (delete) the event.
""" """
return self.db.simple_insert_txn( return self.db_pool.simple_insert_txn(
txn=txn, txn=txn,
table="event_expiry", table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts}, values={"event_id": event_id, "expiry_ts": expiry_ts},
@ -1135,12 +1139,14 @@ class PersistEventsStore:
} }
) )
self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) self.db_pool.simple_insert_many_txn(
txn, table="event_reference_hashes", values=vals
)
def _store_room_members_txn(self, txn, events, backfilled): def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database. """Store a room member in the database.
""" """
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="room_memberships", table="room_memberships",
values=[ values=[
@ -1180,7 +1186,7 @@ class PersistEventsStore:
and event.internal_metadata.is_outlier() and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership() and event.internal_metadata.is_out_of_band_membership()
): ):
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="local_current_membership", table="local_current_membership",
keyvalues={"room_id": event.room_id, "user_id": event.state_key}, keyvalues={"room_id": event.room_id, "user_id": event.state_key},
@ -1218,7 +1224,7 @@ class PersistEventsStore:
aggregation_key = relation.get("key") aggregation_key = relation.get("key")
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="event_relations", table="event_relations",
values={ values={
@ -1246,7 +1252,7 @@ class PersistEventsStore:
redacted_event_id (str): The event that was redacted. redacted_event_id (str): The event that was redacted.
""" """
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id} txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
) )
@ -1282,7 +1288,7 @@ class PersistEventsStore:
# Ignore the event if one of the value isn't an integer. # Ignore the event if one of the value isn't an integer.
return return
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn=txn, txn=txn,
table="room_retention", table="room_retention",
values={ values={
@ -1363,7 +1369,7 @@ class PersistEventsStore:
) )
for event, _ in events_and_contexts: for event, _ in events_and_contexts:
user_ids = self.db.simple_select_onecol_txn( user_ids = self.db_pool.simple_select_onecol_txn(
txn, txn,
table="event_push_actions_staging", table="event_push_actions_staging",
keyvalues={"event_id": event.event_id}, keyvalues={"event_id": event.event_id},
@ -1395,7 +1401,7 @@ class PersistEventsStore:
) )
def _store_rejections_txn(self, txn, event_id, reason): def _store_rejections_txn(self, txn, event_id, reason):
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="rejections", table="rejections",
values={ values={
@ -1421,7 +1427,7 @@ class PersistEventsStore:
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",
values=[ values=[
@ -1443,7 +1449,7 @@ class PersistEventsStore:
if min_depth is not None and depth >= min_depth: if min_depth is not None and depth >= min_depth:
return return
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="room_depth", table="room_depth",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
@ -1455,7 +1461,7 @@ class PersistEventsStore:
For the given event, update the event edges table and forward and For the given event, update the event edges table and forward and
backward extremities tables. backward extremities tables.
""" """
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="event_edges", table="event_edges",
values=[ values=[

View file

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,18 +30,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender, self._background_reindex_fields_sender,
) )
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"event_contains_url_index", "event_contains_url_index",
index_name="event_contains_url_index", index_name="event_contains_url_index",
table="events", table="events",
@ -52,7 +52,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# an event_id index on event_search is useful for the purge_history # an event_id index on event_search is useful for the purge_history
# api. Plus it means we get to enforce some integrity with a UNIQUE # api. Plus it means we get to enforce some integrity with a UNIQUE
# clause # clause
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"event_search_event_id_idx", "event_search_event_id_idx",
index_name="event_search_event_id_idx", index_name="event_search_event_id_idx",
table="event_search", table="event_search",
@ -61,16 +61,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
psql_only=True, psql_only=True,
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"redactions_received_ts", self._redactions_received_ts "redactions_received_ts", self._redactions_received_ts
) )
# This index gets deleted in `event_fix_redactions_bytes` update # This index gets deleted in `event_fix_redactions_bytes` update
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"event_fix_redactions_bytes_create_index", "event_fix_redactions_bytes_create_index",
index_name="redactions_censored_redacts", index_name="redactions_censored_redacts",
table="redactions", table="redactions",
@ -78,15 +78,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="have_censored", where_clause="have_censored",
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"event_fix_redactions_bytes", self._event_fix_redactions_bytes "event_fix_redactions_bytes", self._event_fix_redactions_bytes
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"event_store_labels", self._event_store_labels "event_store_labels", self._event_store_labels
) )
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"redactions_have_censored_ts_idx", "redactions_have_censored_ts_idx",
index_name="redactions_have_censored_ts", index_name="redactions_have_censored_ts",
table="redactions", table="redactions",
@ -149,18 +149,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rows_inserted": rows_inserted + len(rows), "rows_inserted": rows_inserted + len(rows),
} }
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
) )
return len(rows) return len(rows)
result = yield self.db.runInteraction( result = yield self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
) )
if not result: if not result:
yield self.db.updates._end_background_update( yield self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
) )
@ -195,7 +195,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks: for chunk in chunks:
ev_rows = self.db.simple_select_many_txn( ev_rows = self.db_pool.simple_select_many_txn(
txn, txn,
table="event_json", table="event_json",
column="event_id", column="event_id",
@ -228,18 +228,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rows_inserted": rows_inserted + len(rows_to_update), "rows_inserted": rows_inserted + len(rows_to_update),
} }
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
) )
return len(rows_to_update) return len(rows_to_update)
result = yield self.db.runInteraction( result = yield self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
) )
if not result: if not result:
yield self.db.updates._end_background_update( yield self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME self.EVENT_ORIGIN_SERVER_TS_NAME
) )
@ -374,7 +374,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
to_delete.intersection_update(original_set) to_delete.intersection_update(original_set)
deleted = self.db.simple_delete_many_txn( deleted = self.db_pool.simple_delete_many_txn(
txn=txn, txn=txn,
table="event_forward_extremities", table="event_forward_extremities",
column="event_id", column="event_id",
@ -390,7 +390,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if deleted: if deleted:
# We now need to invalidate the caches of these rooms # We now need to invalidate the caches of these rooms
rows = self.db.simple_select_many_txn( rows = self.db_pool.simple_select_many_txn(
txn, txn,
table="events", table="events",
column="event_id", column="event_id",
@ -404,7 +404,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate, (room_id,) self.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
self.db.simple_delete_many_txn( self.db_pool.simple_delete_many_txn(
txn=txn, txn=txn,
table="_extremities_to_check", table="_extremities_to_check",
column="event_id", column="event_id",
@ -414,19 +414,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set) return len(original_set)
num_handled = yield self.db.runInteraction( num_handled = yield self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
) )
if not num_handled: if not num_handled:
yield self.db.updates._end_background_update( yield self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES self.DELETE_SOFT_FAILED_EXTREMITIES
) )
def _drop_table_txn(txn): def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check") txn.execute("DROP TABLE _extremities_to_check")
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
) )
@ -474,18 +474,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "redactions_received_ts", {"last_event_id": upper_event_id} txn, "redactions_received_ts", {"last_event_id": upper_event_id}
) )
return len(rows) return len(rows)
count = yield self.db.runInteraction( count = yield self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn "_redactions_received_ts", _redactions_received_ts_txn
) )
if not count: if not count:
yield self.db.updates._end_background_update("redactions_received_ts") yield self.db_pool.updates._end_background_update("redactions_received_ts")
return count return count
@ -511,11 +511,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts") txn.execute("DROP INDEX redactions_censored_redacts")
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
) )
yield self.db.updates._end_background_update("event_fix_redactions_bytes") yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1 return 1
@ -543,7 +543,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
try: try:
event_json = db_to_json(event_json_raw) event_json = db_to_json(event_json_raw)
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn=txn, txn=txn,
table="event_labels", table="event_labels",
values=[ values=[
@ -569,17 +569,17 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
nbrows += 1 nbrows += 1
last_row_event_id = event_id last_row_event_id = event_id
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "event_store_labels", {"last_event_id": last_row_event_id} txn, "event_store_labels", {"last_event_id": last_row_event_id}
) )
return nbrows return nbrows
num_rows = yield self.db.runInteraction( num_rows = yield self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn desc="event_store_labels", func=_event_store_labels_txn
) )
if not num_rows: if not num_rows:
yield self.db.updates._end_background_update("event_store_labels") yield self.db_pool.updates._end_background_update("event_store_labels")
return num_rows return num_rows

View file

@ -40,7 +40,7 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -80,7 +80,7 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore): class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs) super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.writers.events == hs.get_instance_name(): if hs.config.worker.writers.events == hs.get_instance_name():
@ -136,7 +136,7 @@ class EventsWorkerStore(SQLBaseStore):
Deferred[int|None]: Timestamp in milliseconds, or None for events Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented. that were persisted before received_ts was implemented.
""" """
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="events", table="events",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
retcol="received_ts", retcol="received_ts",
@ -175,7 +175,7 @@ class EventsWorkerStore(SQLBaseStore):
return ts return ts
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_approximate_received_ts", _get_approximate_received_ts_txn "get_approximate_received_ts", _get_approximate_received_ts_txn
) )
@ -543,7 +543,7 @@ class EventsWorkerStore(SQLBaseStore):
event_id for events, _ in event_list for event_id in events event_id for events, _ in event_list for event_id in events
} }
row_dict = self.db.new_transaction( row_dict = self.db_pool.new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
) )
@ -720,7 +720,7 @@ class EventsWorkerStore(SQLBaseStore):
if should_start: if should_start:
run_as_background_process( run_as_background_process(
"fetch_events", self.db.runWithConnection, self._do_fetch "fetch_events", self.db_pool.runWithConnection, self._do_fetch
) )
logger.debug("Loading %d events: %s", len(events), events) logger.debug("Loading %d events: %s", len(events), events)
@ -889,7 +889,7 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and """Given a list of event ids, check if we have already processed and
stored them as non outliers. stored them as non outliers.
""" """
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="events", table="events",
retcols=("event_id",), retcols=("event_id",),
column="event_id", column="event_id",
@ -924,7 +924,7 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100 # break the input up into chunks of 100
input_iterator = iter(event_ids) input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk "have_seen_events", have_seen_events_txn, chunk
) )
return results return results
@ -953,7 +953,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred[int] Deferred[int]
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_total_state_event_counts", "get_total_state_event_counts",
self._get_total_state_event_counts_txn, self._get_total_state_event_counts_txn,
room_id, room_id,
@ -978,7 +978,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred[int] Deferred[int]
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_current_state_event_counts", "get_current_state_event_counts",
self._get_current_state_event_counts_txn, self._get_current_state_event_counts_txn,
room_id, room_id,
@ -1043,7 +1043,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows "get_all_new_forward_event_rows", get_all_new_forward_event_rows
) )
@ -1077,7 +1077,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id)) txn.execute(sql, (last_id, current_id))
return txn.fetchall() return txn.fetchall()
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
) )
@ -1151,7 +1151,7 @@ class EventsWorkerStore(SQLBaseStore):
return new_event_updates, upper_bound, limited return new_event_updates, upper_bound, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
) )
@ -1199,7 +1199,7 @@ class EventsWorkerStore(SQLBaseStore):
# we need to make sure that, for every stream id in the results, we get *all* # we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id. # the rows with that stream id.
rows = await self.db.runInteraction( rows = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas", "get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn, get_all_updated_current_state_deltas_txn,
) # type: List[Tuple] ) # type: List[Tuple]
@ -1222,7 +1222,7 @@ class EventsWorkerStore(SQLBaseStore):
# stream id. let's run the query again, without a row limit, but for # stream id. let's run the query again, without a row limit, but for
# just one stream id. # just one stream id.
to_token += 1 to_token += 1
rows = await self.db.runInteraction( rows = await self.db_pool.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
) )
@ -1317,7 +1317,7 @@ class EventsWorkerStore(SQLBaseStore):
backward_ex_outliers, backward_ex_outliers,
) )
return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
async def is_event_after(self, event_id1, event_id2): async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream """Returns True if event_id1 is after event_id2 in the stream
@ -1328,7 +1328,7 @@ class EventsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(max_entries=5000) @cachedInlineCallbacks(max_entries=5000)
def get_event_ordering(self, event_id): def get_event_ordering(self, event_id):
res = yield self.db.simple_select_one( res = yield self.db_pool.simple_select_one(
table="events", table="events",
retcols=["topological_ordering", "stream_ordering"], retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
@ -1360,7 +1360,7 @@ class EventsWorkerStore(SQLBaseStore):
return txn.fetchone() return txn.fetchone()
return self.db.runInteraction( return self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
) )
@ -1385,7 +1385,7 @@ class EventsWorkerStore(SQLBaseStore):
on_invalidate=cache_context.invalidate, on_invalidate=cache_context.invalidate,
) )
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_unread_message_count_for_user", "get_unread_message_count_for_user",
self._get_unread_message_count_for_user_txn, self._get_unread_message_count_for_user_txn,
user_id, user_id,
@ -1402,7 +1402,7 @@ class EventsWorkerStore(SQLBaseStore):
) -> int: ) -> int:
if last_read_event_id: if last_read_event_id:
# Get the stream ordering for the last read event. # Get the stream ordering for the last read event.
stream_ordering = self.db.simple_select_one_onecol_txn( stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn=txn, txn=txn,
table="events", table="events",
keyvalues={"room_id": room_id, "event_id": last_read_event_id}, keyvalues={"room_id": room_id, "event_id": last_read_event_id},

View file

@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError: except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
def_json = yield self.db.simple_select_one_onecol( def_json = yield self.db_pool.simple_select_one_onecol(
table="user_filters", table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id}, keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json", retcol="filter_json",
@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore):
return filter_id return filter_id
return self.db.runInteraction("add_user_filter", _do_txn) return self.db_pool.runInteraction("add_user_filter", _do_txn)

View file

@ -31,7 +31,7 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore): class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id): def get_group(self, group_id):
return self.db.simple_select_one( return self.db_pool.simple_select_one(
table="groups", table="groups",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
retcols=( retcols=(
@ -53,7 +53,7 @@ class GroupServerWorkerStore(SQLBaseStore):
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
return self.db.simple_select_list( return self.db_pool.simple_select_list(
table="group_users", table="group_users",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"), retcols=("user_id", "is_public", "is_admin"),
@ -63,7 +63,7 @@ class GroupServerWorkerStore(SQLBaseStore):
def get_invited_users_in_group(self, group_id): def get_invited_users_in_group(self, group_id):
# TODO: Pagination # TODO: Pagination
return self.db.simple_select_onecol( return self.db_pool.simple_select_onecol(
table="group_invites", table="group_invites",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
retcol="user_id", retcol="user_id",
@ -117,7 +117,9 @@ class GroupServerWorkerStore(SQLBaseStore):
for room_id, is_public in txn for room_id, is_public in txn
] ]
return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) return self.db_pool.runInteraction(
"get_rooms_in_group", _get_rooms_in_group_txn
)
def get_rooms_for_summary_by_category( def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False, self, group_id: str, include_private: bool = False,
@ -205,13 +207,13 @@ class GroupServerWorkerStore(SQLBaseStore):
return rooms, categories return rooms, categories
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn "get_rooms_for_summary", _get_rooms_for_summary_txn
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_categories(self, group_id): def get_group_categories(self, group_id):
rows = yield self.db.simple_select_list( rows = yield self.db_pool.simple_select_list(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"), retcols=("category_id", "is_public", "profile"),
@ -228,7 +230,7 @@ class GroupServerWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_category(self, group_id, category_id): def get_group_category(self, group_id, category_id):
category = yield self.db.simple_select_one( category = yield self.db_pool.simple_select_one(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"), retcols=("is_public", "profile"),
@ -241,7 +243,7 @@ class GroupServerWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_roles(self, group_id): def get_group_roles(self, group_id):
rows = yield self.db.simple_select_list( rows = yield self.db_pool.simple_select_list(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"), retcols=("role_id", "is_public", "profile"),
@ -258,7 +260,7 @@ class GroupServerWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_role(self, group_id, role_id): def get_group_role(self, group_id, role_id):
role = yield self.db.simple_select_one( role = yield self.db_pool.simple_select_one(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"), retcols=("is_public", "profile"),
@ -277,7 +279,7 @@ class GroupServerWorkerStore(SQLBaseStore):
Deferred[list[str]]: A twisted.Deferred containing a list of group ids Deferred[list[str]]: A twisted.Deferred containing a list of group ids
containing this room containing this room
""" """
return self.db.simple_select_onecol( return self.db_pool.simple_select_onecol(
table="group_rooms", table="group_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="group_id", retcol="group_id",
@ -341,12 +343,12 @@ class GroupServerWorkerStore(SQLBaseStore):
return users, roles return users, roles
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn "get_users_for_summary_by_role", _get_users_for_summary_txn
) )
def is_user_in_group(self, user_id, group_id): def is_user_in_group(self, user_id, group_id):
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="group_users", table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id", retcol="user_id",
@ -355,7 +357,7 @@ class GroupServerWorkerStore(SQLBaseStore):
).addCallback(lambda r: bool(r)) ).addCallback(lambda r: bool(r))
def is_user_admin_in_group(self, group_id, user_id): def is_user_admin_in_group(self, group_id, user_id):
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="group_users", table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin", retcol="is_admin",
@ -366,7 +368,7 @@ class GroupServerWorkerStore(SQLBaseStore):
def is_user_invited_to_local_group(self, group_id, user_id): def is_user_invited_to_local_group(self, group_id, user_id):
"""Has the group server invited a user? """Has the group server invited a user?
""" """
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="group_invites", table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id", retcol="user_id",
@ -389,7 +391,7 @@ class GroupServerWorkerStore(SQLBaseStore):
""" """
def _get_users_membership_in_group_txn(txn): def _get_users_membership_in_group_txn(txn):
row = self.db.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
table="group_users", table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
@ -404,7 +406,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_privileged": row["is_admin"], "is_privileged": row["is_admin"],
} }
row = self.db.simple_select_one_onecol_txn( row = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="group_invites", table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
@ -417,14 +419,14 @@ class GroupServerWorkerStore(SQLBaseStore):
return {} return {}
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn "get_users_membership_info_in_group", _get_users_membership_in_group_txn
) )
def get_publicised_groups_for_user(self, user_id): def get_publicised_groups_for_user(self, user_id):
"""Get all groups a user is publicising """Get all groups a user is publicising
""" """
return self.db.simple_select_onecol( return self.db_pool.simple_select_onecol(
table="local_group_membership", table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id", retcol="group_id",
@ -441,9 +443,9 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE valid_until_ms <= ? WHERE valid_until_ms <= ?
""" """
txn.execute(sql, (valid_until_ms,)) txn.execute(sql, (valid_until_ms,))
return self.db.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn "get_attestations_need_renewals", _get_attestations_need_renewals_txn
) )
@ -452,7 +454,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"""Get the attestation that proves the remote agrees that the user is """Get the attestation that proves the remote agrees that the user is
in the group. in the group.
""" """
row = yield self.db.simple_select_one( row = yield self.db_pool.simple_select_one(
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"), retcols=("valid_until_ms", "attestation_json"),
@ -467,7 +469,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return None return None
def get_joined_groups(self, user_id): def get_joined_groups(self, user_id):
return self.db.simple_select_onecol( return self.db_pool.simple_select_onecol(
table="local_group_membership", table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"}, keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id", retcol="group_id",
@ -494,7 +496,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in txn for row in txn
] ]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn "get_all_groups_for_user", _get_all_groups_for_user_txn
) )
@ -524,7 +526,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for group_id, membership, gtype, content_json in txn for group_id, membership, gtype, content_json in txn
] ]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn "get_groups_changes_for_user", _get_groups_changes_for_user_txn
) )
@ -579,7 +581,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return updates, upto_token, limited return updates, upto_token, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn "get_all_groups_changes", _get_all_groups_changes_txn
) )
@ -592,7 +594,7 @@ class GroupServerStore(GroupServerWorkerStore):
* "invite" * "invite"
* "open" * "open"
""" """
return self.db.simple_update_one( return self.db_pool.simple_update_one(
table="groups", table="groups",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy}, updatevalues={"join_policy": join_policy},
@ -600,7 +602,7 @@ class GroupServerStore(GroupServerWorkerStore):
) )
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
return self.db.runInteraction( return self.db_pool.runInteraction(
"add_room_to_summary", "add_room_to_summary",
self._add_room_to_summary_txn, self._add_room_to_summary_txn,
group_id, group_id,
@ -624,7 +626,7 @@ class GroupServerStore(GroupServerWorkerStore):
an order of 1 will put the room first. Otherwise, the room gets an order of 1 will put the room first. Otherwise, the room gets
added to the end. added to the end.
""" """
room_in_group = self.db.simple_select_one_onecol_txn( room_in_group = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="group_rooms", table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id}, keyvalues={"group_id": group_id, "room_id": room_id},
@ -637,7 +639,7 @@ class GroupServerStore(GroupServerWorkerStore):
if category_id is None: if category_id is None:
category_id = _DEFAULT_CATEGORY_ID category_id = _DEFAULT_CATEGORY_ID
else: else:
cat_exists = self.db.simple_select_one_onecol_txn( cat_exists = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
@ -648,7 +650,7 @@ class GroupServerStore(GroupServerWorkerStore):
raise SynapseError(400, "Category doesn't exist") raise SynapseError(400, "Category doesn't exist")
# TODO: Check category is part of summary already # TODO: Check category is part of summary already
cat_exists = self.db.simple_select_one_onecol_txn( cat_exists = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="group_summary_room_categories", table="group_summary_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
@ -668,7 +670,7 @@ class GroupServerStore(GroupServerWorkerStore):
(group_id, category_id, group_id, category_id), (group_id, category_id, group_id, category_id),
) )
existing = self.db.simple_select_one_txn( existing = self.db_pool.simple_select_one_txn(
txn, txn,
table="group_summary_rooms", table="group_summary_rooms",
keyvalues={ keyvalues={
@ -701,7 +703,7 @@ class GroupServerStore(GroupServerWorkerStore):
to_update["room_order"] = order to_update["room_order"] = order
if is_public is not None: if is_public is not None:
to_update["is_public"] = is_public to_update["is_public"] = is_public
self.db.simple_update_txn( self.db_pool.simple_update_txn(
txn, txn,
table="group_summary_rooms", table="group_summary_rooms",
keyvalues={ keyvalues={
@ -715,7 +717,7 @@ class GroupServerStore(GroupServerWorkerStore):
if is_public is None: if is_public is None:
is_public = True is_public = True
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_summary_rooms", table="group_summary_rooms",
values={ values={
@ -731,7 +733,7 @@ class GroupServerStore(GroupServerWorkerStore):
if category_id is None: if category_id is None:
category_id = _DEFAULT_CATEGORY_ID category_id = _DEFAULT_CATEGORY_ID
return self.db.simple_delete( return self.db_pool.simple_delete(
table="group_summary_rooms", table="group_summary_rooms",
keyvalues={ keyvalues={
"group_id": group_id, "group_id": group_id,
@ -757,7 +759,7 @@ class GroupServerStore(GroupServerWorkerStore):
else: else:
update_values["is_public"] = is_public update_values["is_public"] = is_public
return self.db.simple_upsert( return self.db_pool.simple_upsert(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values, values=update_values,
@ -766,7 +768,7 @@ class GroupServerStore(GroupServerWorkerStore):
) )
def remove_group_category(self, group_id, category_id): def remove_group_category(self, group_id, category_id):
return self.db.simple_delete( return self.db_pool.simple_delete(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category", desc="remove_group_category",
@ -788,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
else: else:
update_values["is_public"] = is_public update_values["is_public"] = is_public
return self.db.simple_upsert( return self.db_pool.simple_upsert(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values, values=update_values,
@ -797,14 +799,14 @@ class GroupServerStore(GroupServerWorkerStore):
) )
def remove_group_role(self, group_id, role_id): def remove_group_role(self, group_id, role_id):
return self.db.simple_delete( return self.db_pool.simple_delete(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role", desc="remove_group_role",
) )
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
return self.db.runInteraction( return self.db_pool.runInteraction(
"add_user_to_summary", "add_user_to_summary",
self._add_user_to_summary_txn, self._add_user_to_summary_txn,
group_id, group_id,
@ -828,7 +830,7 @@ class GroupServerStore(GroupServerWorkerStore):
an order of 1 will put the user first. Otherwise, the user gets an order of 1 will put the user first. Otherwise, the user gets
added to the end. added to the end.
""" """
user_in_group = self.db.simple_select_one_onecol_txn( user_in_group = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="group_users", table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
@ -841,7 +843,7 @@ class GroupServerStore(GroupServerWorkerStore):
if role_id is None: if role_id is None:
role_id = _DEFAULT_ROLE_ID role_id = _DEFAULT_ROLE_ID
else: else:
role_exists = self.db.simple_select_one_onecol_txn( role_exists = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
@ -852,7 +854,7 @@ class GroupServerStore(GroupServerWorkerStore):
raise SynapseError(400, "Role doesn't exist") raise SynapseError(400, "Role doesn't exist")
# TODO: Check role is part of the summary already # TODO: Check role is part of the summary already
role_exists = self.db.simple_select_one_onecol_txn( role_exists = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="group_summary_roles", table="group_summary_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
@ -872,7 +874,7 @@ class GroupServerStore(GroupServerWorkerStore):
(group_id, role_id, group_id, role_id), (group_id, role_id, group_id, role_id),
) )
existing = self.db.simple_select_one_txn( existing = self.db_pool.simple_select_one_txn(
txn, txn,
table="group_summary_users", table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@ -901,7 +903,7 @@ class GroupServerStore(GroupServerWorkerStore):
to_update["user_order"] = order to_update["user_order"] = order
if is_public is not None: if is_public is not None:
to_update["is_public"] = is_public to_update["is_public"] = is_public
self.db.simple_update_txn( self.db_pool.simple_update_txn(
txn, txn,
table="group_summary_users", table="group_summary_users",
keyvalues={ keyvalues={
@ -915,7 +917,7 @@ class GroupServerStore(GroupServerWorkerStore):
if is_public is None: if is_public is None:
is_public = True is_public = True
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_summary_users", table="group_summary_users",
values={ values={
@ -931,7 +933,7 @@ class GroupServerStore(GroupServerWorkerStore):
if role_id is None: if role_id is None:
role_id = _DEFAULT_ROLE_ID role_id = _DEFAULT_ROLE_ID
return self.db.simple_delete( return self.db_pool.simple_delete(
table="group_summary_users", table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary", desc="remove_user_from_summary",
@ -940,7 +942,7 @@ class GroupServerStore(GroupServerWorkerStore):
def add_group_invite(self, group_id, user_id): def add_group_invite(self, group_id, user_id):
"""Record that the group server has invited a user """Record that the group server has invited a user
""" """
return self.db.simple_insert( return self.db_pool.simple_insert(
table="group_invites", table="group_invites",
values={"group_id": group_id, "user_id": user_id}, values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite", desc="add_group_invite",
@ -970,7 +972,7 @@ class GroupServerStore(GroupServerWorkerStore):
""" """
def _add_user_to_group_txn(txn): def _add_user_to_group_txn(txn):
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_users", table="group_users",
values={ values={
@ -981,14 +983,14 @@ class GroupServerStore(GroupServerWorkerStore):
}, },
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_invites", table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
if local_attestation: if local_attestation:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_attestations_renewals", table="group_attestations_renewals",
values={ values={
@ -998,7 +1000,7 @@ class GroupServerStore(GroupServerWorkerStore):
}, },
) )
if remote_attestation: if remote_attestation:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_attestations_remote", table="group_attestations_remote",
values={ values={
@ -1009,49 +1011,49 @@ class GroupServerStore(GroupServerWorkerStore):
}, },
) )
return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id): def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn): def _remove_user_from_group_txn(txn):
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_users", table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_invites", table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_summary_users", table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn "remove_user_from_group", _remove_user_from_group_txn
) )
def add_room_to_group(self, group_id, room_id, is_public): def add_room_to_group(self, group_id, room_id, is_public):
return self.db.simple_insert( return self.db_pool.simple_insert(
table="group_rooms", table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group", desc="add_room_to_group",
) )
def update_room_in_group_visibility(self, group_id, room_id, is_public): def update_room_in_group_visibility(self, group_id, room_id, is_public):
return self.db.simple_update( return self.db_pool.simple_update(
table="group_rooms", table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id}, keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public}, updatevalues={"is_public": is_public},
@ -1060,26 +1062,26 @@ class GroupServerStore(GroupServerWorkerStore):
def remove_room_from_group(self, group_id, room_id): def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn): def _remove_room_from_group_txn(txn):
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_rooms", table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id}, keyvalues={"group_id": group_id, "room_id": room_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_summary_rooms", table="group_summary_rooms",
keyvalues={"group_id": group_id, "room_id": room_id}, keyvalues={"group_id": group_id, "room_id": room_id},
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn "remove_room_from_group", _remove_room_from_group_txn
) )
def update_group_publicity(self, group_id, user_id, publicise): def update_group_publicity(self, group_id, user_id, publicise):
"""Update whether the user is publicising their membership of the group """Update whether the user is publicising their membership of the group
""" """
return self.db.simple_update_one( return self.db_pool.simple_update_one(
table="local_group_membership", table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise}, updatevalues={"is_publicised": publicise},
@ -1115,12 +1117,12 @@ class GroupServerStore(GroupServerWorkerStore):
def _register_user_group_membership_txn(txn, next_id): def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert? # TODO: Upsert?
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="local_group_membership", table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="local_group_membership", table="local_group_membership",
values={ values={
@ -1133,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore):
}, },
) )
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="local_group_updates", table="local_group_updates",
values={ values={
@ -1152,7 +1154,7 @@ class GroupServerStore(GroupServerWorkerStore):
if membership == "join": if membership == "join":
if local_attestation: if local_attestation:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_attestations_renewals", table="group_attestations_renewals",
values={ values={
@ -1162,7 +1164,7 @@ class GroupServerStore(GroupServerWorkerStore):
}, },
) )
if remote_attestation: if remote_attestation:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_attestations_remote", table="group_attestations_remote",
values={ values={
@ -1173,12 +1175,12 @@ class GroupServerStore(GroupServerWorkerStore):
}, },
) )
else: else:
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
@ -1187,7 +1189,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id return next_id
with self._group_updates_id_gen.get_next() as next_id: with self._group_updates_id_gen.get_next() as next_id:
res = yield self.db.runInteraction( res = yield self.db_pool.runInteraction(
"register_user_group_membership", "register_user_group_membership",
_register_user_group_membership_txn, _register_user_group_membership_txn,
next_id, next_id,
@ -1198,7 +1200,7 @@ class GroupServerStore(GroupServerWorkerStore):
def create_group( def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description self, group_id, user_id, name, avatar_url, short_description, long_description
): ):
yield self.db.simple_insert( yield self.db_pool.simple_insert(
table="groups", table="groups",
values={ values={
"group_id": group_id, "group_id": group_id,
@ -1213,7 +1215,7 @@ class GroupServerStore(GroupServerWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def update_group_profile(self, group_id, profile): def update_group_profile(self, group_id, profile):
yield self.db.simple_update_one( yield self.db_pool.simple_update_one(
table="groups", table="groups",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
updatevalues=profile, updatevalues=profile,
@ -1223,7 +1225,7 @@ class GroupServerStore(GroupServerWorkerStore):
def update_attestation_renewal(self, group_id, user_id, attestation): def update_attestation_renewal(self, group_id, user_id, attestation):
"""Update an attestation that we have renewed """Update an attestation that we have renewed
""" """
return self.db.simple_update_one( return self.db_pool.simple_update_one(
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@ -1233,7 +1235,7 @@ class GroupServerStore(GroupServerWorkerStore):
def update_remote_attestion(self, group_id, user_id, attestation): def update_remote_attestion(self, group_id, user_id, attestation):
"""Update an attestation that a remote has renewed """Update an attestation that a remote has renewed
""" """
return self.db.simple_update_one( return self.db_pool.simple_update_one(
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={ updatevalues={
@ -1252,7 +1254,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id (str) group_id (str)
user_id (str) user_id (str)
""" """
return self.db.simple_delete( return self.db_pool.simple_delete(
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal", desc="remove_attestation_renewal",
@ -1288,8 +1290,8 @@ class GroupServerStore(GroupServerWorkerStore):
] ]
for table in tables: for table in tables:
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table=table, keyvalues={"group_id": group_id} txn, table=table, keyvalues={"group_id": group_id}
) )
return self.db.runInteraction("delete_group", _delete_group_txn) return self.db_pool.runInteraction("delete_group", _delete_group_txn)

View file

@ -86,7 +86,7 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch) _get_keys(txn, batch)
return keys return keys
return self.db.runInteraction("get_server_verify_keys", _txn) return self.db_pool.runInteraction("get_server_verify_keys", _txn)
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers. """Stores NACL verification keys for remote servers.
@ -121,9 +121,9 @@ class KeyStore(SQLBaseStore):
f((i,)) f((i,))
return res return res
return self.db.runInteraction( return self.db_pool.runInteraction(
"store_server_verify_keys", "store_server_verify_keys",
self.db.simple_upsert_many_txn, self.db_pool.simple_upsert_many_txn,
table="server_signature_keys", table="server_signature_keys",
key_names=("server_name", "key_id"), key_names=("server_name", "key_id"),
key_values=key_values, key_values=key_values,
@ -151,7 +151,7 @@ class KeyStore(SQLBaseStore):
ts_valid_until_ms (int): The time when this json stops being valid. ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON. key_json (bytes): The encoded JSON.
""" """
return self.db.simple_upsert( return self.db_pool.simple_upsert(
table="server_keys_json", table="server_keys_json",
keyvalues={ keyvalues={
"server_name": server_name, "server_name": server_name,
@ -190,7 +190,7 @@ class KeyStore(SQLBaseStore):
keyvalues["key_id"] = key_id keyvalues["key_id"] = key_id
if from_server is not None: if from_server is not None:
keyvalues["from_server"] = from_server keyvalues["from_server"] = from_server
rows = self.db.simple_select_list_txn( rows = self.db_pool.simple_select_list_txn(
txn, txn,
"server_keys_json", "server_keys_json",
keyvalues=keyvalues, keyvalues=keyvalues,
@ -205,4 +205,6 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows results[(server_name, key_id, from_server)] = rows
return results return results
return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) return self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
)

View file

@ -13,16 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryBackgroundUpdateStore, self).__init__( super(MediaRepositoryBackgroundUpdateStore, self).__init__(
database, db_conn, hs database, db_conn, hs
) )
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
update_name="local_media_repository_url_idx", update_name="local_media_repository_url_idx",
index_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx",
table="local_media_repository", table="local_media_repository",
@ -34,7 +34,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars""" """Persistence for attachments and avatars"""
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs) super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
def get_local_media(self, media_id): def get_local_media(self, media_id):
@ -42,7 +42,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
Returns: Returns:
None if the media_id doesn't exist. None if the media_id doesn't exist.
""" """
return self.db.simple_select_one( return self.db_pool.simple_select_one(
"local_media_repository", "local_media_repository",
{"media_id": media_id}, {"media_id": media_id},
( (
@ -67,7 +67,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id, user_id,
url_cache=None, url_cache=None,
): ):
return self.db.simple_insert( return self.db_pool.simple_insert(
"local_media_repository", "local_media_repository",
{ {
"media_id": media_id, "media_id": media_id,
@ -83,7 +83,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def mark_local_media_as_safe(self, media_id: str): def mark_local_media_as_safe(self, media_id: str):
"""Mark a local media as safe from quarantining.""" """Mark a local media as safe from quarantining."""
return self.db.simple_update_one( return self.db_pool.simple_update_one(
table="local_media_repository", table="local_media_repository",
keyvalues={"media_id": media_id}, keyvalues={"media_id": media_id},
updatevalues={"safe_from_quarantine": True}, updatevalues={"safe_from_quarantine": True},
@ -136,12 +136,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
) )
return self.db.runInteraction("get_url_cache", get_url_cache_txn) return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache( def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts self, url, response_code, etag, expires_ts, og, media_id, download_ts
): ):
return self.db.simple_insert( return self.db_pool.simple_insert(
"local_media_repository_url_cache", "local_media_repository_url_cache",
{ {
"url": url, "url": url,
@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
def get_local_media_thumbnails(self, media_id): def get_local_media_thumbnails(self, media_id):
return self.db.simple_select_list( return self.db_pool.simple_select_list(
"local_media_repository_thumbnails", "local_media_repository_thumbnails",
{"media_id": media_id}, {"media_id": media_id},
( (
@ -178,7 +178,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method, thumbnail_method,
thumbnail_length, thumbnail_length,
): ):
return self.db.simple_insert( return self.db_pool.simple_insert(
"local_media_repository_thumbnails", "local_media_repository_thumbnails",
{ {
"media_id": media_id, "media_id": media_id,
@ -192,7 +192,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
def get_cached_remote_media(self, origin, media_id): def get_cached_remote_media(self, origin, media_id):
return self.db.simple_select_one( return self.db_pool.simple_select_one(
"remote_media_cache", "remote_media_cache",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (
@ -217,7 +217,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name, upload_name,
filesystem_id, filesystem_id,
): ):
return self.db.simple_insert( return self.db_pool.simple_insert(
"remote_media_cache", "remote_media_cache",
{ {
"media_origin": origin, "media_origin": origin,
@ -262,12 +262,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return self.db.runInteraction( return self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn "update_cached_last_access_time", update_cache_txn
) )
def get_remote_media_thumbnails(self, origin, media_id): def get_remote_media_thumbnails(self, origin, media_id):
return self.db.simple_select_list( return self.db_pool.simple_select_list(
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (
@ -292,7 +292,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method, thumbnail_method,
thumbnail_length, thumbnail_length,
): ):
return self.db.simple_insert( return self.db_pool.simple_insert(
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
{ {
"media_origin": origin, "media_origin": origin,
@ -314,24 +314,26 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE last_access_ts < ?" " WHERE last_access_ts < ?"
) )
return self.db.execute( return self.db_pool.execute(
"get_remote_media_before", self.db.cursor_to_dict, sql, before_ts "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
) )
def delete_remote_media(self, media_origin, media_id): def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn): def delete_remote_media_txn(txn):
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
"remote_media_cache", "remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id}, keyvalues={"media_origin": media_origin, "media_id": media_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id}, keyvalues={"media_origin": media_origin, "media_id": media_id},
) )
return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) return self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn
)
def get_expired_url_cache(self, now_ts): def get_expired_url_cache(self, now_ts):
sql = ( sql = (
@ -345,7 +347,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,)) txn.execute(sql, (now_ts,))
return [row[0] for row in txn] return [row[0] for row in txn]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn "get_expired_url_cache", _get_expired_url_cache_txn
) )
@ -358,7 +360,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_txn(txn): def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids]) txn.executemany(sql, [(media_id,) for media_id in media_ids])
return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
)
def get_url_cache_media_before(self, before_ts): def get_url_cache_media_before(self, before_ts):
sql = ( sql = (
@ -372,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,)) txn.execute(sql, (before_ts,))
return [row[0] for row in txn] return [row[0] for row in txn]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn "get_url_cache_media_before", _get_url_cache_media_before_txn
) )
@ -389,6 +393,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, [(media_id,) for media_id in media_ids]) txn.executemany(sql, [(media_id,) for media_id in media_ids])
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn "delete_url_cache_media", _delete_url_cache_media_txn
) )

View file

@ -20,10 +20,10 @@ from twisted.internet import defer
from synapse.metrics import BucketCollector from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.event_push_actions import ( from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
) )
from synapse.storage.database import Database
class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
@ -31,7 +31,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics. stats and prometheus metrics.
""" """
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Collect metrics on the number of forward extremities that exist. # Collect metrics on the number of forward extremities that exist.
@ -66,7 +66,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) )
return txn.fetchall() return txn.fetchall()
res = await self.db.runInteraction("read_forward_extremities", fetch) res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = Counter([x[0] for x in res]) self._current_forward_extremities_amount = Counter([x[0] for x in res])
@defer.inlineCallbacks @defer.inlineCallbacks
@ -88,7 +88,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
ret = yield self.db.runInteraction("count_messages", _count_messages) ret = yield self.db_pool.runInteraction("count_messages", _count_messages)
return ret return ret
@defer.inlineCallbacks @defer.inlineCallbacks
@ -109,7 +109,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) ret = yield self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages
)
return ret return ret
@defer.inlineCallbacks @defer.inlineCallbacks
@ -124,5 +126,5 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
ret = yield self.db.runInteraction("count_daily_active_rooms", _count) ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count)
return ret return ret

View file

@ -18,7 +18,7 @@ from typing import List
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database, make_in_list_sql_clause from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.hs = hs self.hs = hs
@ -48,7 +48,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
return self.db.runInteraction("count_users", _count_users) return self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0) @cached(num_args=0)
def get_monthly_active_count_by_service(self): def get_monthly_active_count_by_service(self):
@ -76,7 +76,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
result = txn.fetchall() result = txn.fetchall()
return dict(result) return dict(result)
return self.db.runInteraction("count_users_by_service", _count_users_by_service) return self.db_pool.runInteraction(
"count_users_by_service", _count_users_by_service
)
async def get_registered_reserved_users(self) -> List[str]: async def get_registered_reserved_users(self) -> List[str]:
"""Of the reserved threepids defined in config, retrieve those that are associated """Of the reserved threepids defined in config, retrieve those that are associated
@ -109,7 +111,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
""" """
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="monthly_active_users", table="monthly_active_users",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="timestamp", retcol="timestamp",
@ -119,7 +121,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._limit_usage_by_mau = hs.config.limit_usage_by_mau
@ -128,7 +130,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# Do not add more reserved users than the total allowable number # Do not add more reserved users than the total allowable number
# cur = LoggingTransaction( # cur = LoggingTransaction(
self.db.new_transaction( self.db_pool.new_transaction(
db_conn, db_conn,
"initialise_mau_threepids", "initialise_mau_threepids",
[], [],
@ -162,7 +164,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
is_support = self.is_support_user_txn(txn, user_id) is_support = self.is_support_user_txn(txn, user_id)
if not is_support: if not is_support:
# We do this manually here to avoid hitting #6791 # We do this manually here to avoid hitting #6791
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="monthly_active_users", table="monthly_active_users",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
@ -246,7 +248,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
reserved_users = await self.get_registered_reserved_users() reserved_users = await self.get_registered_reserved_users()
await self.db.runInteraction( await self.db_pool.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users "reap_monthly_active_users", _reap_users, reserved_users
) )
@ -273,7 +275,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
if is_support: if is_support:
return return
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
) )
@ -303,7 +305,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# never be a big table and alternative approaches (batching multiple # never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity. # upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more # See https://github.com/matrix-org/synapse/issues/3854 for more
is_insert = self.db.simple_upsert_txn( is_insert = self.db_pool.simple_upsert_txn(
txn, txn,
table="monthly_active_users", table="monthly_active_users",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},

View file

@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore): class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id): def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
return self.db.simple_insert( return self.db_pool.simple_insert(
table="open_id_tokens", table="open_id_tokens",
values={ values={
"token": token, "token": token,
@ -28,6 +28,6 @@ class OpenIdStore(SQLBaseStore):
else: else:
return rows[0][0] return rows[0][0]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn "get_user_id_for_token", get_user_id_for_token_txn
) )

View file

@ -31,7 +31,7 @@ class PresenceStore(SQLBaseStore):
) )
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"update_presence", "update_presence",
self._update_presence_txn, self._update_presence_txn,
stream_orderings, stream_orderings,
@ -48,7 +48,7 @@ class PresenceStore(SQLBaseStore):
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows # Actually insert new rows
self.db.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn, txn,
table="presence_stream", table="presence_stream",
values=[ values=[
@ -124,7 +124,7 @@ class PresenceStore(SQLBaseStore):
return updates, upper_bound, limited return updates, upper_bound, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_presence_updates", get_all_presence_updates_txn "get_all_presence_updates", get_all_presence_updates_txn
) )
@ -139,7 +139,7 @@ class PresenceStore(SQLBaseStore):
inlineCallbacks=True, inlineCallbacks=True,
) )
def get_presence_for_users(self, user_ids): def get_presence_for_users(self, user_ids):
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="presence_stream", table="presence_stream",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
@ -165,7 +165,7 @@ class PresenceStore(SQLBaseStore):
return self._presence_id_gen.get_current_token() return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self.db.simple_insert( return self.db_pool.simple_insert(
table="presence_allow_inbound", table="presence_allow_inbound",
values={ values={
"observed_user_id": observed_localpart, "observed_user_id": observed_localpart,
@ -176,7 +176,7 @@ class PresenceStore(SQLBaseStore):
) )
def disallow_presence_visible(self, observed_localpart, observer_userid): def disallow_presence_visible(self, observed_localpart, observer_userid):
return self.db.simple_delete_one( return self.db_pool.simple_delete_one(
table="presence_allow_inbound", table="presence_allow_inbound",
keyvalues={ keyvalues={
"observed_user_id": observed_localpart, "observed_user_id": observed_localpart,

View file

@ -17,14 +17,14 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.roommember import ProfileInfo from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore): class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_profileinfo(self, user_localpart): def get_profileinfo(self, user_localpart):
try: try:
profile = yield self.db.simple_select_one( profile = yield self.db_pool.simple_select_one(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"), retcols=("displayname", "avatar_url"),
@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore):
) )
def get_profile_displayname(self, user_localpart): def get_profile_displayname(self, user_localpart):
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcol="displayname", retcol="displayname",
@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore):
) )
def get_profile_avatar_url(self, user_localpart): def get_profile_avatar_url(self, user_localpart):
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcol="avatar_url", retcol="avatar_url",
@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore):
) )
def get_from_remote_profile_cache(self, user_id): def get_from_remote_profile_cache(self, user_id):
return self.db.simple_select_one( return self.db_pool.simple_select_one(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"), retcols=("displayname", "avatar_url"),
@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore):
) )
def create_profile(self, user_localpart): def create_profile(self, user_localpart):
return self.db.simple_insert( return self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile" table="profiles", values={"user_id": user_localpart}, desc="create_profile"
) )
def set_profile_displayname(self, user_localpart, new_displayname): def set_profile_displayname(self, user_localpart, new_displayname):
return self.db.simple_update_one( return self.db_pool.simple_update_one(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname}, updatevalues={"displayname": new_displayname},
@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore):
) )
def set_profile_avatar_url(self, user_localpart, new_avatar_url): def set_profile_avatar_url(self, user_localpart, new_avatar_url):
return self.db.simple_update_one( return self.db_pool.simple_update_one(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url}, updatevalues={"avatar_url": new_avatar_url},
@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore):
This should only be called when `is_subscribed_remote_profile_for_user` This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user. would return true for the user.
""" """
return self.db.simple_upsert( return self.db_pool.simple_upsert(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
values={ values={
@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore):
) )
def update_remote_profile_cache(self, user_id, displayname, avatar_url): def update_remote_profile_cache(self, user_id, displayname, avatar_url):
return self.db.simple_update( return self.db_pool.simple_update(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
updatevalues={ updatevalues={
@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore):
""" """
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed: if not subscribed:
yield self.db.simple_delete( yield self.db_pool.simple_delete(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache", desc="delete_remote_profile_cache",
@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore):
txn.execute(sql, (last_checked,)) txn.execute(sql, (last_checked,))
return self.db.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire", "get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn, _get_remote_profile_cache_entries_that_expire_txn,
) )
@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
def is_subscribed_remote_profile_for_user(self, user_id): def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile. """Check whether we are interested in a remote user's profile.
""" """
res = yield self.db.simple_select_one_onecol( res = yield self.db_pool.simple_select_one_onecol(
table="group_users", table="group_users",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="user_id", retcol="user_id",
@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
if res: if res:
return True return True
res = yield self.db.simple_select_one_onecol( res = yield self.db_pool.simple_select_one_onecol(
table="group_invites", table="group_invites",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="user_id", retcol="user_id",

View file

@ -18,7 +18,7 @@ from typing import Any, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,7 +43,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
deleted events. deleted events.
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"purge_history", "purge_history",
self._purge_history_txn, self._purge_history_txn,
room_id, room_id,
@ -293,7 +293,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
Deferred[List[int]]: The list of state groups to delete. Deferred[List[int]]: The list of state groups to delete.
""" """
return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
def _purge_room_txn(self, txn, room_id): def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before # First we fetch all the state groups that should be deleted, before

View file

@ -25,12 +25,12 @@ from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.database import Database from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
@ -79,7 +79,7 @@ class PushRulesWorkerStore(
# the abstract methods being implemented. # the abstract methods being implemented.
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:
@ -91,7 +91,7 @@ class PushRulesWorkerStore(
db_conn, "push_rules_stream", "stream_id" db_conn, "push_rules_stream", "stream_id"
) )
push_rules_prefill, push_rules_id = self.db.get_cache_dict( push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"push_rules_stream", "push_rules_stream",
entity_column="user_id", entity_column="user_id",
@ -116,7 +116,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000) @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self.db.simple_select_list( rows = yield self.db_pool.simple_select_list(
table="push_rules", table="push_rules",
keyvalues={"user_name": user_id}, keyvalues={"user_name": user_id},
retcols=( retcols=(
@ -140,7 +140,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000) @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
results = yield self.db.simple_select_list( results = yield self.db_pool.simple_select_list(
table="push_rules_enable", table="push_rules_enable",
keyvalues={"user_name": user_id}, keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"), retcols=("user_name", "rule_id", "enabled"),
@ -162,7 +162,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone() (count,) = txn.fetchone()
return bool(count) return bool(count)
return self.db.runInteraction( return self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn "have_push_rules_changed", have_push_rules_changed_txn
) )
@ -178,7 +178,7 @@ class PushRulesWorkerStore(
results = {user_id: [] for user_id in user_ids} results = {user_id: [] for user_id in user_ids}
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="push_rules", table="push_rules",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,
@ -336,7 +336,7 @@ class PushRulesWorkerStore(
results = {user_id: {} for user_id in user_ids} results = {user_id: {} for user_id in user_ids}
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="push_rules_enable", table="push_rules_enable",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,
@ -394,7 +394,7 @@ class PushRulesWorkerStore(
return updates, upper_bound, limited return updates, upper_bound, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn "get_all_push_rule_updates", get_all_push_rule_updates_txn
) )
@ -416,7 +416,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids stream_id, event_stream_ordering = ids
if before or after: if before or after:
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_add_push_rule_relative_txn", "_add_push_rule_relative_txn",
self._add_push_rule_relative_txn, self._add_push_rule_relative_txn,
stream_id, stream_id,
@ -430,7 +430,7 @@ class PushRuleStore(PushRulesWorkerStore):
after, after,
) )
else: else:
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn", "_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn, self._add_push_rule_highest_priority_txn,
stream_id, stream_id,
@ -461,7 +461,7 @@ class PushRuleStore(PushRulesWorkerStore):
relative_to_rule = before or after relative_to_rule = before or after
res = self.db.simple_select_one_txn( res = self.db_pool.simple_select_one_txn(
txn, txn,
table="push_rules", table="push_rules",
keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
@ -584,7 +584,7 @@ class PushRuleStore(PushRulesWorkerStore):
# We didn't update a row with the given rule_id so insert one # We didn't update a row with the given rule_id so insert one
push_rule_id = self._push_rule_id_gen.get_next() push_rule_id = self._push_rule_id_gen.get_next()
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="push_rules", table="push_rules",
values={ values={
@ -627,7 +627,7 @@ class PushRuleStore(PushRulesWorkerStore):
""" """
def delete_push_rule_txn(txn, stream_id, event_stream_ordering): def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
self.db.simple_delete_one_txn( self.db_pool.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
) )
@ -637,7 +637,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids stream_id, event_stream_ordering = ids
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"delete_push_rule", "delete_push_rule",
delete_push_rule_txn, delete_push_rule_txn,
stream_id, stream_id,
@ -648,7 +648,7 @@ class PushRuleStore(PushRulesWorkerStore):
def set_push_rule_enabled(self, user_id, rule_id, enabled): def set_push_rule_enabled(self, user_id, rule_id, enabled):
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids stream_id, event_stream_ordering = ids
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_set_push_rule_enabled_txn", "_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn, self._set_push_rule_enabled_txn,
stream_id, stream_id,
@ -662,7 +662,7 @@ class PushRuleStore(PushRulesWorkerStore):
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
): ):
new_id = self._push_rules_enable_id_gen.get_next() new_id = self._push_rules_enable_id_gen.get_next()
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
"push_rules_enable", "push_rules_enable",
{"user_name": user_id, "rule_id": rule_id}, {"user_name": user_id, "rule_id": rule_id},
@ -702,7 +702,7 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False, update_stream=False,
) )
else: else:
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
"push_rules", "push_rules",
{"user_name": user_id, "rule_id": rule_id}, {"user_name": user_id, "rule_id": rule_id},
@ -721,7 +721,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids: with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids stream_id, event_stream_ordering = ids
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"set_push_rule_actions", "set_push_rule_actions",
set_push_rule_actions_txn, set_push_rule_actions_txn,
stream_id, stream_id,
@ -741,7 +741,7 @@ class PushRuleStore(PushRulesWorkerStore):
if data is not None: if data is not None:
values.update(data) values.update(data)
self.db.simple_insert_txn(txn, "push_rules_stream", values=values) self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))

View file

@ -50,7 +50,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_has_pusher(self, user_id): def user_has_pusher(self, user_id):
ret = yield self.db.simple_select_one_onecol( ret = yield self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True "pushers", {"user_name": user_id}, "id", allow_none=True
) )
return ret is not None return ret is not None
@ -63,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pushers_by(self, keyvalues): def get_pushers_by(self, keyvalues):
ret = yield self.db.simple_select_list( ret = yield self.db_pool.simple_select_list(
"pushers", "pushers",
keyvalues, keyvalues,
[ [
@ -91,11 +91,11 @@ class PusherWorkerStore(SQLBaseStore):
def get_all_pushers(self): def get_all_pushers(self):
def get_pushers(txn): def get_pushers(txn):
txn.execute("SELECT * FROM pushers") txn.execute("SELECT * FROM pushers")
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows) return self._decode_pushers_rows(rows)
rows = yield self.db.runInteraction("get_all_pushers", get_pushers) rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
return rows return rows
async def get_all_updated_pushers_rows( async def get_all_updated_pushers_rows(
@ -160,7 +160,7 @@ class PusherWorkerStore(SQLBaseStore):
return updates, upper_bound, limited return updates, upper_bound, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
) )
@ -176,7 +176,7 @@ class PusherWorkerStore(SQLBaseStore):
inlineCallbacks=True, inlineCallbacks=True,
) )
def get_if_users_have_pushers(self, user_ids): def get_if_users_have_pushers(self, user_ids):
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="pushers", table="pushers",
column="user_name", column="user_name",
iterable=user_ids, iterable=user_ids,
@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
def update_pusher_last_stream_ordering( def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering self, app_id, pushkey, user_id, last_stream_ordering
): ):
yield self.db.simple_update_one( yield self.db_pool.simple_update_one(
"pushers", "pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering}, {"last_stream_ordering": last_stream_ordering},
@ -216,7 +216,7 @@ class PusherWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred[bool]: True if the pusher still exists; False if it has been deleted. Deferred[bool]: True if the pusher still exists; False if it has been deleted.
""" """
updated = yield self.db.simple_update( updated = yield self.db_pool.simple_update(
table="pushers", table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={ updatevalues={
@ -230,7 +230,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self.db.simple_update( yield self.db_pool.simple_update(
table="pushers", table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since}, updatevalues={"failing_since": failing_since},
@ -239,7 +239,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id): def get_throttle_params_by_room(self, pusher_id):
res = yield self.db.simple_select_list( res = yield self.db_pool.simple_select_list(
"pusher_throttle", "pusher_throttle",
{"pusher": pusher_id}, {"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"], ["room_id", "last_sent_ts", "throttle_ms"],
@ -259,7 +259,7 @@ class PusherWorkerStore(SQLBaseStore):
def set_throttle_params(self, pusher_id, room_id, params): def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on # no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry # (pusher, room_id) so simple_upsert will retry
yield self.db.simple_upsert( yield self.db_pool.simple_upsert(
"pusher_throttle", "pusher_throttle",
{"pusher": pusher_id, "room_id": room_id}, {"pusher": pusher_id, "room_id": room_id},
params, params,
@ -291,7 +291,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on # no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry # (app_id, pushkey, user_name) so simple_upsert will retry
yield self.db.simple_upsert( yield self.db_pool.simple_upsert(
table="pushers", table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={ values={
@ -316,7 +316,7 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True: if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before # invalidate, since we the user might not have had a pusher before
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"add_pusher", "add_pusher",
self._invalidate_cache_and_stream, self._invalidate_cache_and_stream,
self.get_if_user_has_pusher, self.get_if_user_has_pusher,
@ -330,7 +330,7 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,) txn, self.get_if_user_has_pusher, (user_id,)
) )
self.db.simple_delete_one_txn( self.db_pool.simple_delete_one_txn(
txn, txn,
"pushers", "pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@ -339,7 +339,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for # it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that # (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter. # doesn't really matter.
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="deleted_pushers", table="deleted_pushers",
values={ values={
@ -351,4 +351,6 @@ class PusherStore(PusherWorkerStore):
) )
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) yield self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)

View file

@ -23,7 +23,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@ -41,7 +41,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented. # the abstract methods being implemented.
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
@ -64,7 +64,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=2) @cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type): def get_receipts_for_room(self, room_id, receipt_type):
return self.db.simple_select_list( return self.db_pool.simple_select_list(
table="receipts_linearized", table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type}, keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"), retcols=("user_id", "event_id"),
@ -73,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3) @cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="receipts_linearized", table="receipts_linearized",
keyvalues={ keyvalues={
"room_id": room_id, "room_id": room_id,
@ -87,7 +87,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type): def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self.db.simple_select_list( rows = yield self.db_pool.simple_select_list(
table="receipts_linearized", table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type}, keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"), retcols=("room_id", "event_id"),
@ -111,7 +111,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return txn.fetchall() return txn.fetchall()
rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) rows = yield self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
)
return { return {
row[0]: { row[0]: {
"event_id": row[1], "event_id": row[1],
@ -190,11 +192,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (room_id, to_key)) txn.execute(sql, (room_id, to_key))
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
return rows return rows
rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows: if not rows:
return [] return []
@ -240,9 +242,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args)) txn.execute(sql + clause, [to_key] + list(args))
return self.db.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
txn_results = yield self.db.runInteraction( txn_results = yield self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f "_get_linearized_receipts_for_rooms", f
) )
@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [r[0] for r in txn] return [r[0] for r in txn]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
) )
@ -340,7 +342,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return updates, upper_bound, limited return updates, upper_bound, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn "get_all_updated_receipts", get_all_updated_receipts_txn
) )
@ -371,7 +373,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore): class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor # We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id # needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
@ -393,7 +395,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown) (or 0 if the event is unknown)
""" """
res = self.db.simple_select_one_txn( res = self.db_pool.simple_select_one_txn(
txn, txn,
table="events", table="events",
retcols=["stream_ordering", "received_ts"], retcols=["stream_ordering", "received_ts"],
@ -446,7 +448,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type), (user_id, room_id, receipt_type),
) )
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="receipts_linearized", table="receipts_linearized",
keyvalues={ keyvalues={
@ -506,13 +508,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else: else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.db.runInteraction( linearized_event_id = yield self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear "insert_receipt_conv", graph_to_linear
) )
stream_id_manager = self._receipts_id_gen.get_next() stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id: with stream_id_manager as stream_id:
event_ts = yield self.db.runInteraction( event_ts = yield self.db_pool.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,
room_id, room_id,
@ -541,7 +543,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.db.runInteraction( return self.db_pool.runInteraction(
"insert_graph_receipt", "insert_graph_receipt",
self.insert_graph_receipt_txn, self.insert_graph_receipt_txn,
room_id, room_id,
@ -567,7 +569,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,) self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="receipts_graph", table="receipts_graph",
keyvalues={ keyvalues={
@ -576,7 +578,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id, "user_id": user_id,
}, },
) )
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="receipts_graph", table="receipts_graph",
values={ values={

View file

@ -26,7 +26,7 @@ from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID from synapse.types import UserID
@ -38,7 +38,7 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore): class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -50,7 +50,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@cached() @cached()
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
return self.db.simple_select_one( return self.db_pool.simple_select_one(
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
retcols=[ retcols=[
@ -101,7 +101,7 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`, including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`. `valid_until_ms`.
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token "get_user_by_access_token", self._query_for_auth, token
) )
@ -116,7 +116,7 @@ class RegistrationWorkerStore(SQLBaseStore):
otherwise int representation of the timestamp (as a number of otherwise int representation of the timestamp (as a number of
milliseconds since epoch). milliseconds since epoch).
""" """
res = yield self.db.simple_select_one_onecol( res = yield self.db_pool.simple_select_one_onecol(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="expiration_ts_ms", retcol="expiration_ts_ms",
@ -144,7 +144,7 @@ class RegistrationWorkerStore(SQLBaseStore):
""" """
def set_account_validity_for_user_txn(txn): def set_account_validity_for_user_txn(txn):
self.db.simple_update_txn( self.db_pool.simple_update_txn(
txn=txn, txn=txn,
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
@ -158,7 +158,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,) txn, self.get_expiration_ts_for_user, (user_id,)
) )
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn "set_account_validity_for_user", set_account_validity_for_user_txn
) )
@ -174,7 +174,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Raises: Raises:
StoreError: The provided token is already set for another user. StoreError: The provided token is already set for another user.
""" """
yield self.db.simple_update_one( yield self.db_pool.simple_update_one(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token}, updatevalues={"renewal_token": renewal_token},
@ -191,7 +191,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
defer.Deferred[str]: The ID of the user to which the token belongs. defer.Deferred[str]: The ID of the user to which the token belongs.
""" """
res = yield self.db.simple_select_one_onecol( res = yield self.db_pool.simple_select_one_onecol(
table="account_validity", table="account_validity",
keyvalues={"renewal_token": renewal_token}, keyvalues={"renewal_token": renewal_token},
retcol="user_id", retcol="user_id",
@ -210,7 +210,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
defer.Deferred[str]: The renewal token associated with this user ID. defer.Deferred[str]: The renewal token associated with this user ID.
""" """
res = yield self.db.simple_select_one_onecol( res = yield self.db_pool.simple_select_one_onecol(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="renewal_token", retcol="renewal_token",
@ -236,9 +236,9 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
values = [False, now_ms, renew_at] values = [False, now_ms, renew_at]
txn.execute(sql, values) txn.execute(sql, values)
return self.db.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
res = yield self.db.runInteraction( res = yield self.db_pool.runInteraction(
"get_users_expiring_soon", "get_users_expiring_soon",
select_users_txn, select_users_txn,
self.clock.time_msec(), self.clock.time_msec(),
@ -257,7 +257,7 @@ class RegistrationWorkerStore(SQLBaseStore):
email_sent (bool): Flag which indicates whether a renewal email has been sent email_sent (bool): Flag which indicates whether a renewal email has been sent
to this user. to this user.
""" """
yield self.db.simple_update_one( yield self.db_pool.simple_update_one(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent}, updatevalues={"email_sent": email_sent},
@ -272,7 +272,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Args: Args:
user_id (str): ID of the user to remove from the account validity table. user_id (str): ID of the user to remove from the account validity table.
""" """
yield self.db.simple_delete_one( yield self.db_pool.simple_delete_one(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user", desc="delete_account_validity_for_user",
@ -287,7 +287,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool): Returns (bool):
true iff the user is a server admin, false otherwise. true iff the user is a server admin, false otherwise.
""" """
res = await self.db.simple_select_one_onecol( res = await self.db_pool.simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user.to_string()}, keyvalues={"name": user.to_string()},
retcol="admin", retcol="admin",
@ -307,14 +307,14 @@ class RegistrationWorkerStore(SQLBaseStore):
""" """
def set_server_admin_txn(txn): def set_server_admin_txn(txn):
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user.to_string(),) txn, self.get_user_by_id, (user.to_string(),)
) )
return self.db.runInteraction("set_server_admin", set_server_admin_txn) return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
@ -326,7 +326,7 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if rows: if rows:
return rows[0] return rows[0]
@ -342,7 +342,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred[bool]: True if user 'user_type' is null or empty string Deferred[bool]: True if user 'user_type' is null or empty string
""" """
res = yield self.db.runInteraction( res = yield self.db_pool.runInteraction(
"is_real_user", self.is_real_user_txn, user_id "is_real_user", self.is_real_user_txn, user_id
) )
return res return res
@ -357,12 +357,12 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT Deferred[bool]: True if user is of type UserTypes.SUPPORT
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"is_support_user", self.is_support_user_txn, user_id "is_support_user", self.is_support_user_txn, user_id
) )
def is_real_user_txn(self, txn, user_id): def is_real_user_txn(self, txn, user_id):
res = self.db.simple_select_one_onecol_txn( res = self.db_pool.simple_select_one_onecol_txn(
txn=txn, txn=txn,
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
@ -372,7 +372,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return res is None return res is None
def is_support_user_txn(self, txn, user_id): def is_support_user_txn(self, txn, user_id):
res = self.db.simple_select_one_onecol_txn( res = self.db_pool.simple_select_one_onecol_txn(
txn=txn, txn=txn,
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
@ -391,7 +391,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return dict(txn) return dict(txn)
return self.db.runInteraction("get_users_by_id_case_insensitive", f) return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id( async def get_user_by_external_id(
self, auth_provider: str, external_id: str self, auth_provider: str, external_id: str
@ -405,7 +405,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
str|None: the mxid of the user, or None if they are not known str|None: the mxid of the user, or None if they are not known
""" """
return await self.db.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="user_external_ids", table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id}, keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id", retcol="user_id",
@ -419,12 +419,12 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn): def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users") txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if rows: if rows:
return rows[0]["users"] return rows[0]["users"]
return 0 return 0
ret = yield self.db.runInteraction("count_users", _count_users) ret = yield self.db_pool.runInteraction("count_users", _count_users)
return ret return ret
def count_daily_user_type(self): def count_daily_user_type(self):
@ -456,7 +456,9 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1] results[row[0]] = row[1]
return results return results
return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) return self.db_pool.runInteraction(
"count_daily_user_type", _count_daily_user_type
)
@defer.inlineCallbacks @defer.inlineCallbacks
def count_nonbridged_users(self): def count_nonbridged_users(self):
@ -470,7 +472,7 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
ret = yield self.db.runInteraction("count_users", _count_users) ret = yield self.db_pool.runInteraction("count_users", _count_users)
return ret return ret
@defer.inlineCallbacks @defer.inlineCallbacks
@ -479,12 +481,12 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn): def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if rows: if rows:
return rows[0]["users"] return rows[0]["users"]
return 0 return 0
ret = yield self.db.runInteraction("count_real_users", _count_users) ret = yield self.db_pool.runInteraction("count_real_users", _count_users)
return ret return ret
async def generate_user_id(self) -> str: async def generate_user_id(self) -> str:
@ -492,7 +494,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: a (hopefully) free localpart Returns: a (hopefully) free localpart
""" """
next_id = await self.db.runInteraction( next_id = await self.db_pool.runInteraction(
"generate_user_id", self._user_id_seq.get_next_id_txn "generate_user_id", self._user_id_seq.get_next_id_txn
) )
@ -508,7 +510,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
The user ID or None if no user id/threepid mapping exists The user ID or None if no user id/threepid mapping exists
""" """
user_id = await self.db.runInteraction( user_id = await self.db_pool.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
) )
return user_id return user_id
@ -524,7 +526,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
str|None: user id or None if no user id/threepid mapping exists str|None: user id or None if no user id/threepid mapping exists
""" """
ret = self.db.simple_select_one_txn( ret = self.db_pool.simple_select_one_txn(
txn, txn,
"user_threepids", "user_threepids",
{"medium": medium, "address": address}, {"medium": medium, "address": address},
@ -537,7 +539,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at): def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self.db.simple_upsert( yield self.db_pool.simple_upsert(
"user_threepids", "user_threepids",
{"medium": medium, "address": address}, {"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@ -545,7 +547,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_get_threepids(self, user_id): def user_get_threepids(self, user_id):
ret = yield self.db.simple_select_list( ret = yield self.db_pool.simple_select_list(
"user_threepids", "user_threepids",
{"user_id": user_id}, {"user_id": user_id},
["medium", "address", "validated_at", "added_at"], ["medium", "address", "validated_at", "added_at"],
@ -554,7 +556,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret return ret
def user_delete_threepid(self, user_id, medium, address): def user_delete_threepid(self, user_id, medium, address):
return self.db.simple_delete( return self.db_pool.simple_delete(
"user_threepids", "user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address}, keyvalues={"user_id": user_id, "medium": medium, "address": address},
desc="user_delete_threepid", desc="user_delete_threepid",
@ -567,7 +569,7 @@ class RegistrationWorkerStore(SQLBaseStore):
user_id: The user id to delete all threepids of user_id: The user id to delete all threepids of
""" """
return self.db.simple_delete( return self.db_pool.simple_delete(
"user_threepids", "user_threepids",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
desc="user_delete_threepids", desc="user_delete_threepids",
@ -589,7 +591,7 @@ class RegistrationWorkerStore(SQLBaseStore):
""" """
# We need to use an upsert, in case they user had already bound the # We need to use an upsert, in case they user had already bound the
# threepid # threepid
return self.db.simple_upsert( return self.db_pool.simple_upsert(
table="user_threepid_id_server", table="user_threepid_id_server",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -615,7 +617,7 @@ class RegistrationWorkerStore(SQLBaseStore):
medium (str): The medium of the threepid (e.g "email") medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com") address (str): The address of the threepid (e.g "bob@example.com")
""" """
return self.db.simple_select_list( return self.db_pool.simple_select_list(
table="user_threepid_id_server", table="user_threepid_id_server",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["medium", "address"], retcols=["medium", "address"],
@ -636,7 +638,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred Deferred
""" """
return self.db.simple_delete( return self.db_pool.simple_delete(
table="user_threepid_id_server", table="user_threepid_id_server",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -659,7 +661,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred[list[str]]: Resolves to a list of identity servers Deferred[list[str]]: Resolves to a list of identity servers
""" """
return self.db.simple_select_onecol( return self.db_pool.simple_select_onecol(
table="user_threepid_id_server", table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address}, keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server", retcol="id_server",
@ -677,7 +679,7 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.Deferred(bool): The requested value. defer.Deferred(bool): The requested value.
""" """
res = yield self.db.simple_select_one_onecol( res = yield self.db_pool.simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
retcol="deactivated", retcol="deactivated",
@ -744,13 +746,13 @@ class RegistrationWorkerStore(SQLBaseStore):
sql += " LIMIT 1" sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if not rows: if not rows:
return None return None
return rows[0] return rows[0]
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn "get_threepid_validation_session", get_threepid_validation_session_txn
) )
@ -764,37 +766,37 @@ class RegistrationWorkerStore(SQLBaseStore):
""" """
def delete_threepid_session_txn(txn): def delete_threepid_session_txn(txn):
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="threepid_validation_token", table="threepid_validation_token",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
) )
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn "delete_threepid_session", delete_threepid_session_txn
) )
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.config = hs.config self.config = hs.config
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"access_tokens_device_index", "access_tokens_device_index",
index_name="access_tokens_device_id", index_name="access_tokens_device_id",
table="access_tokens", table="access_tokens",
columns=["user_id", "device_id"], columns=["user_id", "device_id"],
) )
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"users_creation_ts", "users_creation_ts",
index_name="users_creation_ts", index_name="users_creation_ts",
table="users", table="users",
@ -804,13 +806,15 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
# we no longer use refresh tokens, but it's possible that some people # we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just # might have a background update queued to build this index. Just
# clear the background update. # clear the background update.
self.db.updates.register_noop_background_update("refresh_tokens_device_index") self.db_pool.updates.register_noop_background_update(
"refresh_tokens_device_index"
)
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"user_threepids_grandfather", self._bg_user_threepids_grandfather "user_threepids_grandfather", self._bg_user_threepids_grandfather
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag "users_set_deactivated_flag", self._background_update_set_deactivated_flag
) )
@ -843,7 +847,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size), (last_user, batch_size),
) )
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if not rows: if not rows:
return True, 0 return True, 0
@ -857,7 +861,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
logger.info("Marked %d rows as deactivated", rows_processed_nb) logger.info("Marked %d rows as deactivated", rows_processed_nb)
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
) )
@ -866,12 +870,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
else: else:
return False, len(rows) return False, len(rows)
end, nb_processed = yield self.db.runInteraction( end, nb_processed = yield self.db_pool.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
) )
if end: if end:
yield self.db.updates._end_background_update("users_set_deactivated_flag") yield self.db_pool.updates._end_background_update(
"users_set_deactivated_flag"
)
return nb_processed return nb_processed
@ -897,17 +903,17 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
txn.executemany(sql, [(id_server,) for id_server in id_servers]) txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers: if id_servers:
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
) )
yield self.db.updates._end_background_update("user_threepids_grandfather") yield self.db_pool.updates._end_background_update("user_threepids_grandfather")
return 1 return 1
class RegistrationStore(RegistrationBackgroundUpdateStore): class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationStore, self).__init__(database, db_conn, hs) super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity self._account_validity = hs.config.account_validity
@ -947,7 +953,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
""" """
next_id = self._access_tokens_id_gen.get_next() next_id = self._access_tokens_id_gen.get_next()
yield self.db.simple_insert( yield self.db_pool.simple_insert(
"access_tokens", "access_tokens",
{ {
"id": next_id, "id": next_id,
@ -992,7 +998,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Returns: Returns:
Deferred Deferred
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"register_user", "register_user",
self._register_user, self._register_user,
user_id, user_id,
@ -1026,7 +1032,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Ensure that the guest user actually exists # Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception # ``allow_none=False`` makes this raise an exception
# if the row isn't in the database. # if the row isn't in the database.
self.db.simple_select_one_txn( self.db_pool.simple_select_one_txn(
txn, txn,
"users", "users",
keyvalues={"name": user_id, "is_guest": 1}, keyvalues={"name": user_id, "is_guest": 1},
@ -1034,7 +1040,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
allow_none=False, allow_none=False,
) )
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
"users", "users",
keyvalues={"name": user_id, "is_guest": 1}, keyvalues={"name": user_id, "is_guest": 1},
@ -1048,7 +1054,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
}, },
) )
else: else:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
"users", "users",
values={ values={
@ -1103,7 +1109,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system external_id: id on that system
user_id: complete mxid that it is mapped to user_id: complete mxid that it is mapped to
""" """
return self.db.simple_insert( return self.db_pool.simple_insert(
table="user_external_ids", table="user_external_ids",
values={ values={
"auth_provider": auth_provider, "auth_provider": auth_provider,
@ -1121,12 +1127,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
""" """
def user_set_password_hash_txn(txn): def user_set_password_hash_txn(txn):
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash} txn, "users", {"name": user_id}, {"password_hash": password_hash}
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db.runInteraction( return self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn "user_set_password_hash", user_set_password_hash_txn
) )
@ -1143,7 +1149,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
""" """
def f(txn): def f(txn):
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
@ -1151,7 +1157,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db.runInteraction("user_set_consent_version", f) return self.db_pool.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version): def user_set_consent_server_notice_sent(self, user_id, consent_version):
"""Updates the user table to record that we have sent the user a server """Updates the user table to record that we have sent the user a server
@ -1167,7 +1173,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
""" """
def f(txn): def f(txn):
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
@ -1175,7 +1181,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db.runInteraction("user_set_consent_server_notice_sent", f) return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
""" """
@ -1221,11 +1227,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices return tokens_and_devices
return self.db.runInteraction("user_delete_access_tokens", f) return self.db_pool.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
def f(txn): def f(txn):
self.db.simple_delete_one_txn( self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token} txn, table="access_tokens", keyvalues={"token": access_token}
) )
@ -1233,11 +1239,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,) txn, self.get_user_by_access_token, (access_token,)
) )
return self.db.runInteraction("delete_access_token", f) return self.db_pool.runInteraction("delete_access_token", f)
@cachedInlineCallbacks() @cachedInlineCallbacks()
def is_guest(self, user_id): def is_guest(self, user_id):
res = yield self.db.simple_select_one_onecol( res = yield self.db_pool.simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
retcol="is_guest", retcol="is_guest",
@ -1252,7 +1258,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Adds a user to the table of users who need to be parted from all the rooms they're Adds a user to the table of users who need to be parted from all the rooms they're
in in
""" """
return self.db.simple_insert( return self.db_pool.simple_insert(
"users_pending_deactivation", "users_pending_deactivation",
values={"user_id": user_id}, values={"user_id": user_id},
desc="add_user_pending_deactivation", desc="add_user_pending_deactivation",
@ -1265,7 +1271,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
""" """
# XXX: This should be simple_delete_one but we failed to put a unique index on # XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it. # the table, so somehow duplicate entries have ended up in it.
return self.db.simple_delete( return self.db_pool.simple_delete(
"users_pending_deactivation", "users_pending_deactivation",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation", desc="del_user_pending_deactivation",
@ -1276,7 +1282,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Gets one user from the table of users waiting to be parted from all the rooms Gets one user from the table of users waiting to be parted from all the rooms
they're in. they're in.
""" """
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
"users_pending_deactivation", "users_pending_deactivation",
keyvalues={}, keyvalues={},
retcol="user_id", retcol="user_id",
@ -1306,7 +1312,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Insert everything into a transaction in order to run atomically # Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn): def validate_threepid_session_txn(txn):
row = self.db.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
@ -1324,7 +1330,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
400, "This client_secret does not match the provided session_id" 400, "This client_secret does not match the provided session_id"
) )
row = self.db.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
table="threepid_validation_token", table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token}, keyvalues={"session_id": session_id, "token": token},
@ -1349,7 +1355,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
# Looks good. Validate the session # Looks good. Validate the session
self.db.simple_update_txn( self.db_pool.simple_update_txn(
txn, txn,
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
@ -1359,7 +1365,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link return next_link
# Return next_link if it exists # Return next_link if it exists
return self.db.runInteraction( return self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn "validate_threepid_session_txn", validate_threepid_session_txn
) )
@ -1392,7 +1398,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
if validated_at: if validated_at:
insertion_values["validated_at"] = validated_at insertion_values["validated_at"] = validated_at
return self.db.simple_upsert( return self.db_pool.simple_upsert(
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt}, values={"last_send_attempt": send_attempt},
@ -1430,7 +1436,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def start_or_continue_validation_session_txn(txn): def start_or_continue_validation_session_txn(txn):
# Create or update a validation session # Create or update a validation session
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
@ -1443,7 +1449,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
# Create a new validation token with this session ID # Create a new validation token with this session ID
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="threepid_validation_token", table="threepid_validation_token",
values={ values={
@ -1454,7 +1460,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
}, },
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"start_or_continue_validation_session", "start_or_continue_validation_session",
start_or_continue_validation_session_txn, start_or_continue_validation_session_txn,
) )
@ -1469,7 +1475,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
""" """
return txn.execute(sql, (ts,)) return txn.execute(sql, (ts,))
return self.db.runInteraction( return self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens", "cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn, cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(), self.clock.time_msec(),
@ -1484,7 +1490,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
deactivated (bool): The value to set for `deactivated`. deactivated (bool): The value to set for `deactivated`.
""" """
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"set_user_deactivated_status", "set_user_deactivated_status",
self.set_user_deactivated_status_txn, self.set_user_deactivated_status_txn,
user_id, user_id,
@ -1492,7 +1498,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
def set_user_deactivated_status_txn(self, txn, user_id, deactivated): def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn=txn, txn=txn,
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
@ -1520,14 +1526,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
txn.execute(sql, []) txn.execute(sql, [])
res = self.db.cursor_to_dict(txn) res = self.db_pool.cursor_to_dict(txn)
if res: if res:
for user in res: for user in res:
self.set_expiration_date_for_user_txn( self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True txn, user["name"], use_delta=True
) )
yield self.db.runInteraction( yield self.db_pool.runInteraction(
"get_users_with_no_expiration_date", "get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn, select_users_with_no_expiration_date_txn,
) )
@ -1551,7 +1557,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expiration_ts, expiration_ts,
) )
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
"account_validity", "account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},

View file

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore): class RejectionsStore(SQLBaseStore):
def get_rejection_reason(self, event_id): def get_rejection_reason(self, event_id):
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="rejections", table="rejections",
retcol="reason", retcol="reason",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},

View file

@ -19,7 +19,7 @@ import attr
from synapse.api.constants import RelationTypes from synapse.api.constants import RelationTypes
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.stream import generate_pagination_where_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import ( from synapse.storage.relations import (
AggregationPaginationToken, AggregationPaginationToken,
PaginationChunk, PaginationChunk,
@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn "get_recent_references_for_event", _get_recent_references_for_event_txn
) )
@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
) )
@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore):
if row: if row:
return row[0] return row[0]
edit_id = yield self.db.runInteraction( edit_id = yield self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn "get_applicable_edit", _get_applicable_edit_txn
) )
@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone()) return bool(txn.fetchone())
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
) )

View file

@ -27,8 +27,8 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import Database, LoggingTransaction from synapse.storage.databases.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -73,7 +73,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore): class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomWorkerStore, self).__init__(database, db_conn, hs) super(RoomWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -86,7 +86,7 @@ class RoomWorkerStore(SQLBaseStore):
Returns: Returns:
A dict containing the room information, or None if the room is unknown. A dict containing the room information, or None if the room is unknown.
""" """
return self.db.simple_select_one( return self.db_pool.simple_select_one(
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"), retcols=("room_id", "is_public", "creator"),
@ -118,7 +118,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, [room_id]) txn.execute(sql, [room_id])
# Catch error if sql returns empty result to return "None" instead of an error # Catch error if sql returns empty result to return "None" instead of an error
try: try:
res = self.db.cursor_to_dict(txn)[0] res = self.db_pool.cursor_to_dict(txn)[0]
except IndexError: except IndexError:
return None return None
@ -126,12 +126,12 @@ class RoomWorkerStore(SQLBaseStore):
res["public"] = bool(res["public"]) res["public"] = bool(res["public"])
return res return res
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id "get_room_with_stats", get_room_with_stats_txn, room_id
) )
def get_public_room_ids(self): def get_public_room_ids(self):
return self.db.simple_select_onecol( return self.db_pool.simple_select_onecol(
table="rooms", table="rooms",
keyvalues={"is_public": True}, keyvalues={"is_public": True},
retcol="room_id", retcol="room_id",
@ -188,7 +188,9 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args) txn.execute(sql, query_args)
return txn.fetchone()[0] return txn.fetchone()[0]
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) return self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
)
async def get_largest_public_rooms( async def get_largest_public_rooms(
self, self,
@ -320,21 +322,21 @@ class RoomWorkerStore(SQLBaseStore):
def _get_largest_public_rooms_txn(txn): def _get_largest_public_rooms_txn(txn):
txn.execute(sql, query_args) txn.execute(sql, query_args)
results = self.db.cursor_to_dict(txn) results = self.db_pool.cursor_to_dict(txn)
if not forwards: if not forwards:
results.reverse() results.reverse()
return results return results
ret_val = await self.db.runInteraction( ret_val = await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn "get_largest_public_rooms", _get_largest_public_rooms_txn
) )
return ret_val return ret_val
@cached(max_entries=10000) @cached(max_entries=10000)
def is_room_blocked(self, room_id): def is_room_blocked(self, room_id):
return self.db.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="blocked_rooms", table="blocked_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="1", retcol="1",
@ -502,7 +504,7 @@ class RoomWorkerStore(SQLBaseStore):
room_count = txn.fetchone() room_count = txn.fetchone()
return rooms, room_count[0] return rooms, room_count[0]
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_rooms_paginate", _get_rooms_paginate_txn, "get_rooms_paginate", _get_rooms_paginate_txn,
) )
@ -519,7 +521,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely. disabled for that user entirely.
""" """
row = await self.db.simple_select_one( row = await self.db_pool.simple_select_one(
table="ratelimit_override", table="ratelimit_override",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"), retcols=("messages_per_second", "burst_count"),
@ -561,9 +563,9 @@ class RoomWorkerStore(SQLBaseStore):
(room_id,), (room_id,),
) )
return self.db.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
ret = await self.db.runInteraction( ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn, "get_retention_policy_for_room", get_retention_policy_for_room_txn,
) )
@ -613,7 +615,7 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs return local_media_mxcs, remote_media_mxcs
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn "get_media_ids_in_room", _get_media_mxcs_in_room_txn
) )
@ -630,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by txn, local_mxcs, remote_mxcs, quarantined_by
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn "quarantine_media_in_room", _quarantine_media_in_room_txn
) )
@ -714,7 +716,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by txn, local_mxcs, remote_mxcs, quarantined_by
) )
return self.db.runInteraction( return self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn "quarantine_media_by_user", _quarantine_media_by_id_txn
) )
@ -730,7 +732,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
return self.db.runInteraction( return self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn "quarantine_media_by_user", _quarantine_media_by_user_txn
) )
@ -848,7 +850,7 @@ class RoomWorkerStore(SQLBaseStore):
return updates, upto_token, limited return updates, upto_token, limited
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms "get_all_new_public_rooms", get_all_new_public_rooms
) )
@ -857,21 +859,21 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"insert_room_retention", self._background_insert_retention, "insert_room_retention", self._background_insert_retention,
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
self._remove_tombstoned_rooms_from_directory, self._remove_tombstoned_rooms_from_directory,
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
self.ADD_ROOMS_ROOM_VERSION_COLUMN, self.ADD_ROOMS_ROOM_VERSION_COLUMN,
self._background_add_rooms_room_version_column, self._background_add_rooms_room_version_column,
) )
@ -900,7 +902,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size), (last_room, batch_size),
) )
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if not rows: if not rows:
return True return True
@ -912,7 +914,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
ev = db_to_json(row["json"]) ev = db_to_json(row["json"])
retention_policy = ev["content"] retention_policy = ev["content"]
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn=txn, txn=txn,
table="room_retention", table="room_retention",
values={ values={
@ -925,7 +927,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows)) logger.info("Inserted %d rows into room_retention", len(rows))
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
) )
@ -934,12 +936,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
else: else:
return False return False
end = await self.db.runInteraction( end = await self.db_pool.runInteraction(
"insert_room_retention", _background_insert_retention_txn, "insert_room_retention", _background_insert_retention_txn,
) )
if end: if end:
await self.db.updates._end_background_update("insert_room_retention") await self.db_pool.updates._end_background_update("insert_room_retention")
return batch_size return batch_size
@ -983,7 +985,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
# mainly for paranoia as much badness would happen if we don't # mainly for paranoia as much badness would happen if we don't
# insert the row and then try and get the room version for the # insert the row and then try and get the room version for the
# room. # room.
self.db.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
@ -992,19 +994,19 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
) )
new_last_room_id = room_id new_last_room_id = room_id
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id}
) )
return False return False
end = await self.db.runInteraction( end = await self.db_pool.runInteraction(
"_background_add_rooms_room_version_column", "_background_add_rooms_room_version_column",
_background_add_rooms_room_version_column_txn, _background_add_rooms_room_version_column_txn,
) )
if end: if end:
await self.db.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.ADD_ROOMS_ROOM_VERSION_COLUMN self.ADD_ROOMS_ROOM_VERSION_COLUMN
) )
@ -1038,12 +1040,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return [row[0] for row in txn] return [row[0] for row in txn]
rooms = await self.db.runInteraction( rooms = await self.db_pool.runInteraction(
"get_tombstoned_directory_rooms", _get_rooms "get_tombstoned_directory_rooms", _get_rooms
) )
if not rooms: if not rooms:
await self.db.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
) )
return 0 return 0
@ -1052,7 +1054,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Removing tombstoned room %s from the directory", room_id) logger.info("Removing tombstoned room %s from the directory", room_id)
await self.set_room_is_public(room_id, False) await self.set_room_is_public(room_id, False)
await self.db.updates._background_update_progress( await self.db_pool.updates._background_update_progress(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
) )
@ -1068,7 +1070,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs) super(RoomStore, self).__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -1079,7 +1081,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Called when we join a room over federation, and overwrites any room version Called when we join a room over federation, and overwrites any room version
currently in the table. currently in the table.
""" """
await self.db.simple_upsert( await self.db_pool.simple_upsert(
desc="upsert_room_on_join", desc="upsert_room_on_join",
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
@ -1111,7 +1113,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
try: try:
def store_room_txn(txn, next_id): def store_room_txn(txn, next_id):
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
"rooms", "rooms",
{ {
@ -1122,7 +1124,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
}, },
) )
if is_public: if is_public:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="public_room_list_stream", table="public_room_list_stream",
values={ values={
@ -1133,7 +1135,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
await self.db.runInteraction("store_room_txn", store_room_txn, next_id) await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
except Exception as e: except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
@ -1143,7 +1147,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
When we receive an invite over federation, store the version of the room if we When we receive an invite over federation, store the version of the room if we
don't already know the room version. don't already know the room version.
""" """
await self.db.simple_upsert( await self.db_pool.simple_upsert(
desc="maybe_store_room_on_invite", desc="maybe_store_room_on_invite",
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
@ -1160,14 +1164,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
async def set_room_is_public(self, room_id, is_public): async def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id): def set_room_is_public_txn(txn, next_id):
self.db.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
updatevalues={"is_public": is_public}, updatevalues={"is_public": is_public},
) )
entries = self.db.simple_select_list_txn( entries = self.db_pool.simple_select_list_txn(
txn, txn,
table="public_room_list_stream", table="public_room_list_stream",
keyvalues={ keyvalues={
@ -1185,7 +1189,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream: if add_to_stream:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="public_room_list_stream", table="public_room_list_stream",
values={ values={
@ -1198,7 +1202,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
await self.db.runInteraction( await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id "set_room_is_public", set_room_is_public_txn, next_id
) )
self.hs.get_notifier().on_new_replication_data() self.hs.get_notifier().on_new_replication_data()
@ -1224,7 +1228,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def set_room_is_public_appservice_txn(txn, next_id): def set_room_is_public_appservice_txn(txn, next_id):
if is_public: if is_public:
try: try:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="appservice_room_list", table="appservice_room_list",
values={ values={
@ -1237,7 +1241,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
# We've already inserted, nothing to do. # We've already inserted, nothing to do.
return return
else: else:
self.db.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="appservice_room_list", table="appservice_room_list",
keyvalues={ keyvalues={
@ -1247,7 +1251,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
}, },
) )
entries = self.db.simple_select_list_txn( entries = self.db_pool.simple_select_list_txn(
txn, txn,
table="public_room_list_stream", table="public_room_list_stream",
keyvalues={ keyvalues={
@ -1265,7 +1269,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream: if add_to_stream:
self.db.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="public_room_list_stream", table="public_room_list_stream",
values={ values={
@ -1278,7 +1282,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
await self.db.runInteraction( await self.db_pool.runInteraction(
"set_room_is_public_appservice", "set_room_is_public_appservice",
set_room_is_public_appservice_txn, set_room_is_public_appservice_txn,
next_id, next_id,
@ -1295,13 +1299,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone() row = txn.fetchone()
return row[0] or 0 return row[0] or 0
return self.db.runInteraction("get_rooms", f) return self.db_pool.runInteraction("get_rooms", f)
def add_event_report( def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts self, room_id, event_id, user_id, reason, content, received_ts
): ):
next_id = self._event_reports_id_gen.get_next() next_id = self._event_reports_id_gen.get_next()
return self.db.simple_insert( return self.db_pool.simple_insert(
table="event_reports", table="event_reports",
values={ values={
"id": next_id, "id": next_id,
@ -1325,14 +1329,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
room_id: Room to block room_id: Room to block
user_id: Who blocked it user_id: Who blocked it
""" """
await self.db.simple_upsert( await self.db_pool.simple_upsert(
table="blocked_rooms", table="blocked_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
values={}, values={},
insertion_values={"user_id": user_id}, insertion_values={"user_id": user_id},
desc="block_room", desc="block_room",
) )
await self.db.runInteraction( await self.db_pool.runInteraction(
"block_room_invalidation", "block_room_invalidation",
self._invalidate_cache_and_stream, self._invalidate_cache_and_stream,
self.is_room_blocked, self.is_room_blocked,
@ -1388,7 +1392,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql, args) txn.execute(sql, args)
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
rooms_dict = {} rooms_dict = {}
for row in rows: for row in rows:
@ -1404,7 +1408,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql) txn.execute(sql)
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention # If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy. # policy in its state), add it with a null policy.
@ -1417,7 +1421,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict return rooms_dict
rooms = await self.db.runInteraction( rooms = await self.db_pool.runInteraction(
"get_rooms_for_retention_period_in_range", "get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn, get_rooms_for_retention_period_in_range_txn,
) )

View file

@ -28,8 +28,8 @@ from synapse.storage._base import (
db_to_json, db_to_json,
make_in_list_sql_clause, make_in_list_sql_clause,
) )
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import DatabasePool
from synapse.storage.database import Database from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import ( from synapse.storage.roommember import (
GetRoomsForUserWithStreamOrdering, GetRoomsForUserWithStreamOrdering,
@ -51,7 +51,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
# Is the current_state_events.membership up to date? Or is the # Is the current_state_events.membership up to date? Or is the
@ -116,7 +116,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query) txn.execute(query)
return list(txn)[0][0] return list(txn)[0][0]
count = yield self.db.runInteraction("get_known_servers", _transact) count = yield self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in # We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new). # room_memberships (for example, the server is new).
@ -128,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership column is up to date membership column is up to date
""" """
pending_update = self.db.simple_select_one_txn( pending_update = self.db_pool.simple_select_one_txn(
txn, txn,
table="background_updates", table="background_updates",
keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
@ -144,14 +144,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
15.0, 15.0,
run_as_background_process, run_as_background_process,
"_check_safe_current_state_events_membership_updated", "_check_safe_current_state_events_membership_updated",
self.db.runInteraction, self.db_pool.runInteraction,
"_check_safe_current_state_events_membership_updated", "_check_safe_current_state_events_membership_updated",
self._check_safe_current_state_events_membership_updated_txn, self._check_safe_current_state_events_membership_updated_txn,
) )
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id "get_users_in_room", self.get_users_in_room_txn, room_id
) )
@ -259,7 +259,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res return res
return self.db.runInteraction("get_room_summary", _get_room_summary_txn) return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
def _get_user_counts_in_room_txn(self, txn, room_id): def _get_user_counts_in_room_txn(self, txn, room_id):
""" """
@ -332,7 +332,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if not membership_list: if not membership_list:
return defer.succeed(None) return defer.succeed(None)
rooms = yield self.db.runInteraction( rooms = yield self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is", "get_rooms_for_local_user_where_membership_is",
self._get_rooms_for_local_user_where_membership_is_txn, self._get_rooms_for_local_user_where_membership_is_txn,
user_id, user_id,
@ -369,7 +369,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
txn.execute(sql, (user_id, *args)) txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
return results return results
@ -388,7 +388,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
the rooms the user is in currently, along with the stream ordering the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room. of the most recent join for that user and room.
""" """
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering", "get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn, self._get_rooms_for_user_with_stream_ordering_txn,
user_id, user_id,
@ -453,7 +453,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0] for row in txn} return {row[0] for row in txn}
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"get_users_server_still_shares_room_with", "get_users_server_still_shares_room_with",
_get_users_server_still_shares_room_with_txn, _get_users_server_still_shares_room_with_txn,
) )
@ -624,7 +624,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
to `user_id` and ProfileInfo (or None if not join event). to `user_id` and ProfileInfo (or None if not join event).
""" """
rows = yield self.db.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
@ -664,7 +664,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain. # the returned user actually has the correct domain.
like_clause = "%:" + host like_clause = "%:" + host
rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) rows = yield self.db_pool.execute(
"is_host_joined", None, sql, room_id, like_clause
)
if not rows: if not rows:
return False return False
@ -704,7 +706,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain. # the returned user actually has the correct domain.
like_clause = "%:" + host like_clause = "%:" + host
rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) rows = yield self.db_pool.execute(
"was_host_joined", None, sql, room_id, like_clause
)
if not rows: if not rows:
return False return False
@ -774,7 +778,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = txn.fetchall() rows = txn.fetchall()
return rows[0][0] return rows[0][0]
count = yield self.db.runInteraction("did_forget_membership", f) count = yield self.db_pool.runInteraction("did_forget_membership", f)
return count == 0 return count == 0
@cached() @cached()
@ -811,7 +815,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0} return {row[0] for row in txn if row[1] == 0}
return self.db.runInteraction( return self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
) )
@ -826,7 +830,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred[set[str]]: Set of room IDs. Deferred[set[str]]: Set of room IDs.
""" """
room_ids = yield self.db.simple_select_onecol( room_ids = yield self.db_pool.simple_select_onecol(
table="room_memberships", table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id}, keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id", retcol="room_id",
@ -841,7 +845,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Get user_id and membership of a set of event IDs. """Get user_id and membership of a set of event IDs.
""" """
return self.db.simple_select_many_batch( return self.db_pool.simple_select_many_batch(
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=member_event_ids, iterable=member_event_ids,
@ -877,23 +881,23 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return bool(txn.fetchone()) return bool(txn.fetchone())
return await self.db.runInteraction( return await self.db_pool.runInteraction(
"is_local_host_in_room_ignoring_users", "is_local_host_in_room_ignoring_users",
_is_local_host_in_room_ignoring_users_txn, _is_local_host_in_room_ignoring_users_txn,
) )
class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
) )
self.db.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
self._background_current_state_membership, self._background_current_state_membership,
) )
self.db.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"room_membership_forgotten_idx", "room_membership_forgotten_idx",
index_name="room_memberships_user_room_forgotten", index_name="room_memberships_user_room_forgotten",
table="room_memberships", table="room_memberships",
@ -926,7 +930,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = self.db.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if not rows: if not rows:
return 0 return 0
@ -961,18 +965,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive": min_stream_id, "max_stream_id_exclusive": min_stream_id,
} }
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
) )
return len(rows) return len(rows)
result = yield self.db.runInteraction( result = yield self.db_pool.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
) )
if not result: if not result:
yield self.db.updates._end_background_update( yield self.db_pool.updates._end_background_update(
_MEMBERSHIP_PROFILE_UPDATE_NAME _MEMBERSHIP_PROFILE_UPDATE_NAME
) )
@ -1013,7 +1017,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
last_processed_room = next_room last_processed_room = next_room
self.db.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, txn,
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
{"last_processed_room": last_processed_room}, {"last_processed_room": last_processed_room},
@ -1025,14 +1029,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
# string, which will compare before all room IDs correctly. # string, which will compare before all room IDs correctly.
last_processed_room = progress.get("last_processed_room", "") last_processed_room = progress.get("last_processed_room", "")
row_count, finished = yield self.db.runInteraction( row_count, finished = yield self.db_pool.runInteraction(
"_background_current_state_membership_update", "_background_current_state_membership_update",
_background_current_state_membership_txn, _background_current_state_membership_txn,
last_processed_room, last_processed_room,
) )
if finished: if finished:
yield self.db.updates._end_background_update( yield self.db_pool.updates._end_background_update(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
) )
@ -1040,7 +1044,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs) super(RoomMemberStore, self).__init__(database, db_conn, hs)
def forget(self, user_id, room_id): def forget(self, user_id, room_id):
@ -1064,7 +1068,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,) txn, self.get_forgotten_rooms_for_user, (user_id,)
) )
return self.db.runInteraction("forget_membership", f) return self.db_pool.runInteraction("forget_membership", f)
class _JoinedHostsCache(object): class _JoinedHostsCache(object):

Some files were not shown because too many files have changed in this diff Show more