Merge pull request #6469 from matrix-org/erikj/make_database_class

Create a Database class and move methods out of SQLBaseStore
This commit is contained in:
Erik Johnston 2019-12-06 11:56:59 +00:00 committed by GitHub
commit f3ea2f5a08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
68 changed files with 2686 additions and 2515 deletions

1
changelog.d/6469.misc Normal file
View file

@ -0,0 +1 @@
Move per database functionality out of the data stores and into a dedicated `Database` class.

View file

@ -58,10 +58,10 @@ if __name__ == "__main__":
" on it."
)
)
parser.add_argument("-v", action='store_true')
parser.add_argument("-v", action="store_true")
parser.add_argument(
"--database-config",
type=argparse.FileType('r'),
type=argparse.FileType("r"),
required=True,
help="A database config file for either a SQLite3 database or a PostgreSQL one.",
)
@ -101,10 +101,7 @@ if __name__ == "__main__":
# Instantiate and initialise the homeserver object.
hs = MockHomeserver(
config,
database_engine,
db_conn,
db_config=config.database_config,
config, database_engine, db_conn, db_config=config.database_config,
)
# setup instantiates the store within the homeserver object.
hs.setup()
@ -112,13 +109,13 @@ if __name__ == "__main__":
@defer.inlineCallbacks
def run_background_updates():
yield store.run_background_updates(sleep=False)
yield store.db.updates.run_background_updates(sleep=False)
# Stop the reactor to exit the script once every background update is run.
reactor.stop()
# Apply all background updates on the database.
reactor.callWhenRunning(lambda: run_as_background_process(
"background_updates", run_background_updates
))
reactor.callWhenRunning(
lambda: run_as_background_process("background_updates", run_background_updates)
)
reactor.run()

View file

@ -173,14 +173,14 @@ class Store(
return (yield self.db_pool.runWithConnection(r))
def execute(self, f, *args, **kwargs):
return self.runInteraction(f.__name__, f, *args, **kwargs)
return self.db.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args):
def r(txn):
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction("execute_sql", r)
return self.db.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
@ -223,7 +223,7 @@ class Porter(object):
def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
row = yield self.postgres_store.simple_select_one(
row = yield self.postgres_store.db.simple_select_one(
table="port_from_sqlite3",
keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"),
@ -233,12 +233,14 @@ class Porter(object):
total_to_port = None
if row is None:
if table == "sent_transactions":
forward_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions()
)
(
forward_chunk,
already_ported,
total_to_port,
) = yield self._setup_sent_transactions()
backward_chunk = 0
else:
yield self.postgres_store.simple_insert(
yield self.postgres_store.db.simple_insert(
table="port_from_sqlite3",
values={
"table_name": table,
@ -268,7 +270,7 @@ class Porter(object):
yield self.postgres_store.execute(delete_all)
yield self.postgres_store.simple_insert(
yield self.postgres_store.db.simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
)
@ -322,7 +324,7 @@ class Porter(object):
if table == "user_directory_stream_pos":
# We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there.
yield self.postgres_store.simple_insert(
yield self.postgres_store.db.simple_insert(
table=table, values={"stream_id": None}
)
self.progress.update(table, table_size) # Mark table as done
@ -363,7 +365,9 @@ class Porter(object):
return headers, forward_rows, backward_rows
headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
headers, frows, brows = yield self.sqlite_store.db.runInteraction(
"select", r
)
if frows or brows:
if frows:
@ -377,7 +381,7 @@ class Porter(object):
def insert(txn):
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
self.postgres_store.simple_update_one_txn(
self.postgres_store.db.simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
@ -416,7 +420,7 @@ class Porter(object):
return headers, rows
headers, rows = yield self.sqlite_store.runInteraction("select", r)
headers, rows = yield self.sqlite_store.db.runInteraction("select", r)
if rows:
forward_chunk = rows[-1][0] + 1
@ -433,8 +437,8 @@ class Porter(object):
rows_dict = []
for row in rows:
d = dict(zip(headers, row))
if "\0" in d['value']:
logger.warning('dropping search row %s', d)
if "\0" in d["value"]:
logger.warning("dropping search row %s", d)
else:
rows_dict.append(d)
@ -454,7 +458,7 @@ class Porter(object):
],
)
self.postgres_store.simple_update_one_txn(
self.postgres_store.db.simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": "event_search"},
@ -504,17 +508,14 @@ class Porter(object):
self.progress.set_state("Preparing %s" % config["name"])
conn = self.setup_db(config, engine)
db_pool = adbapi.ConnectionPool(
config["name"], **config["args"]
)
db_pool = adbapi.ConnectionPool(config["name"], **config["args"])
hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
store = Store(conn, hs)
yield store.runInteraction(
"%s_engine.check_database" % config["name"],
engine.check_database,
yield store.db.runInteraction(
"%s_engine.check_database" % config["name"], engine.check_database,
)
return store
@ -522,7 +523,9 @@ class Porter(object):
@defer.inlineCallbacks
def run_background_updates_on_postgres(self):
# Manually apply all background updates on the PostgreSQL database.
postgres_ready = yield self.postgres_store.has_completed_background_updates()
postgres_ready = (
yield self.postgres_store.db.updates.has_completed_background_updates()
)
if not postgres_ready:
# Only say that we're running background updates when there are background
@ -530,9 +533,9 @@ class Porter(object):
self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready:
yield self.postgres_store.do_next_background_update(100)
yield self.postgres_store.db.updates.do_next_background_update(100)
postgres_ready = yield (
self.postgres_store.has_completed_background_updates()
self.postgres_store.db.updates.has_completed_background_updates()
)
@defer.inlineCallbacks
@ -541,7 +544,9 @@ class Porter(object):
self.sqlite_store = yield self.build_db_store(self.sqlite_config)
# Check if all background updates are done, abort if not.
updates_complete = yield self.sqlite_store.has_completed_background_updates()
updates_complete = (
yield self.sqlite_store.db.updates.has_completed_background_updates()
)
if not updates_complete:
sys.stderr.write(
"Pending background updates exist in the SQLite3 database."
@ -582,22 +587,22 @@ class Porter(object):
)
try:
yield self.postgres_store.runInteraction("alter_table", alter_table)
yield self.postgres_store.db.runInteraction("alter_table", alter_table)
except Exception:
# On Error Resume Next
pass
yield self.postgres_store.runInteraction(
yield self.postgres_store.db.runInteraction(
"create_port_table", create_port_table
)
# Step 2. Get tables.
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store.simple_select_onecol(
sqlite_tables = yield self.sqlite_store.db.simple_select_onecol(
table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
)
postgres_tables = yield self.postgres_store.simple_select_onecol(
postgres_tables = yield self.postgres_store.db.simple_select_onecol(
table="information_schema.tables",
keyvalues={},
retcol="distinct table_name",
@ -687,11 +692,11 @@ class Porter(object):
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
ts_ind = headers.index('ts')
ts_ind = headers.index("ts")
return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.runInteraction("select", r)
headers, rows = yield self.sqlite_store.db.runInteraction("select", r)
rows = self._convert_rows("sent_transactions", headers, rows)
@ -724,7 +729,7 @@ class Porter(object):
next_chunk = yield self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk)
yield self.postgres_store.simple_insert(
yield self.postgres_store.db.simple_insert(
table="port_from_sqlite3",
values={
"table_name": "sent_transactions",
@ -737,7 +742,7 @@ class Porter(object):
txn.execute(
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
)
size, = txn.fetchone()
(size,) = txn.fetchone()
return int(size)
remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
@ -790,7 +795,7 @@ class Porter(object):
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
##############################################
@ -871,7 +876,7 @@ class CursesProgress(Progress):
duration = int(now) - int(self.start_time)
minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds)
duration_str = "%02dm %02ds" % (minutes, seconds)
if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,)
@ -881,7 +886,7 @@ class CursesProgress(Progress):
left = float(self.total_remaining) / self.total_processed
est_remaining = (int(now) - self.start_time) * left
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60)
else:
est_remaining_str = "Unknown"
status = "Time spent: %s (est. remaining: %s)" % (
@ -967,7 +972,7 @@ if __name__ == "__main__":
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
)
parser.add_argument("-v", action='store_true')
parser.add_argument("-v", action="store_true")
parser.add_argument(
"--sqlite-database",
required=True,
@ -976,12 +981,12 @@ if __name__ == "__main__":
)
parser.add_argument(
"--postgres-config",
type=argparse.FileType('r'),
type=argparse.FileType("r"),
required=True,
help="The database config file for the PostgreSQL database",
)
parser.add_argument(
"--curses", action='store_true', help="display a curses based progress UI"
"--curses", action="store_true", help="display a curses based progress UI"
)
parser.add_argument(

View file

@ -269,7 +269,7 @@ def start(hs, listeners=None):
# It is now safe to start your Synapse.
hs.start_listening(listeners)
hs.get_datastore().start_profiling()
hs.get_datastore().db.start_profiling()
setup_sentry(hs)
setup_sdnotify(hs)

View file

@ -436,7 +436,7 @@ def setup(config_options):
_base.start(hs, config.listeners)
hs.get_pusherpool().start()
hs.get_datastore().start_doing_background_updates()
hs.get_datastore().db.updates.start_doing_background_updates()
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)

View file

@ -64,7 +64,7 @@ class UserDirectorySlaveStore(
super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict(
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",

View file

@ -175,4 +175,4 @@ class ModuleApi(object):
Returns:
Deferred[object]: result of func
"""
return self._store.runInteraction(desc, func, *args, **kwargs)
return self._store.db.runInteraction(desc, func, *args, **kwargs)

View file

@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeResource):
logger.info("Running url preview cache expiry")
if not (yield self.store.has_completed_background_updates()):
if not (yield self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return

View file

@ -17,10 +17,10 @@
"""
The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple
databases). The `data_stores` are classes that talk directly to a single
database and have associated schemas, background updates, etc. On top of those
there are (or will be) classes that provide high level interfaces that combine
calls to multiple `data_stores`.
databases). The `Database` class represents a single physical database. The
`data_stores` are classes that talk directly to a `Database` instance and have
associated schemas, background updates, etc. On top of those there are classes
that provide high level interfaces that combine calls to multiple `data_stores`.
There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are

File diff suppressed because it is too large Load diff

View file

@ -22,7 +22,6 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@ -74,7 +73,7 @@ class BackgroundUpdatePerformance(object):
return float(self.total_item_count) / float(self.total_duration_ms)
class BackgroundUpdateStore(SQLBaseStore):
class BackgroundUpdater(object):
""" Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
@ -86,8 +85,10 @@ class BackgroundUpdateStore(SQLBaseStore):
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, db_conn, hs):
super(BackgroundUpdateStore, self).__init__(db_conn, hs)
def __init__(self, hs, database):
self._clock = hs.get_clock()
self.db = database
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
@ -101,9 +102,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info("Starting background schema updates")
while True:
if sleep:
yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0
)
yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try:
result = yield self.do_next_background_update(
@ -139,7 +138,7 @@ class BackgroundUpdateStore(SQLBaseStore):
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
updates = yield self.simple_select_onecol(
updates = yield self.db.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
@ -161,7 +160,7 @@ class BackgroundUpdateStore(SQLBaseStore):
if update_name in self._background_update_queue:
return False
update_exists = await self.simple_select_one_onecol(
update_exists = await self.db.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="1",
@ -184,7 +183,7 @@ class BackgroundUpdateStore(SQLBaseStore):
no more work to do.
"""
if not self._background_update_queue:
updates = yield self.simple_select_list(
updates = yield self.db.simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name", "depends_on"),
@ -226,7 +225,7 @@ class BackgroundUpdateStore(SQLBaseStore):
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
progress_json = yield self.simple_select_one_onecol(
progress_json = yield self.db.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json",
@ -380,7 +379,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.debug("[SQL] %s", sql)
c.execute(sql)
if isinstance(self.database_engine, engines.PostgresEngine):
if isinstance(self.db.database_engine, engines.PostgresEngine):
runner = create_index_psql
elif psql_only:
runner = None
@ -391,7 +390,7 @@ class BackgroundUpdateStore(SQLBaseStore):
def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
yield self.runWithConnection(runner)
yield self.db.runWithConnection(runner)
yield self._end_background_update(update_name)
return 1
@ -413,7 +412,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = []
progress_json = json.dumps(progress)
return self.simple_insert(
return self.db.simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json},
)
@ -429,7 +428,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
return self.simple_delete_one(
return self.db.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
@ -444,7 +443,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = json.dumps(progress)
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},

View file

@ -169,9 +169,11 @@ class DataStore(
else:
self._cache_id_gen = None
super(DataStore, self).__init__(db_conn, hs)
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self.get_cache_dict(
presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
db_conn,
"presence_stream",
entity_column="user_id",
@ -185,7 +187,7 @@ class DataStore(
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.get_cache_dict(
device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
@ -200,7 +202,7 @@ class DataStore(
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self.get_cache_dict(
device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
@ -226,7 +228,7 @@ class DataStore(
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict(
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
@ -240,7 +242,7 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill,
)
_group_updates_prefill, min_group_updates_id = self.get_cache_dict(
_group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
db_conn,
"local_group_updates",
entity_column="user_id",
@ -260,8 +262,6 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
super(DataStore, self).__init__(db_conn, hs)
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
@ -281,7 +281,7 @@ class DataStore(
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
txn.close()
for row in rows:
@ -294,7 +294,7 @@ class DataStore(
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
return self.runInteraction("count_daily_users", self._count_users, yesterday)
return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
def count_monthly_users(self):
"""
@ -304,7 +304,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
return self.runInteraction(
return self.db.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
@ -404,7 +404,7 @@ class DataStore(
return results
return self.runInteraction("count_r30_users", _count_r30_users)
return self.db.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
@ -469,7 +469,7 @@ class DataStore(
# frequently
self._last_user_visit_update = now
return self.runInteraction(
return self.db.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
@ -480,7 +480,7 @@ class DataStore(
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.simple_select_list(
return self.db.simple_select_list(
table="users",
keyvalues={},
retcols=[
@ -519,7 +519,7 @@ class DataStore(
if not deactivated:
attr_filter["deactivated"] = False
return self.simple_select_list_paginate(
return self.db.simple_select_list_paginate(
desc="get_users_paginate",
table="users",
orderby="name",
@ -547,7 +547,7 @@ class DataStore(
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.simple_search_list(
return self.db.simple_search_list(
table="users",
term=term,
col="name",

View file

@ -67,7 +67,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_user_txn(txn):
rows = self.simple_select_list_txn(
rows = self.db.simple_select_list_txn(
txn,
"account_data",
{"user_id": user_id},
@ -78,7 +78,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
rows = self.simple_select_list_txn(
rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id},
@ -92,7 +92,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
return self.runInteraction(
return self.db.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@ -102,7 +102,7 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns:
Deferred: A dict
"""
result = yield self.simple_select_one_onecol(
result = yield self.db.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@ -127,7 +127,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_txn(txn):
rows = self.simple_select_list_txn(
rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
@ -138,7 +138,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
return self.runInteraction(
return self.db.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@ -156,7 +156,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_and_type_txn(txn):
content_json = self.simple_select_one_onecol_txn(
content_json = self.db.simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return json.loads(content_json) if content_json else None
return self.runInteraction(
return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@ -207,7 +207,7 @@ class AccountDataWorkerStore(SQLBaseStore):
room_results = txn.fetchall()
return global_results, room_results
return self.runInteraction(
return self.db.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
)
@ -252,7 +252,7 @@ class AccountDataWorkerStore(SQLBaseStore):
if not changed:
return {}, {}
return self.runInteraction(
return self.db.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@ -302,7 +302,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
yield self.simple_upsert(
yield self.db.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@ -348,7 +348,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
yield self.simple_upsert(
yield self.db.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@ -388,4 +388,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
return self.runInteraction("update_account_data_max_stream_id", _update)
return self.db.runInteraction("update_account_data_max_stream_id", _update)

View file

@ -133,7 +133,7 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
"""
results = yield self.simple_select_list(
results = yield self.db.simple_select_list(
"application_services_state", dict(state=state), ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
@ -155,7 +155,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves to ApplicationServiceState.
"""
result = yield self.simple_select_one(
result = yield self.db.simple_select_one(
"application_services_state",
dict(as_id=service.id),
["state"],
@ -175,7 +175,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves when the state was set successfully.
"""
return self.simple_upsert(
return self.db.simple_upsert(
"application_services_state", dict(as_id=service.id), dict(state=state)
)
@ -216,7 +216,7 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.runInteraction("create_appservice_txn", _create_appservice_txn)
return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@ -249,7 +249,7 @@ class ApplicationServiceTransactionWorkerStore(
)
# Set current txn_id for AS to 'txn_id'
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
"application_services_state",
dict(as_id=service.id),
@ -257,11 +257,13 @@ class ApplicationServiceTransactionWorkerStore(
)
# Delete txn
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
)
return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
return self.db.runInteraction(
"complete_appservice_txn", _complete_appservice_txn
)
@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
@ -283,7 +285,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return None
@ -291,7 +293,7 @@ class ApplicationServiceTransactionWorkerStore(
return entry
entry = yield self.runInteraction(
entry = yield self.db.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
@ -321,7 +323,7 @@ class ApplicationServiceTransactionWorkerStore(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.runInteraction(
return self.db.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@ -350,7 +352,7 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
upper_bound, event_ids = yield self.db.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)

View file

@ -95,7 +95,7 @@ class CacheInvalidationStore(SQLBaseStore):
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="cache_invalidation_stream",
values={
@ -122,7 +122,9 @@ class CacheInvalidationStore(SQLBaseStore):
txn.execute(sql, (last_id, limit))
return txn.fetchall()
return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
return self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self):
if self._cache_id_gen:

View file

@ -20,7 +20,7 @@ from six import iteritems
from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
@ -32,41 +32,41 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"user_ips_last_seen_index",
index_name="user_ips_last_seen",
table="user_ips",
columns=["user_id", "last_seen"],
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"user_ips_last_seen_only_index",
index_name="user_ips_last_seen_only",
table="user_ips",
columns=["last_seen"],
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"user_ips_analyze", self._analyze_user_ip
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"user_ips_remove_dupes", self._remove_user_ip_dupes
)
# Register a unique index
self.register_background_index_update(
self.db.updates.register_background_index_update(
"user_ips_device_unique_index",
index_name="user_ips_user_token_ip_unique_index",
table="user_ips",
@ -75,12 +75,12 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
)
# Drop the old non-unique index
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
# Update the last seen info in devices.
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"devices_last_seen", self._devices_last_seen_update
)
@ -91,8 +91,8 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update("user_ips_drop_nonunique_index")
yield self.db.runWithConnection(f)
yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
return 1
@defer.inlineCallbacks
@ -106,9 +106,9 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
yield self.runInteraction("user_ips_analyze", user_ips_analyze)
yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
yield self._end_background_update("user_ips_analyze")
yield self.db.updates._end_background_update("user_ips_analyze")
return 1
@ -140,7 +140,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
end_last_seen = yield self.runInteraction(
end_last_seen = yield self.db.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
@ -271,14 +271,14 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip, device_id, user_agent, last_seen),
)
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
yield self.runInteraction("user_ips_dups_remove", remove)
yield self.db.runInteraction("user_ips_dups_remove", remove)
if last:
yield self._end_background_update("user_ips_remove_dupes")
yield self.db.updates._end_background_update("user_ips_remove_dupes")
return batch_size
@ -344,7 +344,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
@ -352,12 +352,12 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
return len(rows)
updated = yield self.runInteraction(
updated = yield self.db.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
yield self.db.updates._end_background_update("devices_last_seen")
return updated
@ -417,12 +417,12 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
return self.runInteraction(
return self.db.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(self, txn, to_update):
if "user_ips" in self._unsafe_to_upsert_tables or (
if "user_ips" in self.db._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
@ -431,7 +431,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
@ -450,7 +450,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
@ -483,7 +483,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if device_id is not None:
keyvalues["device_id"] = device_id
res = yield self.simple_select_list(
res = yield self.db.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@ -516,7 +516,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
rows = yield self.simple_select_list(
rows = yield self.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
@ -546,7 +546,9 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Nothing to do
return
if not await self.has_completed_background_update("devices_last_seen"):
if not await self.db.updates.has_completed_background_update(
"devices_last_seen"
):
# Only start pruning if we have finished populating the devices
# last seen info.
return
@ -577,4 +579,4 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _prune_old_user_ips_txn(txn):
txn.execute(sql, (timestamp,))
await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)

View file

@ -21,7 +21,6 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__)
@ -69,7 +68,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
return self.runInteraction(
return self.db.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
@ -109,7 +108,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.runInteraction(
count = yield self.db.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
@ -178,7 +177,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
return self.runInteraction(
return self.db.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
@ -203,25 +202,25 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
return self.db.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
@ -232,9 +231,9 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self.db.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
@ -294,7 +293,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
yield self.db.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@ -314,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
already_inserted = self.simple_select_one_txn(
already_inserted = self.db.simple_select_one_txn(
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
@ -326,7 +325,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add an entry for this message_id so that we know we've processed
# it.
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="device_federation_inbox",
values={
@ -344,7 +343,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
yield self.db.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
@ -465,6 +464,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
return rows
return self.runInteraction(
return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)

View file

@ -31,7 +31,6 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import (
@ -61,7 +60,7 @@ class DeviceWorkerStore(SQLBaseStore):
Raises:
StoreError: if the device is not found
"""
return self.simple_select_one(
return self.db.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@ -80,7 +79,7 @@ class DeviceWorkerStore(SQLBaseStore):
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self.simple_select_list(
devices = yield self.db.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@ -122,7 +121,7 @@ class DeviceWorkerStore(SQLBaseStore):
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
@ -283,7 +282,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
devices = (
yield self.runInteraction(
yield self.db.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
@ -340,12 +339,12 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
return self.runInteraction("get_last_device_update_for_remote_user", f)
return self.db.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
return self.db.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@ -399,7 +398,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
yield self.db.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
from_user_id,
@ -414,7 +413,7 @@ class DeviceWorkerStore(SQLBaseStore):
from_user_id,
stream_id,
)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
"user_signature_stream",
values={
@ -466,7 +465,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self.simple_select_one_onecol(
content = yield self.db.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
@ -476,7 +475,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
devices = yield self.simple_select_list(
devices = yield self.db.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
@ -492,7 +491,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
(stream_id, devices)
"""
return self.runInteraction(
return self.db.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
@ -565,7 +564,7 @@ class DeviceWorkerStore(SQLBaseStore):
return changes
return self.runInteraction(
return self.db.runInteraction(
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
@ -584,7 +583,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ?
"""
rows = yield self.execute(
rows = yield self.db.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
return set(user for row in rows for user in json.loads(row[0]))
@ -605,7 +604,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self.execute(
return self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@ -614,7 +613,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@ -628,7 +627,7 @@ class DeviceWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@ -642,11 +641,11 @@ class DeviceWorkerStore(SQLBaseStore):
return results
class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
table="device_lists_stream",
@ -654,7 +653,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
)
# create a unique index on device_lists_remote_cache
self.register_background_index_update(
self.db.updates.register_background_index_update(
"device_lists_remote_cache_unique_idx",
index_name="device_lists_remote_cache_unique_id",
table="device_lists_remote_cache",
@ -663,7 +662,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
)
# And one on device_lists_remote_extremeties
self.register_background_index_update(
self.db.updates.register_background_index_update(
"device_lists_remote_extremeties_unique_idx",
index_name="device_lists_remote_extremeties_unique_idx",
table="device_lists_remote_extremeties",
@ -672,7 +671,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
)
# once they complete, we can remove the old non-unique indexes.
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
self._drop_device_list_streams_non_unique_indexes,
)
@ -685,8 +684,10 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
yield self.db.runWithConnection(f)
yield self.db.updates._end_background_update(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
)
return 1
@ -722,7 +723,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False
try:
inserted = yield self.simple_insert(
inserted = yield self.db.simple_insert(
"devices",
values={
"user_id": user_id,
@ -736,7 +737,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
hidden = yield self.simple_select_one_onecol(
hidden = yield self.db.simple_select_one_onecol(
"devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden",
@ -771,7 +772,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
defer.Deferred
"""
yield self.simple_delete_one(
yield self.db.simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
@ -789,7 +790,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
defer.Deferred
"""
yield self.simple_delete_many(
yield self.db.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
@ -818,7 +819,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
return self.simple_update_one(
return self.db.simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
@ -829,7 +830,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
yield self.simple_delete(
yield self.db.simple_delete(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
@ -853,7 +854,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
Deferred[None]
"""
return self.runInteraction(
return self.db.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@ -866,7 +867,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self, txn, user_id, device_id, content, stream_id
):
if content.get("deleted"):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@ -874,7 +875,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@ -890,7 +891,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@ -914,7 +915,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
Deferred[None]
"""
return self.runInteraction(
return self.db.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@ -923,11 +924,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
@ -946,7 +947,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@ -962,7 +963,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
yield self.db.runInteraction(
"add_device_change_to_streams",
self._add_device_change_txn,
user_id,
@ -995,7 +996,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[(user_id, device_id, stream_id) for device_id in device_ids],
)
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
@ -1006,7 +1007,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map()
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
@ -1069,7 +1070,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return run_as_background_process(
"prune_old_outbound_device_pokes",
self.runInteraction,
self.db.runInteraction,
"_prune_old_outbound_device_pokes",
_prune_txn,
)

View file

@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore):
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
"""
room_id = yield self.simple_select_one_onecol(
room_id = yield self.db.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id:
return None
servers = yield self.simple_select_onecol(
servers = yield self.db.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore):
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
return self.simple_select_onecol(
return self.db.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore):
"""
def alias_txn(txn):
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
"room_aliases",
{
@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore):
},
)
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[
@ -117,7 +117,9 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
ret = yield self.runInteraction("create_room_alias_association", alias_txn)
ret = yield self.db.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
@ -126,7 +128,7 @@ class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction(
room_id = yield self.db.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
@ -168,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
return self.runInteraction(
return self.db.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)

View file

@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
yield self.simple_update_one(
yield self.db.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
}
)
yield self.simple_insert_many(
yield self.db.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
rows = yield self.simple_select_list(
rows = yield self.db.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@ -170,7 +170,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
"""
return self.runInteraction(
return self.db.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
@ -234,7 +234,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version (str): the version ID of the backup we're querying about
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
@ -267,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
yield self.simple_delete(
yield self.db.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@ -312,7 +312,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there.
raise StoreError(404, "No row found")
result = self.simple_select_one_txn(
result = self.db.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
@ -324,7 +324,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result["etag"] = 0
return result
return self.runInteraction(
return self.db.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@ -352,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
new_version = str(int(current_version) + 1)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version
return self.runInteraction(
return self.db.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@ -391,7 +391,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag
if updatevalues:
return self.simple_update(
return self.db.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
@ -420,19 +420,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
else:
this_version = version
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": this_version},
)
return self.simple_update_one_txn(
return self.db.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
return self.runInteraction(
return self.db.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)

View file

@ -48,7 +48,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
results = yield self.runInteraction(
results = yield self.db.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
result = {}
for row in rows:
@ -143,7 +143,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(signature_sql, signature_query_params)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
# add each cross-signing signature to the correct device in the result dict.
for row in rows:
@ -186,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
key_id) to json string for key
"""
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
@ -219,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
values=[
@ -238,7 +238,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
yield self.runInteraction(
yield self.db.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@ -261,7 +261,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
return self.db.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
@ -322,7 +324,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
dict of the key data or None if not found
"""
return self.runInteraction(
return self.db.runInteraction(
"get_e2e_cross_signing_key",
self._get_e2e_cross_signing_key_txn,
user_id,
@ -350,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id
"""
return self.execute(
return self.db.execute(
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
)
@ -367,7 +369,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
old_key_json = self.simple_select_one_onecol_txn(
old_key_json = self.db.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@ -383,7 +385,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"Message": "Device key already stored."})
return False
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@ -392,7 +394,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
@ -431,7 +433,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
return self.db.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
def delete_e2e_keys_by_device(self, user_id, device_id):
def delete_e2e_keys_by_device_txn(txn):
@ -442,12 +446,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"user_id": user_id,
}
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@ -456,7 +460,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return self.runInteraction(
return self.db.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
@ -492,7 +496,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# The "keys" property must only have one entry, which will be the public
# key, so we just grab the first value in there
pubkey = next(iter(key["keys"].values()))
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
"devices",
values={
@ -505,7 +509,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
"e2e_cross_signing_keys",
values={
@ -524,7 +528,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
return self.runInteraction(
return self.db.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
user_id,
@ -539,7 +543,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
user_id (str): the user who made the signatures
signatures (iterable[SignatureListItem]): signatures to add
"""
return self.simple_insert_many(
return self.db.simple_insert_many(
"e2e_cross_signing_signatures",
[
{

View file

@ -58,7 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
list of event_ids
"""
return self.runInteraction(
return self.db.runInteraction(
"get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
)
@ -90,12 +90,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
def get_oldest_events_in_room(self, room_id):
return self.runInteraction(
return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
return self.runInteraction(
return self.db.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@ -126,7 +126,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns
Deferred[int]
"""
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@ -140,7 +140,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self.simple_select_onecol_txn(
return self.db.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
@ -188,7 +188,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
where *hashes* is a map from algorithm to hash.
"""
return self.runInteraction(
return self.db.runInteraction(
"get_latest_event_ids_and_hashes_in_room",
self._get_latest_event_ids_and_hashes_in_room,
room_id,
@ -229,13 +229,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args)
return [room_id for room_id, in txn]
return self.runInteraction(
return self.db.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
@ -266,12 +266,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
"""
return self.runInteraction(
return self.db.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self.simple_select_one_onecol_txn(
min_depth = self.db.simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@ -337,7 +337,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
return self.runInteraction(
return self.db.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
@ -352,7 +352,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit (int)
"""
return (
self.runInteraction(
self.db.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
@ -383,7 +383,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue()
for event_id in event_list:
depth = self.simple_select_one_onecol_txn(
depth = self.db.simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": event_id, "room_id": room_id},
@ -415,7 +415,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = yield self.runInteraction(
ids = yield self.db.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@ -468,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
Deferred[list[str]]
"""
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@ -494,7 +494,7 @@ class EventFederationStore(EventFederationWorkerStore):
def __init__(self, db_conn, hs):
super(EventFederationStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
@ -508,7 +508,7 @@ class EventFederationStore(EventFederationWorkerStore):
if min_depth and depth >= min_depth:
return
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@ -520,7 +520,7 @@ class EventFederationStore(EventFederationWorkerStore):
For the given event, update the event edges table and forward and
backward extremities tables.
"""
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="event_edges",
values=[
@ -604,13 +604,13 @@ class EventFederationStore(EventFederationWorkerStore):
return run_as_background_process(
"delete_old_forward_extrem_cache",
self.runInteraction,
self.db.runInteraction,
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
return self.runInteraction(
return self.db.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
@ -654,17 +654,17 @@ class EventFederationStore(EventFederationWorkerStore):
"max_stream_id_exclusive": min_stream_id,
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.EVENT_AUTH_STATE_ONLY, new_progress
)
return min_stream_id >= target_min_stream_id
result = yield self.runInteraction(
result = yield self.db.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
return batch_size

View file

@ -93,7 +93,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
ret = yield self.runInteraction(
ret = yield self.db.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
@ -177,7 +177,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
return ret
@defer.inlineCallbacks
@ -229,7 +229,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@ -257,7 +257,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@ -329,7 +329,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@ -357,7 +357,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@ -407,7 +407,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
return self.runInteraction(
return self.db.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
@ -458,7 +458,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
return self.runInteraction(
return self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
@ -472,7 +472,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
res = yield self.simple_delete(
res = yield self.db.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
@ -489,7 +489,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _find_stream_orderings_for_times(self):
return run_as_background_process(
"event_push_action_stream_orderings",
self.runInteraction,
self.db.runInteraction,
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
@ -525,7 +525,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Deferred[int]: stream ordering of the first event received on/after
the timestamp
"""
return self.runInteraction(
return self.db.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@ -614,14 +614,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def __init__(self, db_conn, hs):
super(EventPushActionsStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight",
table="event_push_actions",
columns=["user_id", "stream_ordering"],
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index",
table="event_push_actions",
@ -677,7 +677,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
for event, _ in events_and_contexts:
user_ids = self.simple_select_onecol_txn(
user_ids = self.db.simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
@ -727,9 +727,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
push_actions = yield self.runInteraction("get_push_actions_for_user", f)
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
@ -748,7 +748,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
@defer.inlineCallbacks
@ -757,7 +757,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
result = yield self.db.runInteraction(
"get_latest_push_action_stream_ordering", f
)
return result[0] or 0
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
@ -830,7 +832,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
caught_up = yield self.runInteraction(
caught_up = yield self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
@ -844,7 +846,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not.
"""
old_rotate_stream_ordering = self.simple_select_one_onecol_txn(
old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@ -880,7 +882,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
old_rotate_stream_ordering = self.simple_select_one_onecol_txn(
old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@ -912,7 +914,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="event_push_summary",
values=[

View file

@ -38,7 +38,6 @@ from synapse.logging.utils import log_function
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.event_federation import EventFederationStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
@ -94,10 +93,7 @@ def _retry_on_integrity_error(func):
# inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here)
class EventsStore(
StateGroupWorkerStore,
EventFederationStore,
EventsWorkerStore,
BackgroundUpdateStore,
StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
):
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
@ -143,7 +139,7 @@ class EventsStore(
)
return txn.fetchall()
res = yield self.runInteraction("read_forward_extremities", fetch)
res = yield self.db.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
@_retry_on_integrity_error
@ -208,7 +204,7 @@ class EventsStore(
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
yield self.runInteraction(
yield self.db.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@ -281,7 +277,7 @@ class EventsStore(
results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
yield self.runInteraction(
yield self.db.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
@ -345,7 +341,7 @@ class EventsStore(
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
yield self.runInteraction(
yield self.db.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
@ -432,7 +428,7 @@ class EventsStore(
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@ -580,12 +576,12 @@ class EventsStore(
self, txn, new_forward_extremities, max_stream_order
):
for room_id, new_extrem in iteritems(new_forward_extremities):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
@ -598,7 +594,7 @@ class EventsStore(
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
@ -722,7 +718,7 @@ class EventsStore(
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
@ -794,7 +790,7 @@ class EventsStore(
d.pop("redacted_because", None)
return d
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="event_json",
values=[
@ -811,7 +807,7 @@ class EventsStore(
],
)
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="events",
values=[
@ -841,7 +837,7 @@ class EventsStore(
# If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as
# requiring censoring.
self.simple_update_txn(
self.db.simple_update_txn(
txn,
table="redactions",
keyvalues={"redacts": event.event_id},
@ -983,7 +979,7 @@ class EventsStore(
state_values.append(vals)
self.simple_insert_many_txn(txn, table="state_events", values=state_values)
self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@ -1014,7 +1010,7 @@ class EventsStore(
)
txn.execute(sql + clause, args)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
@ -1032,7 +1028,7 @@ class EventsStore(
# invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="redactions",
values={
@ -1077,7 +1073,9 @@ class EventsStore(
LIMIT ?
"""
rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100)
rows = yield self.db.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
)
updates = []
@ -1109,14 +1107,14 @@ class EventsStore(
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
table="redactions",
keyvalues={"event_id": redaction_id},
updatevalues={"have_censored": True},
)
yield self.runInteraction("_update_censor_txn", _update_censor_txn)
yield self.db.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json):
"""Censor an event by replacing its JSON in the event_json table with the
@ -1127,7 +1125,7 @@ class EventsStore(
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
"""
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
table="event_json",
keyvalues={"event_id": event_id},
@ -1153,7 +1151,7 @@ class EventsStore(
(count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_messages", _count_messages)
ret = yield self.db.runInteraction("count_messages", _count_messages)
return ret
@defer.inlineCallbacks
@ -1174,7 +1172,7 @@ class EventsStore(
(count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
return ret
@defer.inlineCallbacks
@ -1189,7 +1187,7 @@ class EventsStore(
(count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_daily_active_rooms", _count)
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret
def get_current_backfill_token(self):
@ -1241,7 +1239,7 @@ class EventsStore(
return new_event_updates
return self.runInteraction(
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
@ -1286,7 +1284,7 @@ class EventsStore(
return new_event_updates
return self.runInteraction(
return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@ -1379,7 +1377,7 @@ class EventsStore(
backward_ex_outliers,
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
def purge_history(self, room_id, token, delete_local_events):
"""Deletes room history before a certain point
@ -1399,7 +1397,7 @@ class EventsStore(
deleted events.
"""
return self.runInteraction(
return self.db.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@ -1647,7 +1645,7 @@ class EventsStore(
Deferred[List[int]]: The list of state groups to delete.
"""
return self.runInteraction("purge_room", self._purge_room_txn, room_id)
return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before
@ -1766,7 +1764,7 @@ class EventsStore(
to delete.
"""
return self.runInteraction(
return self.db.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
@ -1778,7 +1776,7 @@ class EventsStore(
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
rows = self.simple_select_many_txn(
rows = self.db.simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
@ -1805,15 +1803,15 @@ class EventsStore(
curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
curr_state = curr_state[sg]
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table="state_group_edges", keyvalues={"state_group": sg}
)
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@ -1850,7 +1848,7 @@ class EventsStore(
state group.
"""
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,
@ -1869,7 +1867,7 @@ class EventsStore(
state_groups_to_delete (list[int]): State groups to delete
"""
return self.runInteraction(
return self.db.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
@ -1880,7 +1878,7 @@ class EventsStore(
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
self.simple_delete_many_txn(
self.db.simple_delete_many_txn(
txn,
table="state_groups_state",
column="state_group",
@ -1891,7 +1889,7 @@ class EventsStore(
# ... and the state group edges
logger.info("[purge] removing %s from state_group_edges", room_id)
self.simple_delete_many_txn(
self.db.simple_delete_many_txn(
txn,
table="state_group_edges",
column="state_group",
@ -1902,7 +1900,7 @@ class EventsStore(
# ... and the state groups
logger.info("[purge] removing %s from state_groups", room_id)
self.simple_delete_many_txn(
self.db.simple_delete_many_txn(
txn,
table="state_groups",
column="id",
@ -1919,7 +1917,7 @@ class EventsStore(
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
res = yield self.simple_select_one(
res = yield self.db.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@ -1942,7 +1940,7 @@ class EventsStore(
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.runInteraction(
return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@ -1960,7 +1958,7 @@ class EventsStore(
room_id (str): The ID of the room the event was sent to.
topological_ordering (int): The position of the event in the room's topology.
"""
return self.simple_insert_many_txn(
return self.db.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@ -1982,7 +1980,7 @@ class EventsStore(
event_id (str): The event ID the expiry timestamp is associated with.
expiry_ts (int): The timestamp at which to expire (delete) the event.
"""
return self.simple_insert_txn(
return self.db.simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
@ -2031,7 +2029,7 @@ class EventsStore(
txn, "_get_event_cache", (event.event_id,)
)
yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
def _delete_event_expiry_txn(self, txn, event_id):
"""Delete the expiry timestamp associated with an event ID without deleting the
@ -2041,7 +2039,7 @@ class EventsStore(
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
"""
return self.simple_delete_txn(
return self.db.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)
@ -2065,7 +2063,7 @@ class EventsStore(
return txn.fetchone()
return self.runInteraction(
return self.db.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)

View file

@ -22,13 +22,12 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventContentFields
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
logger = logging.getLogger(__name__)
class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@ -37,15 +36,15 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"event_contains_url_index",
index_name="event_contains_url_index",
table="events",
@ -56,7 +55,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
# an event_id index on event_search is useful for the purge_history
# api. Plus it means we get to enforce some integrity with a UNIQUE
# clause
self.register_background_index_update(
self.db.updates.register_background_index_update(
"event_search_event_id_idx",
index_name="event_search_event_id_idx",
table="event_search",
@ -65,16 +64,16 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
psql_only=True,
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"redactions_received_ts", self._redactions_received_ts
)
# This index gets deleted in `event_fix_redactions_bytes` update
self.register_background_index_update(
self.db.updates.register_background_index_update(
"event_fix_redactions_bytes_create_index",
index_name="redactions_censored_redacts",
table="redactions",
@ -82,11 +81,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
where_clause="have_censored",
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"event_fix_redactions_bytes", self._event_fix_redactions_bytes
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"event_store_labels", self._event_store_labels
)
@ -145,18 +144,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(rows),
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
result = yield self.db.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
yield self.db.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
@ -189,7 +190,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
ev_rows = self.simple_select_many_txn(
ev_rows = self.db.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
@ -222,18 +223,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(rows_to_update),
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows_to_update)
result = yield self.runInteraction(
result = yield self.db.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
yield self.db.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
@ -366,7 +369,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
to_delete.intersection_update(original_set)
deleted = self.simple_delete_many_txn(
deleted = self.db.simple_delete_many_txn(
txn=txn,
table="event_forward_extremities",
column="event_id",
@ -382,7 +385,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if deleted:
# We now need to invalidate the caches of these rooms
rows = self.simple_select_many_txn(
rows = self.db.simple_select_many_txn(
txn,
table="events",
column="event_id",
@ -396,7 +399,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
self.simple_delete_many_txn(
self.db.simple_delete_many_txn(
txn=txn,
table="_extremities_to_check",
column="event_id",
@ -406,17 +409,19 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(original_set)
num_handled = yield self.runInteraction(
num_handled = yield self.db.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
yield self.db.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
yield self.runInteraction(
yield self.db.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
@ -464,18 +469,18 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, "redactions_received_ts", {"last_event_id": upper_event_id}
)
return len(rows)
count = yield self.runInteraction(
count = yield self.db.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
yield self._end_background_update("redactions_received_ts")
yield self.db.updates._end_background_update("redactions_received_ts")
return count
@ -501,11 +506,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
txn.execute("DROP INDEX redactions_censored_redacts")
yield self.runInteraction(
yield self.db.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
yield self._end_background_update("event_fix_redactions_bytes")
yield self.db.updates._end_background_update("event_fix_redactions_bytes")
return 1
@ -533,7 +538,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
try:
event_json = json.loads(event_json_raw)
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@ -559,17 +564,17 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
nbrows += 1
last_row_event_id = event_id
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, "event_store_labels", {"last_event_id": last_row_event_id}
)
return nbrows
num_rows = yield self.runInteraction(
num_rows = yield self.db.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
yield self._end_background_update("event_store_labels")
yield self.db.updates._end_background_update("event_store_labels")
return num_rows

View file

@ -78,7 +78,7 @@ class EventsWorkerStore(SQLBaseStore):
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
@ -117,7 +117,7 @@ class EventsWorkerStore(SQLBaseStore):
return ts
return self.runInteraction(
return self.db.runInteraction(
"get_approximate_received_ts", _get_approximate_received_ts_txn
)
@ -452,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore):
event_id for events, _ in event_list for event_id in events
)
row_dict = self.new_transaction(
row_dict = self.db.new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
@ -584,7 +584,7 @@ class EventsWorkerStore(SQLBaseStore):
if should_start:
run_as_background_process(
"fetch_events", self.runWithConnection, self._do_fetch
"fetch_events", self.db.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events: %s", len(events), events)
@ -745,7 +745,7 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@ -780,7 +780,9 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
yield self.db.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
def _get_total_state_event_counts_txn(self, txn, room_id):
@ -807,7 +809,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[int]
"""
return self.runInteraction(
return self.db.runInteraction(
"get_total_state_event_counts",
self._get_total_state_event_counts_txn,
room_id,
@ -832,7 +834,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[int]
"""
return self.runInteraction(
return self.db.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,

View file

@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
def_json = yield self.simple_select_one_onecol(
def_json = yield self.db.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
return self.runInteraction("add_user_filter", _do_txn)
return self.db.runInteraction("add_user_filter", _do_txn)

View file

@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore):
* "invite"
* "open"
"""
return self.simple_update_one(
return self.db.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore):
)
def get_group(self, group_id):
return self.simple_select_one(
return self.db.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
return self.simple_select_list(
return self.db.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore):
def get_invited_users_in_group(self, group_id):
# TODO: Pagination
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
return self.simple_select_list(
return self.db.simple_select_list(
table="group_rooms",
keyvalues=keyvalues,
retcols=("room_id", "is_public"),
@ -153,10 +153,12 @@ class GroupServerStore(SQLBaseStore):
return rooms, categories
return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
return self.db.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
return self.runInteraction(
return self.db.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@ -180,7 +182,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the room first. Otherwise, the room gets
added to the end.
"""
room_in_group = self.simple_select_one_onecol_txn(
room_in_group = self.db.simple_select_one_onecol_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
@ -193,7 +195,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
else:
cat_exists = self.simple_select_one_onecol_txn(
cat_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@ -204,7 +206,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Category doesn't exist")
# TODO: Check category is part of summary already
cat_exists = self.simple_select_one_onecol_txn(
cat_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_summary_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@ -224,7 +226,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, category_id, group_id, category_id),
)
existing = self.simple_select_one_txn(
existing = self.db.simple_select_one_txn(
txn,
table="group_summary_rooms",
keyvalues={
@ -257,7 +259,7 @@ class GroupServerStore(SQLBaseStore):
to_update["room_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
self.simple_update_txn(
self.db.simple_update_txn(
txn,
table="group_summary_rooms",
keyvalues={
@ -271,7 +273,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="group_summary_rooms",
values={
@ -287,7 +289,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
return self.simple_delete(
return self.db.simple_delete(
table="group_summary_rooms",
keyvalues={
"group_id": group_id,
@ -299,7 +301,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_categories(self, group_id):
rows = yield self.simple_select_list(
rows = yield self.db.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"),
@ -316,7 +318,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_category(self, group_id, category_id):
category = yield self.simple_select_one(
category = yield self.db.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"),
@ -343,7 +345,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
return self.simple_upsert(
return self.db.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@ -352,7 +354,7 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_category(self, group_id, category_id):
return self.simple_delete(
return self.db.simple_delete(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
@ -360,7 +362,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_roles(self, group_id):
rows = yield self.simple_select_list(
rows = yield self.db.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"),
@ -377,7 +379,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_role(self, group_id, role_id):
role = yield self.simple_select_one(
role = yield self.db.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"),
@ -404,7 +406,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
return self.simple_upsert(
return self.db.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@ -413,14 +415,14 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_role(self, group_id, role_id):
return self.simple_delete(
return self.db.simple_delete(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
)
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
return self.runInteraction(
return self.db.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@ -444,7 +446,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the user first. Otherwise, the user gets
added to the end.
"""
user_in_group = self.simple_select_one_onecol_txn(
user_in_group = self.db.simple_select_one_onecol_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@ -457,7 +459,7 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
else:
role_exists = self.simple_select_one_onecol_txn(
role_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@ -468,7 +470,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Role doesn't exist")
# TODO: Check role is part of the summary already
role_exists = self.simple_select_one_onecol_txn(
role_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_summary_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@ -488,7 +490,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, role_id, group_id, role_id),
)
existing = self.simple_select_one_txn(
existing = self.db.simple_select_one_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@ -517,7 +519,7 @@ class GroupServerStore(SQLBaseStore):
to_update["user_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
self.simple_update_txn(
self.db.simple_update_txn(
txn,
table="group_summary_users",
keyvalues={
@ -531,7 +533,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="group_summary_users",
values={
@ -547,7 +549,7 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
return self.simple_delete(
return self.db.simple_delete(
table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
@ -561,7 +563,7 @@ class GroupServerStore(SQLBaseStore):
Deferred[list[str]]: A twisted.Deferred containing a list of group ids
containing this room
"""
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
@ -625,12 +627,12 @@ class GroupServerStore(SQLBaseStore):
return users, roles
return self.runInteraction(
return self.db.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
def is_user_in_group(self, user_id, group_id):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@ -639,7 +641,7 @@ class GroupServerStore(SQLBaseStore):
).addCallback(lambda r: bool(r))
def is_user_admin_in_group(self, group_id, user_id):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@ -650,7 +652,7 @@ class GroupServerStore(SQLBaseStore):
def add_group_invite(self, group_id, user_id):
"""Record that the group server has invited a user
"""
return self.simple_insert(
return self.db.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
@ -659,7 +661,7 @@ class GroupServerStore(SQLBaseStore):
def is_user_invited_to_local_group(self, group_id, user_id):
"""Has the group server invited a user?
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@ -682,7 +684,7 @@ class GroupServerStore(SQLBaseStore):
"""
def _get_users_membership_in_group_txn(txn):
row = self.simple_select_one_txn(
row = self.db.simple_select_one_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@ -697,7 +699,7 @@ class GroupServerStore(SQLBaseStore):
"is_privileged": row["is_admin"],
}
row = self.simple_select_one_onecol_txn(
row = self.db.simple_select_one_onecol_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
@ -710,7 +712,7 @@ class GroupServerStore(SQLBaseStore):
return {}
return self.runInteraction(
return self.db.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
@ -738,7 +740,7 @@ class GroupServerStore(SQLBaseStore):
"""
def _add_user_to_group_txn(txn):
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="group_users",
values={
@ -749,14 +751,14 @@ class GroupServerStore(SQLBaseStore):
},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
if local_attestation:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@ -766,7 +768,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@ -777,49 +779,49 @@ class GroupServerStore(SQLBaseStore):
},
)
return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
return self.runInteraction(
return self.db.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
def add_room_to_group(self, group_id, room_id, is_public):
return self.simple_insert(
return self.db.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
def update_room_in_group_visibility(self, group_id, room_id, is_public):
return self.simple_update(
return self.db.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
@ -828,26 +830,26 @@ class GroupServerStore(SQLBaseStore):
def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_summary_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
return self.runInteraction(
return self.db.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
def get_publicised_groups_for_user(self, user_id):
"""Get all groups a user is publicising
"""
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
@ -857,7 +859,7 @@ class GroupServerStore(SQLBaseStore):
def update_group_publicity(self, group_id, user_id, publicise):
"""Update whether the user is publicising their membership of the group
"""
return self.simple_update_one(
return self.db.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
@ -893,12 +895,12 @@ class GroupServerStore(SQLBaseStore):
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="local_group_membership",
values={
@ -911,7 +913,7 @@ class GroupServerStore(SQLBaseStore):
},
)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="local_group_updates",
values={
@ -930,7 +932,7 @@ class GroupServerStore(SQLBaseStore):
if membership == "join":
if local_attestation:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@ -940,7 +942,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@ -951,12 +953,12 @@ class GroupServerStore(SQLBaseStore):
},
)
else:
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
@ -965,7 +967,7 @@ class GroupServerStore(SQLBaseStore):
return next_id
with self._group_updates_id_gen.get_next() as next_id:
res = yield self.runInteraction(
res = yield self.db.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
next_id,
@ -976,7 +978,7 @@ class GroupServerStore(SQLBaseStore):
def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
):
yield self.simple_insert(
yield self.db.simple_insert(
table="groups",
values={
"group_id": group_id,
@ -991,7 +993,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def update_group_profile(self, group_id, profile):
yield self.simple_update_one(
yield self.db.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
@ -1008,16 +1010,16 @@ class GroupServerStore(SQLBaseStore):
WHERE valid_until_ms <= ?
"""
txn.execute(sql, (valid_until_ms,))
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
return self.runInteraction(
return self.db.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
def update_attestation_renewal(self, group_id, user_id, attestation):
"""Update an attestation that we have renewed
"""
return self.simple_update_one(
return self.db.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@ -1027,7 +1029,7 @@ class GroupServerStore(SQLBaseStore):
def update_remote_attestion(self, group_id, user_id, attestation):
"""Update an attestation that a remote has renewed
"""
return self.simple_update_one(
return self.db.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
@ -1046,7 +1048,7 @@ class GroupServerStore(SQLBaseStore):
group_id (str)
user_id (str)
"""
return self.simple_delete(
return self.db.simple_delete(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
@ -1057,7 +1059,7 @@ class GroupServerStore(SQLBaseStore):
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
row = yield self.simple_select_one(
row = yield self.db.simple_select_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"),
@ -1072,7 +1074,7 @@ class GroupServerStore(SQLBaseStore):
return None
def get_joined_groups(self, user_id):
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
@ -1099,7 +1101,7 @@ class GroupServerStore(SQLBaseStore):
for row in txn
]
return self.runInteraction(
return self.db.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
@ -1129,7 +1131,7 @@ class GroupServerStore(SQLBaseStore):
for group_id, membership, gtype, content_json in txn
]
return self.runInteraction(
return self.db.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
@ -1154,7 +1156,7 @@ class GroupServerStore(SQLBaseStore):
for stream_id, group_id, user_id, gtype, content_json in txn
]
return self.runInteraction(
return self.db.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn
)
@ -1188,8 +1190,8 @@ class GroupServerStore(SQLBaseStore):
]
for table in tables:
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table=table, keyvalues={"group_id": group_id}
)
return self.runInteraction("delete_group", _delete_group_txn)
return self.db.runInteraction("delete_group", _delete_group_txn)

View file

@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
return self.runInteraction("get_server_verify_keys", _txn)
return self.db.runInteraction("get_server_verify_keys", _txn)
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers.
@ -127,9 +127,9 @@ class KeyStore(SQLBaseStore):
f((i,))
return res
return self.runInteraction(
return self.db.runInteraction(
"store_server_verify_keys",
self.simple_upsert_many_txn,
self.db.simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore):
ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON.
"""
return self.simple_upsert(
return self.db.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore):
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
rows = self.simple_select_list_txn(
rows = self.db.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
@ -211,4 +211,4 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows
return results
return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)

View file

@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage._base import SQLBaseStore
class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.db.updates.register_background_index_update(
update_name="local_media_repository_url_idx",
index_name="local_media_repository_url_idx",
table="local_media_repository",
@ -39,7 +39,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
Returns:
None if the media_id doesn't exist.
"""
return self.simple_select_one(
return self.db.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@ -64,7 +64,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id,
url_cache=None,
):
return self.simple_insert(
return self.db.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@ -124,12 +124,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
return self.runInteraction("get_url_cache", get_url_cache_txn)
return self.db.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
return self.simple_insert(
return self.db.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@ -144,7 +144,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
def get_local_media_thumbnails(self, media_id):
return self.simple_select_list(
return self.db.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@ -166,7 +166,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
return self.simple_insert(
return self.db.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@ -180,7 +180,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
def get_cached_remote_media(self, origin, media_id):
return self.simple_select_one(
return self.db.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@ -205,7 +205,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name,
filesystem_id,
):
return self.simple_insert(
return self.db.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@ -250,10 +250,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
return self.db.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
def get_remote_media_thumbnails(self, origin, media_id):
return self.simple_select_list(
return self.db.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@ -278,7 +280,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
return self.simple_insert(
return self.db.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@ -300,24 +302,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE last_access_ts < ?"
)
return self.execute(
"get_remote_media_before", self.cursor_to_dict, sql, before_ts
return self.db.execute(
"get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
sql = (
@ -331,7 +333,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
return self.db.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
@ -342,7 +346,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
@ -356,7 +360,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
return self.runInteraction(
return self.db.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
@ -373,6 +377,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction(
return self.db.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)

View file

@ -32,7 +32,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
# Do not add more reserved users than the total allowable number
self.new_transaction(
self.db.new_transaction(
dbconn,
"initialise_mau_threepids",
[],
@ -146,7 +146,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn.execute(sql, query_args)
reserved_users = yield self.get_registered_reserved_users()
yield self.runInteraction(
yield self.db.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
# It seems poor to invalidate the whole cache, Postgres supports
@ -174,7 +174,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
return self.db.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
def get_registered_reserved_users(self):
@ -217,7 +217,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support:
return
yield self.runInteraction(
yield self.db.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
@ -261,7 +261,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
is_insert = self.simple_upsert_txn(
is_insert = self.db.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
@ -281,7 +281,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",

View file

@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
return self.simple_insert(
return self.db.simple_insert(
table="open_id_tokens",
values={
"token": token,
@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
return self.db.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)

View file

@ -29,7 +29,7 @@ class PresenceStore(SQLBaseStore):
)
with stream_ordering_manager as stream_orderings:
yield self.runInteraction(
yield self.db.runInteraction(
"update_presence",
self._update_presence_txn,
stream_orderings,
@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore):
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="presence_stream",
values=[
@ -88,7 +88,7 @@ class PresenceStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
return self.runInteraction(
return self.db.runInteraction(
"get_all_presence_updates", get_all_presence_updates_txn
)
@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_presence_for_users(self, user_ids):
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore):
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self.simple_insert(
return self.db.simple_insert(
table="presence_allow_inbound",
values={
"observed_user_id": observed_localpart,
@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore):
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
return self.simple_delete_one(
return self.db.simple_delete_one(
table="presence_allow_inbound",
keyvalues={
"observed_user_id": observed_localpart,

View file

@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
profile = yield self.simple_select_one(
profile = yield self.db.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_displayname(self, user_localpart):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_avatar_url(self, user_localpart):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_from_remote_profile_cache(self, user_id):
return self.simple_select_one(
return self.db.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore):
)
def create_profile(self, user_localpart):
return self.simple_insert(
return self.db.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
def set_profile_displayname(self, user_localpart, new_displayname):
return self.simple_update_one(
return self.db.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
return self.simple_update_one(
return self.db.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore):
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
return self.simple_upsert(
return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
return self.simple_update(
return self.db.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore):
"""
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
yield self.simple_delete(
yield self.db.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore):
txn.execute(sql, (last_checked,))
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
return self.runInteraction(
return self.db.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
if res:
return True
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",

View file

@ -75,7 +75,7 @@ class PushRulesWorkerStore(
def __init__(self, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
push_rules_prefill, push_rules_id = self.get_cache_dict(
push_rules_prefill, push_rules_id = self.db.get_cache_dict(
db_conn,
"push_rules_stream",
entity_column="user_id",
@ -100,7 +100,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self.simple_select_list(
rows = yield self.db.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@ -124,7 +124,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self.simple_select_list(
results = yield self.db.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@ -146,7 +146,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
return self.runInteraction(
return self.db.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
@ -162,7 +162,7 @@ class PushRulesWorkerStore(
results = {user_id: [] for user_id in user_ids}
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@ -320,7 +320,7 @@ class PushRulesWorkerStore(
results = {user_id: {} for user_id in user_ids}
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@ -350,7 +350,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
yield self.runInteraction(
yield self.db.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@ -364,7 +364,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
yield self.runInteraction(
yield self.db.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@ -395,7 +395,7 @@ class PushRuleStore(PushRulesWorkerStore):
relative_to_rule = before or after
res = self.simple_select_one_txn(
res = self.db.simple_select_one_txn(
txn,
table="push_rules",
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
@ -518,7 +518,7 @@ class PushRuleStore(PushRulesWorkerStore):
# We didn't update a row with the given rule_id so insert one
push_rule_id = self._push_rule_id_gen.get_next()
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="push_rules",
values={
@ -561,7 +561,7 @@ class PushRuleStore(PushRulesWorkerStore):
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
self.simple_delete_one_txn(
self.db.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
yield self.db.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
@ -582,7 +582,7 @@ class PushRuleStore(PushRulesWorkerStore):
def set_push_rule_enabled(self, user_id, rule_id, enabled):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
yield self.db.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@ -596,7 +596,7 @@ class PushRuleStore(PushRulesWorkerStore):
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next()
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
"push_rules_enable",
{"user_name": user_id, "rule_id": rule_id},
@ -636,7 +636,7 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
"push_rules",
{"user_name": user_id, "rule_id": rule_id},
@ -655,7 +655,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
yield self.db.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
@ -675,7 +675,7 @@ class PushRuleStore(PushRulesWorkerStore):
if data is not None:
values.update(data)
self.simple_insert_txn(txn, "push_rules_stream", values=values)
self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
@ -699,7 +699,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.runInteraction(
return self.db.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)

View file

@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
ret = yield self.simple_select_one_onecol(
ret = yield self.db.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
ret = yield self.simple_select_list(
ret = yield self.db.simple_select_list(
"pushers",
keyvalues,
[
@ -100,11 +100,11 @@ class PusherWorkerStore(SQLBaseStore):
def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
rows = yield self.runInteraction("get_all_pushers", get_pushers)
rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
@ -134,7 +134,7 @@ class PusherWorkerStore(SQLBaseStore):
return updated, deleted
return self.runInteraction(
return self.db.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@ -177,7 +177,7 @@ class PusherWorkerStore(SQLBaseStore):
return results
return self.runInteraction(
return self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids):
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@ -230,7 +230,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
yield self.simple_upsert(
yield self.db.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@ -255,7 +255,7 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
yield self.runInteraction(
yield self.db.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
self.simple_delete_one_txn(
self.db.simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="deleted_pushers",
values={
@ -290,13 +290,13 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
@defer.inlineCallbacks
def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
):
yield self.simple_update_one(
yield self.db.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore):
Returns:
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
"""
updated = yield self.simple_update(
updated = yield self.db.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore):
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self.simple_update(
yield self.db.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore):
@defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id):
res = yield self.simple_select_list(
res = yield self.db.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@ -362,7 +362,7 @@ class PusherStore(PusherWorkerStore):
def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
yield self.simple_upsert(
yield self.db.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,

View file

@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self.simple_select_list(
return self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@ -70,7 +70,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@ -84,7 +84,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self.simple_select_list(
rows = yield self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@ -108,7 +108,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
return {
row[0]: {
"event_id": row[1],
@ -187,11 +187,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (room_id, to_key))
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
return rows
rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@ -237,9 +237,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args))
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
txn_results = yield self.db.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results = {}
for row in txn_results:
@ -282,7 +284,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return list(r[0:5] + (json.loads(r[5]),) for r in txn)
return self.runInteraction(
return self.db.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
@ -335,7 +337,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
res = self.simple_select_one_txn(
res = self.db.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
@ -388,7 +390,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
@ -398,7 +400,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
},
)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="receipts_linearized",
values={
@ -453,13 +455,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.runInteraction(
linearized_event_id = yield self.db.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
event_ts = yield self.runInteraction(
event_ts = yield self.db.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@ -488,7 +490,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.runInteraction(
return self.db.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@ -514,7 +516,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
@ -523,7 +525,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
},
)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="receipts_graph",
values={

View file

@ -26,7 +26,6 @@ from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@ -45,7 +44,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
return self.simple_select_one(
return self.db.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@ -94,7 +93,7 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
"""
return self.runInteraction(
return self.db.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@ -109,7 +108,7 @@ class RegistrationWorkerStore(SQLBaseStore):
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
@ -137,7 +136,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def set_account_validity_for_user_txn(txn):
self.simple_update_txn(
self.db.simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
@ -151,7 +150,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,)
)
yield self.runInteraction(
yield self.db.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn
)
@ -167,7 +166,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Raises:
StoreError: The provided token is already set for another user.
"""
yield self.simple_update_one(
yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
@ -184,7 +183,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
@ -203,7 +202,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
@ -229,9 +228,9 @@ class RegistrationWorkerStore(SQLBaseStore):
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
res = yield self.runInteraction(
res = yield self.db.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
@ -250,7 +249,7 @@ class RegistrationWorkerStore(SQLBaseStore):
email_sent (bool): Flag which indicates whether a renewal email has been sent
to this user.
"""
yield self.simple_update_one(
yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
@ -265,7 +264,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Args:
user_id (str): ID of the user to remove from the account validity table.
"""
yield self.simple_delete_one(
yield self.db.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
@ -281,7 +280,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@ -299,7 +298,7 @@ class RegistrationWorkerStore(SQLBaseStore):
admin (bool): true iff the user is to be a server admin,
false otherwise.
"""
return self.simple_update_one(
return self.db.simple_update_one(
table="users",
keyvalues={"name": user.to_string()},
updatevalues={"admin": 1 if admin else 0},
@ -316,7 +315,7 @@ class RegistrationWorkerStore(SQLBaseStore):
)
txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]
@ -332,7 +331,9 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user 'user_type' is null or empty string
"""
res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
res = yield self.db.runInteraction(
"is_real_user", self.is_real_user_txn, user_id
)
return res
@cachedInlineCallbacks()
@ -345,13 +346,13 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
"""
res = yield self.runInteraction(
res = yield self.db.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
return res
def is_real_user_txn(self, txn, user_id):
res = self.simple_select_one_onecol_txn(
res = self.db.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@ -361,7 +362,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return res is None
def is_support_user_txn(self, txn, user_id):
res = self.simple_select_one_onecol_txn(
res = self.db.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@ -380,7 +381,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
return self.db.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@ -394,7 +395,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: the mxid of the user, or None if they are not known
"""
return await self.simple_select_one_onecol(
return await self.db.simple_select_one_onecol(
table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id",
@ -408,12 +409,12 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
ret = yield self.db.runInteraction("count_users", _count_users)
return ret
def count_daily_user_type(self):
@ -445,7 +446,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
@ -459,7 +460,7 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
ret = yield self.db.runInteraction("count_users", _count_users)
return ret
@defer.inlineCallbacks
@ -468,12 +469,12 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_real_users", _count_users)
ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret
@defer.inlineCallbacks
@ -503,7 +504,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return (
(
yield self.runInteraction(
yield self.db.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
@ -520,7 +521,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
user_id = yield self.db.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
return user_id
@ -536,7 +537,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
ret = self.simple_select_one_txn(
ret = self.db.simple_select_one_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
@ -549,7 +550,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self.simple_upsert(
yield self.db.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@ -557,7 +558,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self.simple_select_list(
ret = yield self.db.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
@ -566,7 +567,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret
def user_delete_threepid(self, user_id, medium, address):
return self.simple_delete(
return self.db.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
desc="user_delete_threepid",
@ -579,7 +580,7 @@ class RegistrationWorkerStore(SQLBaseStore):
user_id: The user id to delete all threepids of
"""
return self.simple_delete(
return self.db.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id},
desc="user_delete_threepids",
@ -601,7 +602,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
# We need to use an upsert, in case they user had already bound the
# threepid
return self.simple_upsert(
return self.db.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@ -627,7 +628,7 @@ class RegistrationWorkerStore(SQLBaseStore):
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
return self.simple_select_list(
return self.db.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
@ -648,7 +649,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred
"""
return self.simple_delete(
return self.db.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@ -671,7 +672,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
@ -689,7 +690,7 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.Deferred(bool): The requested value.
"""
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@ -756,13 +757,13 @@ class RegistrationWorkerStore(SQLBaseStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return None
return rows[0]
return self.runInteraction(
return self.db.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
@ -776,39 +777,37 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def delete_threepid_session_txn(txn):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
)
return self.runInteraction(
return self.db.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
class RegistrationBackgroundUpdateStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, db_conn, hs):
super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
self.register_background_index_update(
self.db.updates.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"users_creation_ts",
index_name="users_creation_ts",
table="users",
@ -818,13 +817,13 @@ class RegistrationBackgroundUpdateStore(
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
self.register_noop_background_update("refresh_tokens_device_index")
self.db.updates.register_noop_background_update("refresh_tokens_device_index")
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"user_threepids_grandfather", self._bg_user_threepids_grandfather
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
@ -857,7 +856,7 @@ class RegistrationBackgroundUpdateStore(
(last_user, batch_size),
)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return True, 0
@ -871,7 +870,7 @@ class RegistrationBackgroundUpdateStore(
logger.info("Marked %d rows as deactivated", rows_processed_nb)
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
)
@ -880,12 +879,12 @@ class RegistrationBackgroundUpdateStore(
else:
return False, len(rows)
end, nb_processed = yield self.runInteraction(
end, nb_processed = yield self.db.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
yield self._end_background_update("users_set_deactivated_flag")
yield self.db.updates._end_background_update("users_set_deactivated_flag")
return nb_processed
@ -911,11 +910,11 @@ class RegistrationBackgroundUpdateStore(
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
yield self.db.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
yield self.db.updates._end_background_update("user_threepids_grandfather")
return 1
@ -961,7 +960,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
next_id = self._access_tokens_id_gen.get_next()
yield self.simple_insert(
yield self.db.simple_insert(
"access_tokens",
{
"id": next_id,
@ -1003,7 +1002,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Raises:
StoreError if the user_id could not be registered.
"""
return self.runInteraction(
return self.db.runInteraction(
"register_user",
self._register_user,
user_id,
@ -1037,7 +1036,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception
# if the row isn't in the database.
self.simple_select_one_txn(
self.db.simple_select_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@ -1045,7 +1044,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
allow_none=False,
)
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@ -1059,7 +1058,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
else:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
"users",
values={
@ -1114,7 +1113,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
return self.simple_insert(
return self.db.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
@ -1132,12 +1131,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def user_set_password_hash_txn(txn):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
return self.db.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@ -1152,7 +1153,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@ -1160,7 +1161,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_consent_version", f)
return self.db.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
"""Updates the user table to record that we have sent the user a server
@ -1176,7 +1177,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@ -1184,7 +1185,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_consent_server_notice_sent", f)
return self.db.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
@ -1230,11 +1231,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
return self.runInteraction("user_delete_access_tokens", f)
return self.db.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
self.simple_delete_one_txn(
self.db.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@ -1242,11 +1243,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f)
return self.db.runInteraction("delete_access_token", f)
@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self.simple_select_one_onecol(
res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@ -1261,7 +1262,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
return self.simple_insert(
return self.db.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
@ -1274,7 +1275,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
return self.simple_delete(
return self.db.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
@ -1285,7 +1286,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@ -1315,7 +1316,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
row = self.simple_select_one_txn(
row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@ -1333,7 +1334,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
400, "This client_secret does not match the provided session_id"
)
row = self.simple_select_one_txn(
row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
@ -1358,7 +1359,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Looks good. Validate the session
self.simple_update_txn(
self.db.simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@ -1368,7 +1369,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
return self.runInteraction(
return self.db.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
@ -1401,7 +1402,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
if validated_at:
insertion_values["validated_at"] = validated_at
return self.simple_upsert(
return self.db.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
@ -1439,7 +1440,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@ -1452,7 +1453,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Create a new validation token with this session ID
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="threepid_validation_token",
values={
@ -1463,7 +1464,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
return self.runInteraction(
return self.db.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
@ -1478,7 +1479,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
return txn.execute(sql, (ts,))
return self.runInteraction(
return self.db.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
@ -1493,7 +1494,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
deactivated (bool): The value to set for `deactivated`.
"""
yield self.runInteraction(
yield self.db.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
@ -1501,7 +1502,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@ -1529,14 +1530,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.execute(sql, [])
res = self.cursor_to_dict(txn)
res = self.db.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)
yield self.runInteraction(
yield self.db.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)
@ -1560,7 +1561,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expiration_ts,
)
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
"account_validity",
keyvalues={"user_id": user_id},

View file

@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def _store_rejections_txn(self, txn, event_id, reason):
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="rejections",
values={
@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore):
)
def get_rejection_reason(self, event_id):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},

View file

@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return self.runInteraction(
return self.db.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return self.runInteraction(
return self.db.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore):
if row:
return row[0]
edit_id = yield self.runInteraction(
edit_id = yield self.db.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
)
@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone())
return self.runInteraction(
return self.db.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore):
aggregation_key = relation.get("key")
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="event_relations",
values={
@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore):
redacted_event_id (str): The event that was redacted.
"""
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)

View file

@ -28,7 +28,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@ -54,7 +53,7 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
A dict containing the room information, or None if the room is unknown.
"""
return self.simple_select_one(
return self.db.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@ -63,7 +62,7 @@ class RoomWorkerStore(SQLBaseStore):
)
def get_public_room_ids(self):
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="rooms",
keyvalues={"is_public": True},
retcol="room_id",
@ -120,7 +119,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
return self.runInteraction("count_public_rooms", _count_public_rooms_txn)
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
@defer.inlineCallbacks
def get_largest_public_rooms(
@ -253,21 +252,21 @@ class RoomWorkerStore(SQLBaseStore):
def _get_largest_public_rooms_txn(txn):
txn.execute(sql, query_args)
results = self.cursor_to_dict(txn)
results = self.db.cursor_to_dict(txn)
if not forwards:
results.reverse()
return results
ret_val = yield self.runInteraction(
ret_val = yield self.db.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
defer.returnValue(ret_val)
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
@ -288,7 +287,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self.simple_select_one(
row = yield self.db.simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
@ -330,9 +329,9 @@ class RoomWorkerStore(SQLBaseStore):
(room_id,),
)
return self.cursor_to_dict(txn)
return self.db.cursor_to_dict(txn)
ret = yield self.runInteraction(
ret = yield self.db.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
)
@ -361,13 +360,13 @@ class RoomWorkerStore(SQLBaseStore):
defer.returnValue(row)
class RoomBackgroundUpdateStore(BackgroundUpdateStore):
class RoomBackgroundUpdateStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs)
self.config = hs.config
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"insert_room_retention", self._background_insert_retention,
)
@ -396,7 +395,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
(last_room, batch_size),
)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return True
@ -408,7 +407,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
ev = json.loads(row["json"])
retention_policy = json.dumps(ev["content"])
self.simple_insert_txn(
self.db.simple_insert_txn(
txn=txn,
table="room_retention",
values={
@ -421,7 +420,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
logger.info("Inserted %d rows into room_retention", len(rows))
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
)
@ -430,12 +429,12 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
else:
return False
end = yield self.runInteraction(
end = yield self.db.runInteraction(
"insert_room_retention", _background_insert_retention_txn,
)
if end:
yield self._end_background_update("insert_room_retention")
yield self.db.updates._end_background_update("insert_room_retention")
defer.returnValue(batch_size)
@ -461,7 +460,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
try:
def store_room_txn(txn, next_id):
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
"rooms",
{
@ -471,7 +470,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
if is_public:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@ -482,7 +481,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction("store_room_txn", store_room_txn, next_id)
yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@ -490,14 +489,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"is_public": is_public},
)
entries = self.simple_select_list_txn(
entries = self.db.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@ -515,7 +514,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@ -528,7 +527,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
yield self.db.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
@ -555,7 +554,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="appservice_room_list",
values={
@ -568,7 +567,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
# We've already inserted, nothing to do.
return
else:
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="appservice_room_list",
keyvalues={
@ -578,7 +577,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
entries = self.simple_select_list_txn(
entries = self.db.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@ -596,7 +595,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@ -609,7 +608,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
yield self.db.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
next_id,
@ -626,7 +625,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
return self.runInteraction("get_rooms", f)
return self.db.runInteraction("get_rooms", f)
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
@ -660,7 +659,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
# Ignore the event if one of the value isn't an integer.
return
self.simple_insert_txn(
self.db.simple_insert_txn(
txn=txn,
table="room_retention",
values={
@ -679,7 +678,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self, room_id, event_id, user_id, reason, content, received_ts
):
next_id = self._event_reports_id_gen.get_next()
return self.simple_insert(
return self.db.simple_insert(
table="event_reports",
values={
"id": next_id,
@ -712,7 +711,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
if prev_id == current_id:
return defer.succeed([])
return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
return self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
@ -725,14 +726,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Returns:
Deferred
"""
yield self.simple_upsert(
yield self.db.simple_upsert(
table="blocked_rooms",
keyvalues={"room_id": room_id},
values={},
insertion_values={"user_id": user_id},
desc="block_room",
)
yield self.runInteraction(
yield self.db.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked,
@ -763,7 +764,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return local_media_mxcs, remote_media_mxcs
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
return self.db.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
@ -802,7 +805,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return total_media_quarantined
return self.runInteraction(
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@ -907,7 +910,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
rooms_dict = {}
for row in rows:
@ -923,7 +926,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql)
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
@ -936,7 +939,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict
rooms = yield self.runInteraction(
rooms = yield self.db.runInteraction(
"get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn,
)

View file

@ -26,8 +26,11 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage._base import (
LoggingTransaction,
SQLBaseStore,
make_in_list_sql_clause,
)
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@ -116,7 +119,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
count = yield self.runInteraction("get_known_servers", _transact)
count = yield self.db.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@ -128,7 +131,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership column is up to date
"""
pending_update = self.simple_select_one_txn(
pending_update = self.db.simple_select_one_txn(
txn,
table="background_updates",
keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
@ -144,7 +147,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
15.0,
run_as_background_process,
"_check_safe_current_state_events_membership_updated",
self.runInteraction,
self.db.runInteraction,
"_check_safe_current_state_events_membership_updated",
self._check_safe_current_state_events_membership_updated_txn,
)
@ -161,7 +164,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
return self.runInteraction(
return self.db.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@ -269,7 +272,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
return self.runInteraction("get_room_summary", _get_room_summary_txn)
return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
def _get_user_counts_in_room_txn(self, txn, room_id):
"""
@ -339,7 +342,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if not membership_list:
return defer.succeed(None)
rooms = yield self.runInteraction(
rooms = yield self.db.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
user_id,
@ -392,7 +395,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
if do_invite:
sql = (
@ -412,7 +415,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
)
for r in self.cursor_to_dict(txn)
for r in self.db.cursor_to_dict(txn)
)
return results
@ -603,7 +606,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
to `user_id` and ProfileInfo (or None if not join event).
"""
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@ -643,7 +646,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause)
rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
if not rows:
return False
@ -683,7 +686,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause)
rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
if not rows:
return False
@ -753,7 +756,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = txn.fetchall()
return rows[0][0]
count = yield self.runInteraction("did_forget_membership", f)
count = yield self.db.runInteraction("did_forget_membership", f)
return count == 0
@cached()
@ -790,7 +793,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return set(row[0] for row in txn if row[1] == 0)
return self.runInteraction(
return self.db.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@ -805,7 +808,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred[set[str]]: Set of room IDs.
"""
room_ids = yield self.simple_select_onecol(
room_ids = yield self.db.simple_select_onecol(
table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id",
@ -820,7 +823,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Get user_id and membership of a set of event IDs.
"""
return self.simple_select_many_batch(
return self.db.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@ -831,17 +834,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
self._background_current_state_membership,
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
"room_membership_forgotten_idx",
index_name="room_memberships_user_room_forgotten",
table="room_memberships",
@ -874,7 +877,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return 0
@ -909,18 +912,20 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
"max_stream_id_exclusive": min_stream_id,
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
result = yield self.db.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
)
if not result:
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
yield self.db.updates._end_background_update(
_MEMBERSHIP_PROFILE_UPDATE_NAME
)
return result
@ -959,7 +964,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
last_processed_room = next_room
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn,
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
{"last_processed_room": last_processed_room},
@ -971,14 +976,16 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
# string, which will compare before all room IDs correctly.
last_processed_room = progress.get("last_processed_room", "")
row_count, finished = yield self.runInteraction(
row_count, finished = yield self.db.runInteraction(
"_background_current_state_membership_update",
_background_current_state_membership_txn,
last_processed_room,
)
if finished:
yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
yield self.db.updates._end_background_update(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
)
return row_count
@ -990,7 +997,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="room_memberships",
values=[
@ -1028,7 +1035,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="local_invites",
values={
@ -1068,7 +1075,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
@ -1091,7 +1098,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
return self.db.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):

View file

@ -24,8 +24,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
logger = logging.getLogger(__name__)
@ -36,7 +35,7 @@ SearchEntry = namedtuple(
)
class SearchBackgroundUpdateStore(BackgroundUpdateStore):
class SearchBackgroundUpdateStore(SQLBaseStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@ -49,10 +48,10 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
if not hs.config.enable_search:
return
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
)
@ -61,9 +60,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
# a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the
# background update.
self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
self.db.updates.register_noop_background_update(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
@ -93,7 +94,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
rows = self.cursor_to_dict(txn)
rows = self.db.cursor_to_dict(txn)
if not rows:
return 0
@ -153,18 +154,18 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(event_search_rows),
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
)
return len(event_search_rows)
result = yield self.runInteraction(
result = yield self.db.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
return result
@ -206,9 +207,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
yield self.runWithConnection(create_index)
yield self.db.runWithConnection(create_index)
yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
yield self.db.updates._end_background_update(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
)
return 1
@defer.inlineCallbacks
@ -237,14 +240,14 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
)
conn.set_session(autocommit=False)
yield self.runWithConnection(create_index)
yield self.db.runWithConnection(create_index)
pg = dict(progress)
pg["have_added_indexes"] = True
yield self.runInteraction(
yield self.db.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_update_progress_txn,
self.db.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
pg,
)
@ -274,18 +277,20 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
"have_added_indexes": True,
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
)
return len(rows), True
num_rows, finished = yield self.runInteraction(
num_rows, finished = yield self.db.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)
if not finished:
yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
yield self.db.updates._end_background_update(
self.EVENT_SEARCH_ORDER_UPDATE_NAME
)
return num_rows
@ -441,7 +446,9 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args)
results = yield self.db.execute(
"search_msgs", self.db.cursor_to_dict, sql, *args
)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
@ -455,8 +462,8 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
count_results = yield self.execute(
"search_rooms_count", self.cursor_to_dict, count_sql, *count_args
count_results = yield self.db.execute(
"search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@ -586,7 +593,9 @@ class SearchStore(SearchBackgroundUpdateStore):
args.append(limit)
results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args)
results = yield self.db.execute(
"search_rooms", self.db.cursor_to_dict, sql, *args
)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
@ -600,8 +609,8 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
count_results = yield self.execute(
"search_rooms_count", self.cursor_to_dict, count_sql, *count_args
count_results = yield self.db.execute(
"search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@ -686,7 +695,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
return self.runInteraction("_find_highlights", f)
return self.db.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):

View file

@ -48,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore):
for event_id in event_ids
}
return self.runInteraction("get_event_reference_hashes", f)
return self.db.runInteraction("get_event_reference_hashes", f)
@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore):
}
)
self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)

View file

@ -27,7 +27,6 @@ from synapse.api.errors import NotFoundError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
@ -89,7 +88,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
count = 0
while next_group:
next_group = self.simple_select_one_onecol_txn(
next_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@ -192,7 +191,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
):
break
next_group = self.simple_select_one_onecol_txn(
next_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@ -348,7 +347,9 @@ class StateGroupWorkerStore(
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
return self.db.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@ -392,7 +393,7 @@ class StateGroupWorkerStore(
return results
return self.runInteraction(
return self.db.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@ -431,7 +432,7 @@ class StateGroupWorkerStore(
"""
def _get_state_group_delta_txn(txn):
prev_group = self.simple_select_one_onecol_txn(
prev_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
@ -442,7 +443,7 @@ class StateGroupWorkerStore(
if not prev_group:
return _GetStateGroupDelta(None, None)
delta_ids = self.simple_select_list_txn(
delta_ids = self.db.simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
@ -454,7 +455,9 @@ class StateGroupWorkerStore(
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
return self.db.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn
)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@ -540,7 +543,7 @@ class StateGroupWorkerStore(
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
res = yield self.db.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn,
chunk,
@ -644,7 +647,7 @@ class StateGroupWorkerStore(
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
@ -661,7 +664,7 @@ class StateGroupWorkerStore(
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
@ -902,7 +905,7 @@ class StateGroupWorkerStore(
state_group = self.database_engine.get_next_state_group_id(txn)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="state_groups",
values={"id": state_group, "room_id": room_id, "event_id": event_id},
@ -911,7 +914,7 @@ class StateGroupWorkerStore(
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
is_in_db = self.simple_select_one_onecol_txn(
is_in_db = self.db.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
@ -926,13 +929,13 @@ class StateGroupWorkerStore(
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="state_group_edges",
values={"state_group": state_group, "prev_state_group": prev_group},
)
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@ -947,7 +950,7 @@ class StateGroupWorkerStore(
],
)
else:
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@ -993,7 +996,7 @@ class StateGroupWorkerStore(
return state_group
return self.runInteraction("store_state_group", _store_state_group_txn)
return self.db.runInteraction("store_state_group", _store_state_group_txn)
@defer.inlineCallbacks
def get_referenced_state_groups(self, state_groups):
@ -1007,7 +1010,7 @@ class StateGroupWorkerStore(
referenced.
"""
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
@ -1019,9 +1022,7 @@ class StateGroupWorkerStore(
return set(row["state_group"] for row in rows)
class StateBackgroundUpdateStore(
StateGroupBackgroundUpdateStore, BackgroundUpdateStore
):
class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
@ -1030,21 +1031,21 @@ class StateBackgroundUpdateStore(
def __init__(self, db_conn, hs):
super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
self.register_background_index_update(
self.db.updates.register_background_index_update(
self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
index_name="event_to_state_groups_sg_index",
table="event_to_state_groups",
@ -1065,7 +1066,7 @@ class StateBackgroundUpdateStore(
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
rows = yield self.execute(
rows = yield self.db.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
@ -1135,13 +1136,13 @@ class StateBackgroundUpdateStore(
if prev_state.get(key, None) != value
}
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
)
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="state_group_edges",
values={
@ -1150,13 +1151,13 @@ class StateBackgroundUpdateStore(
},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
)
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@ -1177,18 +1178,18 @@ class StateBackgroundUpdateStore(
"max_group": max_group,
}
self._background_update_progress_txn(
self.db.updates._background_update_progress_txn(
txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
)
return False, batch_size
finished, result = yield self.runInteraction(
finished, result = yield self.db.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
)
if finished:
yield self._end_background_update(
yield self.db.updates._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
@ -1218,9 +1219,9 @@ class StateBackgroundUpdateStore(
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
yield self.runWithConnection(reindex_txn)
yield self.db.runWithConnection(reindex_txn)
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
return 1
@ -1263,7 +1264,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
state_groups[event.event_id] = context.state_group
self.simple_insert_many_txn(
self.db.simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[

View file

@ -98,14 +98,14 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.cursor_to_dict(txn)
return clipped_stream_id, self.db.cursor_to_dict(txn)
return self.runInteraction(
return self.db.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
return self.simple_select_one_onecol_txn(
return self.db.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
keyvalues={},
@ -113,7 +113,7 @@ class StateDeltasStore(SQLBaseStore):
)
def get_max_stream_id_in_current_state_deltas(self):
return self.runInteraction(
return self.db.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)

View file

@ -68,17 +68,17 @@ class StatsStore(StateDeltasStore):
self.stats_delta_processing_lock = DeferredLock()
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"populate_stats_process_rooms", self._populate_stats_process_rooms
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"populate_stats_process_users", self._populate_stats_process_users
)
# we no longer need to perform clean-up, but we will give ourselves
# the potential to reintroduce it in the future so documentation
# will still encourage the use of this no-op handler.
self.register_noop_background_update("populate_stats_cleanup")
self.register_noop_background_update("populate_stats_prepare")
self.db.updates.register_noop_background_update("populate_stats_cleanup")
self.db.updates.register_noop_background_update("populate_stats_prepare")
def quantise_stats_time(self, ts):
"""
@ -102,7 +102,7 @@ class StatsStore(StateDeltasStore):
This is a background update which regenerates statistics for users.
"""
if not self.stats_enabled:
yield self._end_background_update("populate_stats_process_users")
yield self.db.updates._end_background_update("populate_stats_process_users")
return 1
last_user_id = progress.get("last_user_id", "")
@ -117,22 +117,22 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn]
users_to_work_on = yield self.runInteraction(
users_to_work_on = yield self.db.runInteraction(
"_populate_stats_process_users", _get_next_batch
)
# No more rooms -- complete the transaction.
if not users_to_work_on:
yield self._end_background_update("populate_stats_process_users")
yield self.db.updates._end_background_update("populate_stats_process_users")
return 1
for user_id in users_to_work_on:
yield self._calculate_and_set_initial_state_for_user(user_id)
progress["last_user_id"] = user_id
yield self.runInteraction(
yield self.db.runInteraction(
"populate_stats_process_users",
self._background_update_progress_txn,
self.db.updates._background_update_progress_txn,
"populate_stats_process_users",
progress,
)
@ -145,7 +145,7 @@ class StatsStore(StateDeltasStore):
This is a background update which regenerates statistics for rooms.
"""
if not self.stats_enabled:
yield self._end_background_update("populate_stats_process_rooms")
yield self.db.updates._end_background_update("populate_stats_process_rooms")
return 1
last_room_id = progress.get("last_room_id", "")
@ -160,22 +160,22 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn]
rooms_to_work_on = yield self.runInteraction(
rooms_to_work_on = yield self.db.runInteraction(
"populate_stats_rooms_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
yield self._end_background_update("populate_stats_process_rooms")
yield self.db.updates._end_background_update("populate_stats_process_rooms")
return 1
for room_id in rooms_to_work_on:
yield self._calculate_and_set_initial_state_for_room(room_id)
progress["last_room_id"] = room_id
yield self.runInteraction(
yield self.db.runInteraction(
"_populate_stats_process_rooms",
self._background_update_progress_txn,
self.db.updates._background_update_progress_txn,
"populate_stats_process_rooms",
progress,
)
@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore):
"""
Returns the stats processor positions.
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@ -215,7 +215,7 @@ class StatsStore(StateDeltasStore):
if field and "\0" in field:
fields[col] = None
return self.simple_upsert(
return self.db.simple_upsert(
table="room_stats_state",
keyvalues={"room_id": room_id},
values=fields,
@ -236,7 +236,7 @@ class StatsStore(StateDeltasStore):
Deferred[list[dict]], where the dict has the keys of
ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
"""
return self.runInteraction(
return self.db.runInteraction(
"get_statistics_for_subject",
self._get_statistics_for_subject_txn,
stats_type,
@ -257,7 +257,7 @@ class StatsStore(StateDeltasStore):
ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
)
slice_list = self.simple_select_list_paginate_txn(
slice_list = self.db.simple_select_list_paginate_txn(
txn,
table + "_historical",
"end_ts",
@ -282,7 +282,7 @@ class StatsStore(StateDeltasStore):
"name", "topic", "canonical_alias", "avatar", "join_rules",
"history_visibility"
"""
return self.simple_select_one(
return self.db.simple_select_one(
"room_stats_state",
{"room_id": room_id},
retcols=(
@ -308,7 +308,7 @@ class StatsStore(StateDeltasStore):
"""
table, id_col = TYPE_TO_TABLE[stats_type]
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
@ -344,14 +344,14 @@ class StatsStore(StateDeltasStore):
complete_with_stream_id=stream_id,
)
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": stream_id},
)
return self.runInteraction(
return self.db.runInteraction(
"bulk_update_stats_delta", _bulk_update_stats_delta_txn
)
@ -382,7 +382,7 @@ class StatsStore(StateDeltasStore):
Does not work with per-slice fields.
"""
return self.runInteraction(
return self.db.runInteraction(
"update_stats_delta",
self._update_stats_delta_txn,
ts,
@ -517,17 +517,17 @@ class StatsStore(StateDeltasStore):
else:
self.database_engine.lock_table(txn, table)
retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
current_row = self.simple_select_one_txn(
current_row = self.db.simple_select_one_txn(
txn, table, keyvalues, retcols, allow_none=True
)
if current_row is None:
merged_dict = {**keyvalues, **absolutes, **additive_relatives}
self.simple_insert_txn(txn, table, merged_dict)
self.db.simple_insert_txn(txn, table, merged_dict)
else:
for (key, val) in additive_relatives.items():
current_row[key] += val
current_row.update(absolutes)
self.simple_update_one_txn(txn, table, keyvalues, current_row)
self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
def _upsert_copy_from_table_with_additive_relatives_txn(
self,
@ -614,11 +614,11 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, qargs)
else:
self.database_engine.lock_table(txn, into_table)
src_row = self.simple_select_one_txn(
src_row = self.db.simple_select_one_txn(
txn, src_table, keyvalues, copy_columns
)
all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
dest_current_row = self.simple_select_one_txn(
dest_current_row = self.db.simple_select_one_txn(
txn,
into_table,
keyvalues=all_dest_keyvalues,
@ -634,11 +634,11 @@ class StatsStore(StateDeltasStore):
**src_row,
**additive_relatives,
}
self.simple_insert_txn(txn, into_table, merged_dict)
self.db.simple_insert_txn(txn, into_table, merged_dict)
else:
for (key, val) in additive_relatives.items():
src_row[key] = dest_current_row[key] + val
self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
"""Fetches the counts of events in the given range of stream IDs.
@ -652,7 +652,7 @@ class StatsStore(StateDeltasStore):
changes.
"""
return self.runInteraction(
return self.db.runInteraction(
"stats_incremental_total_events_and_bytes",
self.get_changes_room_total_events_and_bytes_txn,
min_pos,
@ -735,7 +735,7 @@ class StatsStore(StateDeltasStore):
def _fetch_current_state_stats(txn):
pos = self.get_room_max_stream_ordering()
rows = self.simple_select_many_txn(
rows = self.db.simple_select_many_txn(
txn,
table="current_state_events",
column="type",
@ -791,7 +791,7 @@ class StatsStore(StateDeltasStore):
current_state_events_count,
users_in_room,
pos,
) = yield self.runInteraction(
) = yield self.db.runInteraction(
"get_initial_state_for_room", _fetch_current_state_stats
)
@ -866,7 +866,7 @@ class StatsStore(StateDeltasStore):
(count,) = txn.fetchone()
return count, pos
joined_rooms, pos = yield self.runInteraction(
joined_rooms, pos = yield self.db.runInteraction(
"calculate_and_set_initial_state_for_user",
_calculate_and_set_initial_state_for_user_txn,
)

View file

@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
super(StreamWorkerStore, self).__init__(db_conn, hs)
events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self.get_cache_dict(
event_cache_prefill, min_event_val = self.db.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
@ -400,7 +400,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
@ -450,7 +450,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
rows = yield self.runInteraction("get_membership_changes_for_user", f)
rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
rows, token = yield self.runInteraction(
rows, token = yield self.db.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@ -548,7 +548,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
return self.runInteraction("get_room_event_after_stream_ordering", _f)
return self.db.runInteraction("get_room_event_after_stream_ordering", _f)
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
@ -562,7 +562,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if room_id is None:
return "s%d" % (token,)
else:
topo = yield self.runInteraction(
topo = yield self.db.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
return "t%d-%d" % (topo, token)
@ -576,7 +576,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "s%d" stream token.
"""
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
@ -589,7 +589,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "t%d-%d" topological token.
"""
return self.simple_select_one(
return self.db.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
@ -613,7 +613,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
return self.execute(
return self.db.execute(
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
@ -667,7 +667,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
results = yield self.runInteraction(
results = yield self.db.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@ -709,7 +709,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
results = self.simple_select_one_txn(
results = self.db.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
@ -788,7 +788,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
upper_bound, event_ids = yield self.db.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
@ -797,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, events
def get_federation_out_pos(self, typ):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
@ -805,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
def update_federation_out_pos(self, typ, stream_id):
return self.simple_update_one(
return self.db.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ},
updatevalues={"stream_id": stream_id},
@ -956,7 +956,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if to_key:
to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.runInteraction(
rows, token = yield self.db.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,

View file

@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag strings to tag content.
"""
deferred = self.simple_select_list(
deferred = self.db.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@ -78,7 +78,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
tag_ids = yield self.runInteraction(
tag_ids = yield self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
@ -98,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
tags = yield self.runInteraction(
tags = yield self.db.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
@ -135,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed:
return {}
room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
room_ids = yield self.db.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
results = {}
if room_ids:
@ -153,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A deferred list of string tags.
"""
return self.simple_select_list(
return self.db.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
@ -178,7 +180,7 @@ class TagsStore(TagsWorkerStore):
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@ -187,7 +189,7 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@ -210,7 +212,7 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))

View file

@ -77,7 +77,7 @@ class TransactionStore(SQLBaseStore):
this transaction or a 2-tuple of (int, dict)
"""
return self.runInteraction(
return self.db.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@ -85,7 +85,7 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
result = self.simple_select_one_txn(
result = self.db.simple_select_one_txn(
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
@ -119,7 +119,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
return self.simple_insert(
return self.db.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
@ -148,7 +148,7 @@ class TransactionStore(SQLBaseStore):
if result is not SENTINEL:
return result
result = yield self.runInteraction(
result = yield self.db.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
@ -160,7 +160,7 @@ class TransactionStore(SQLBaseStore):
return result
def _get_destination_retry_timings(self, txn, destination):
result = self.simple_select_one_txn(
result = self.db.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@ -187,7 +187,7 @@ class TransactionStore(SQLBaseStore):
"""
self._destination_retry_cache.pop(destination, None)
return self.runInteraction(
return self.db.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
@ -227,7 +227,7 @@ class TransactionStore(SQLBaseStore):
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
prev_row = self.simple_select_one_txn(
prev_row = self.db.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@ -236,7 +236,7 @@ class TransactionStore(SQLBaseStore):
)
if not prev_row:
self.simple_insert_txn(
self.db.simple_insert_txn(
txn,
table="destinations",
values={
@ -247,7 +247,7 @@ class TransactionStore(SQLBaseStore):
},
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
self.simple_update_one_txn(
self.db.simple_update_one_txn(
txn,
"destinations",
keyvalues={"destination": destination},
@ -270,4 +270,6 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
return self.db.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)

View file

@ -19,7 +19,6 @@ import re
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.state import StateFilter
from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@ -32,7 +31,7 @@ logger = logging.getLogger(__name__)
TEMP_TABLE = "_temp_populate_user_directory"
class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
@ -43,19 +42,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
self.server_name = hs.hostname
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"populate_user_directory_createtables",
self._populate_user_directory_createtables,
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"populate_user_directory_process_rooms",
self._populate_user_directory_process_rooms,
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"populate_user_directory_process_users",
self._populate_user_directory_process_users,
)
self.register_background_update_handler(
self.db.updates.register_background_update_handler(
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
@ -85,7 +84,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
txn.execute(sql)
rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
# If search all users is on, get all the users we want to add.
@ -100,15 +99,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.execute("SELECT name FROM users")
users = [{"user_id": x[0]} for x in txn.fetchall()]
self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = yield self.get_max_stream_id_in_current_state_deltas()
yield self.runInteraction(
yield self.db.runInteraction(
"populate_user_directory_temp_build", _make_staging_area
)
yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
yield self._end_background_update("populate_user_directory_createtables")
yield self.db.updates._end_background_update(
"populate_user_directory_createtables"
)
return 1
@defer.inlineCallbacks
@ -116,7 +117,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
Update the user directory stream position, then clean up the old tables.
"""
position = yield self.simple_select_one_onecol(
position = yield self.db.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
yield self.update_user_directory_stream_pos(position)
@ -126,11 +127,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
yield self.runInteraction(
yield self.db.runInteraction(
"populate_user_directory_cleanup", _delete_staging_area
)
yield self._end_background_update("populate_user_directory_cleanup")
yield self.db.updates._end_background_update("populate_user_directory_cleanup")
return 1
@defer.inlineCallbacks
@ -170,13 +171,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
return rooms_to_work_on
rooms_to_work_on = yield self.runInteraction(
rooms_to_work_on = yield self.db.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
yield self._end_background_update("populate_user_directory_process_rooms")
yield self.db.updates._end_background_update(
"populate_user_directory_process_rooms"
)
return 1
logger.info(
@ -243,12 +246,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
to_insert.clear()
# We've finished a room. Delete it from the table.
yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
# Update the remaining counter.
progress["remaining"] -= 1
yield self.runInteraction(
yield self.db.runInteraction(
"populate_user_directory",
self._background_update_progress_txn,
self.db.updates._background_update_progress_txn,
"populate_user_directory_process_rooms",
progress,
)
@ -267,7 +270,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
If search_all_users is enabled, add all of the users to the user directory.
"""
if not self.hs.config.user_directory_search_all_users:
yield self._end_background_update("populate_user_directory_process_users")
yield self.db.updates._end_background_update(
"populate_user_directory_process_users"
)
return 1
def _get_next_batch(txn):
@ -291,13 +296,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
return users_to_work_on
users_to_work_on = yield self.runInteraction(
users_to_work_on = yield self.db.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more users -- complete the transaction.
if not users_to_work_on:
yield self._end_background_update("populate_user_directory_process_users")
yield self.db.updates._end_background_update(
"populate_user_directory_process_users"
)
return 1
logger.info(
@ -312,12 +319,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
# We've finished processing a user. Delete it from the table.
yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
# Update the remaining counter.
progress["remaining"] -= 1
yield self.runInteraction(
yield self.db.runInteraction(
"populate_user_directory",
self._background_update_progress_txn,
self.db.updates._background_update_progress_txn,
"populate_user_directory_process_users",
progress,
)
@ -361,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
def _update_profile_in_user_dir_txn(txn):
new_entry = self.simple_upsert_txn(
new_entry = self.db.simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
@ -435,7 +442,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name) if display_name else user_id
self.simple_upsert_txn(
self.db.simple_upsert_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
@ -448,7 +455,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction(
return self.db.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
@ -462,7 +469,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
def _add_users_who_share_room_txn(txn):
self.simple_upsert_many_txn(
self.db.simple_upsert_many_txn(
txn,
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
@ -474,7 +481,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
value_values=None,
)
return self.runInteraction(
return self.db.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
@ -489,7 +496,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
def _add_users_in_public_rooms_txn(txn):
self.simple_upsert_many_txn(
self.db.simple_upsert_many_txn(
txn,
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
@ -498,7 +505,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
value_values=None,
)
return self.runInteraction(
return self.db.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
@ -513,13 +520,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.runInteraction(
return self.db.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
return self.simple_select_one(
return self.db.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@ -528,7 +535,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
def update_user_directory_stream_pos(self, stream_id):
return self.simple_update_one(
return self.db.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
@ -547,42 +554,42 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_share_pub = yield self.simple_select_onecol(
user_ids_share_pub = yield self.db.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
user_ids_share_priv = yield self.simple_select_onecol(
user_ids_share_priv = yield self.db.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
@ -605,23 +612,23 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"""
def _remove_user_who_share_room_txn(txn):
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id, "room_id": room_id},
)
self.simple_delete_txn(
self.db.simple_delete_txn(
txn,
table="users_in_public_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
return self.runInteraction(
return self.db.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@ -636,14 +643,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
Returns:
list: user_id
"""
rows = yield self.simple_select_onecol(
rows = yield self.db.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
desc="get_rooms_user_is_in",
)
pub_rows = yield self.simple_select_onecol(
pub_rows = yield self.db.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
@ -674,14 +681,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
) f2 USING (room_id)
"""
rows = yield self.execute(
rows = yield self.db.execute(
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
return [room_id for room_id, in rows]
def get_user_directory_stream_pos(self):
return self.simple_select_one_onecol(
return self.db.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
@ -786,7 +793,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args)
results = yield self.db.execute(
"search_user_dir", self.db.cursor_to_dict, sql, *args
)
limited = len(results) > limit

View file

@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if the user has requested erasure
"""
return self.simple_select_onecol(
return self.db.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
rows = yield self.simple_select_many_batch(
rows = yield self.db.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.runInteraction("mark_user_erased", f)
return self.db.runInteraction("mark_user_erased", f)

1496
synapse/storage/database.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None):
stor = hs.get_datastore()
# Run the database background updates.
if hasattr(stor, "do_next_background_update"):
while not await stor.has_completed_background_updates():
await stor.do_next_background_update(1)
if hasattr(stor.db.updates, "do_next_background_update"):
while not await stor.db.updates.has_completed_background_updates():
await stor.db.updates.do_next_background_update(1)
def cleanup():
for i in cleanup_tasks:

View file

@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
self.store._all_done = False
self.store.db.updates._all_done = False
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
def get_all_room_state(self):
return self.store.simple_select_list(
return self.store.db.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
return self.get_success(
self.store.simple_select_one(
self.store.db.simple_select_one(
table + "_historical",
{id_col: stat_id, end_ts: end_ts},
cols,
@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the stats via the background update
self._add_background_updates()
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
def test_initial_room(self):
"""
@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
r = self.get_success(self.get_all_room_state())
@ -178,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
self.store._all_done = False
self.store.db.updates._all_done = False
self.get_success(
self.store.simple_update_one(
self.store.db.simple_update_one(
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": 0},
@ -188,14 +196,18 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
# Now, before the table is actually ingested, add some more events.
self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
@ -205,13 +217,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Now do the initial ingestion.
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.store._all_done = False
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
self.store.db.updates._all_done = False
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
self.reactor.advance(86401)
@ -653,15 +669,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# preparation stage of the initial background update
# Ugh, have to reset this flag
self.store._all_done = False
self.store.db.updates._all_done = False
self.get_success(
self.store.simple_delete(
self.store.db.simple_delete(
"room_stats_current", {"1": 1}, "test_delete_stats"
)
)
self.get_success(
self.store.simple_delete(
self.store.db.simple_delete(
"user_stats_current", {"1": 1}, "test_delete_stats"
)
)
@ -673,9 +689,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# now do the background updates
self.store._all_done = False
self.store.db.updates._all_done = False
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
@ -685,7 +701,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
@ -695,7 +711,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@ -705,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
r1stats_complete = self._get_current_stats("room", r1)
u1stats_complete = self._get_current_stats("user", u1)

View file

@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_in_public_rooms(self):
r = self.get_success(
self.store.simple_select_list(
self.store.db.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
)
)
@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_who_share_private_rooms(self):
return self.get_success(
self.store.simple_select_list(
self.store.db.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
@ -181,10 +181,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
self.store._all_done = False
self.store.db.updates._all_done = False
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_createtables",
@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_rooms",
@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_users",
@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_cleanup",
@ -255,8 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
public_users = self.get_users_in_public_rooms()
@ -290,8 +294,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
public_users = self.get_users_in_public_rooms()

View file

@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
"state_groups_state",
):
count = self.get_success(
self.store.simple_select_one_onecol(
self.store.db.simple_select_one_onecol(
table=table,
keyvalues={"room_id": room_id},
retcol="COUNT(*)",

View file

@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
self.table_name = "table_" + hs.get_secrets().token_hex(6)
self.get_success(
self.storage.runInteraction(
self.storage.db.runInteraction(
"create",
lambda x, *a: x.execute(*a),
"CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
self.get_success(
self.storage.runInteraction(
self.storage.db.runInteraction(
"index",
lambda x, *a: x.execute(*a),
"CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["hello"], ["there"]]
self.get_success(
self.storage.runInteraction(
self.storage.db.runInteraction(
"test",
self.storage.simple_upsert_many_txn,
self.storage.db.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
self.storage.simple_select_list(
self.storage.db.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["bleb"]]
self.get_success(
self.storage.runInteraction(
self.storage.db.runInteraction(
"test",
self.storage.simple_upsert_many_txn,
self.storage.db.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
self.storage.simple_select_list(
self.storage.db.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)

View file

@ -15,7 +15,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
self.update_handler = Mock()
yield self.store.register_background_update_handler(
yield self.store.db.updates.register_background_update_handler(
"test_update", self.update_handler
)
@ -23,7 +23,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
# (perhaps we should run them as part of the test HS setup, since we
# run all of the other schema setup stuff there?)
while True:
res = yield self.store.do_next_background_update(1000)
res = yield self.store.db.updates.do_next_background_update(1000)
if res is None:
break
@ -37,9 +37,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
def update(progress, count):
self.clock.advance_time_msec(count * duration_ms)
progress = {"my_key": progress["my_key"] + 1}
yield self.store.runInteraction(
yield self.store.db.runInteraction(
"update_progress",
self.store._background_update_progress_txn,
self.store.db.updates._background_update_progress_txn,
"test_update",
progress,
)
@ -47,29 +47,37 @@ class BackgroundUpdateTestCase(unittest.TestCase):
self.update_handler.side_effect = update
yield self.store.start_background_update("test_update", {"my_key": 1})
yield self.store.db.updates.start_background_update(
"test_update", {"my_key": 1}
)
self.update_handler.reset_mock()
result = yield self.store.do_next_background_update(duration_ms * desired_count)
result = yield self.store.db.updates.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with(
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
{"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
yield self.store.db.updates._end_background_update("test_update")
return count
self.update_handler.side_effect = update
self.update_handler.reset_mock()
result = yield self.store.do_next_background_update(duration_ms * desired_count)
result = yield self.store.db.updates.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
result = yield self.store.do_next_background_update(duration_ms * desired_count)
result = yield self.store.db.updates.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNone(result)
self.assertFalse(self.update_handler.called)

View file

@ -65,7 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
yield self.datastore.simple_insert(
yield self.datastore.db.simple_insert(
table="tablename", values={"columname": "Value"}
)
@ -77,7 +77,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
yield self.datastore.simple_insert(
yield self.datastore.db.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore.simple_select_one_onecol(
value = yield self.datastore.db.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
@ -106,7 +106,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
ret = yield self.datastore.simple_select_one(
ret = yield self.datastore.db.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
@ -122,7 +122,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
ret = yield self.datastore.simple_select_one(
ret = yield self.datastore.db.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
@ -137,7 +137,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
ret = yield self.datastore.simple_select_list(
ret = yield self.datastore.db.simple_select_list(
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
)
@ -150,7 +150,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
yield self.datastore.simple_update_one(
yield self.datastore.db.simple_update_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columnname": "New Value"},
@ -165,7 +165,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
yield self.datastore.simple_update_one(
yield self.datastore.db.simple_update_one(
table="tablename",
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
yield self.datastore.simple_delete_one(
yield self.datastore.db.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
)

View file

@ -46,7 +46,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"""Re run the background update to clean up the extremities.
"""
# Make sure we don't clash with in progress updates.
self.assertTrue(self.store._all_done, "Background updates are still ongoing")
self.assertTrue(
self.store.db.updates._all_done, "Background updates are still ongoing"
)
schema_path = os.path.join(
prepare_database.dir_path,
@ -62,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
prepare_database.executescript(txn, schema_path)
self.get_success(
self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
self.store.db.runInteraction(
"test_delete_forward_extremities", run_delta_file
)
)
# Ugh, have to reset this flag
self.store._all_done = False
self.store.db.updates._all_done = False
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
def test_soft_failed_extremities_handled_correctly(self):
"""Test that extremities are correctly calculated in the presence of

View file

@ -81,7 +81,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
self.store.simple_select_list(
self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@ -112,7 +112,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
self.store.simple_select_list(
self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@ -202,8 +202,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_devices_last_seen_bg_update(self):
# First make sure we have completed all updates.
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
# Insert a user IP
user_id = "@user:id"
@ -218,7 +222,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But clear the associated entry in devices table
self.get_success(
self.store.simple_update(
self.store.db.simple_update(
table="devices",
keyvalues={"user_id": user_id, "device_id": "device_id"},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@ -245,7 +249,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
table="background_updates",
values={
"update_name": "devices_last_seen",
@ -256,11 +260,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
self.store._all_done = False
self.store.db.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
# We should now get the correct result again
result = self.get_success(
@ -281,8 +289,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_old_user_ips_pruned(self):
# First make sure we have completed all updates.
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
# Insert a user IP
user_id = "@user:id"
@ -297,7 +309,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should see that in the DB
result = self.get_success(
self.store.simple_select_list(
self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@ -323,7 +335,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should get no results.
result = self.get_success(
self.store.simple_select_list(
self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],

View file

@ -61,7 +61,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
)
for i in range(0, 11):
yield self.store.runInteraction("insert", insert_event, i)
yield self.store.db.runInteraction("insert", insert_event, i)
# this should get the last five and five others
r = yield self.store.get_prev_events_for_room(room_id)
@ -93,9 +93,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
)
for i in range(0, 20):
yield self.store.runInteraction("insert", insert_event, i, room1)
yield self.store.runInteraction("insert", insert_event, i, room2)
yield self.store.runInteraction("insert", insert_event, i, room3)
yield self.store.db.runInteraction("insert", insert_event, i, room1)
yield self.store.db.runInteraction("insert", insert_event, i, room2)
yield self.store.db.runInteraction("insert", insert_event, i, room3)
# Test simple case
r = yield self.store.get_rooms_with_many_extremities(5, 5, [])

View file

@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
counts = yield self.store.runInteraction(
counts = yield self.store.db.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
self.assertEquals(
@ -74,7 +74,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield self.store.add_push_actions_to_staging(
event.event_id, {user_id: action}
)
yield self.store.runInteraction(
yield self.store.db.runInteraction(
"",
self.store._set_push_actions_for_event_and_users_txn,
[(event, None)],
@ -82,12 +82,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
return self.store.runInteraction(
return self.store.db.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
def _mark_read(stream, depth):
return self.store.runInteraction(
return self.store.db.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
room_id,
@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
yield self.store.simple_delete(
yield self.store.db.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
return self.store.simple_insert(
return self.store.db.simple_insert(
"events",
{
"stream_ordering": so,

View file

@ -65,7 +65,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
self.store.runInteraction(
self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.pump()
@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.hs.config.mau_limits_reserved_threepids = threepids
self.store.runInteraction(
self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
count = self.store.get_monthly_active_count()
@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": user2_email},
]
self.hs.config.mau_limits_reserved_threepids = threepids
self.store.runInteraction(
self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)

View file

@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
event_json = self.get_success(
self.store.simple_select_one_onecol(
self.store.db.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.reactor.advance(60 * 60 * 2)
event_json = self.get_success(
self.store.simple_select_one_onecol(
self.store.db.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",

View file

@ -122,8 +122,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
@ -132,7 +136,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
self.store.simple_insert(
self.store.db.simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
@ -143,8 +147,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
self.store._all_done = False
self.store.db.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
while not self.get_success(
self.store.db.updates.has_completed_background_updates()
):
self.get_success(
self.store.db.updates.do_next_background_update(100), by=0.1
)

View file

@ -401,10 +401,12 @@ class HomeserverTestCase(TestCase):
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
# Run the database background updates.
if hasattr(stor, "do_next_background_update"):
while not self.get_success(stor.has_completed_background_updates()):
self.get_success(stor.do_next_background_update(1))
# Run the database background updates, when running against "master".
if hs.__class__.__name__ == "TestHomeServer":
while not self.get_success(
stor.db.updates.has_completed_background_updates()
):
self.get_success(stor.db.updates.do_next_background_update(1))
return hs
@ -544,7 +546,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
self.hs.get_datastore().simple_insert(
self.hs.get_datastore().db.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",