diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 70bc2bb2cc..f04aad0743 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -20,7 +20,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.expiringcache import ExpiringCache @@ -378,15 +378,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) else: if not devices: continue - sql = ( - "SELECT device_id FROM devices" - " WHERE user_id = ? AND device_id IN (" - + ",".join("?" * len(devices)) - + ")" + + clause, args = make_in_list_sql_clause( + txn.database_engine, "device_id", devices ) + sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause + # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. - txn.execute(sql, [user_id] + devices) + txn.execute(sql, [user_id] + list(args)) for row in txn: # Only insert into the local inbox if the device exists on # this server diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 111bfb3d64..ac5239e50a 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -28,7 +28,12 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import Cache, SQLBaseStore, db_to_json +from synapse.storage._base import ( + Cache, + SQLBaseStore, + db_to_json, + make_in_list_sql_clause, +) from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList @@ -448,11 +453,14 @@ class DeviceWorkerStore(SQLBaseStore): sql = """ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? - AND user_id IN (%s) + AND """ for chunk in batch_iter(to_check, 100): - txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk) + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql + clause, (from_key,) + tuple(args)) changes.update(user_id for user_id, in txn) return changes diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index f5e8c39262..47cc10d32a 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.signatures import SignatureWorkerStore from synapse.util.caches.descriptors import cached @@ -68,7 +68,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: results = set() - base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" + base_sql = "SELECT auth_id FROM event_auth WHERE " front = set(event_ids) while front: @@ -76,7 +76,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas front_list = list(front) chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] for chunk in chunks: - txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk) + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", chunk + ) + txn.execute(base_sql + clause, list(args)) new_front.update([r[0] for r in txn]) new_front -= results diff --git a/synapse/storage/events.py b/synapse/storage/events.py index bb6ff0595a..ee49ef235d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -39,6 +39,7 @@ from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateResolutionStore +from synapse.storage._base import make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore @@ -641,14 +642,16 @@ class EventsStore( LEFT JOIN rejections USING (event_id) LEFT JOIN event_json USING (event_id) WHERE - prev_event_id IN (%s) - AND NOT events.outlier + NOT events.outlier AND rejections.event_id IS NULL - """ % ( - ",".join("?" for _ in batch), + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", batch ) - txn.execute(sql, batch) + txn.execute(sql + clause, args) results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): @@ -695,13 +698,15 @@ class EventsStore( LEFT JOIN rejections USING (event_id) LEFT JOIN event_json USING (event_id) WHERE - event_id IN (%s) - AND NOT events.outlier - """ % ( - ",".join("?" for _ in to_recursively_check), + NOT events.outlier + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", to_recursively_check ) - txn.execute(sql, to_recursively_check) + txn.execute(sql + clause, args) to_recursively_check = [] for event_id, prev_event_id, metadata, rejected in txn: @@ -1543,10 +1548,14 @@ class EventsStore( " FROM events as e" " LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(ev_map)),) + " WHERE " + ) - txn.execute(sql, list(ev_map)) + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", list(ev_map) + ) + + txn.execute(sql + clause, args) rows = self.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] @@ -2249,11 +2258,12 @@ class EventsStore( sql = """ SELECT DISTINCT state_group FROM event_to_state_groups LEFT JOIN events_to_purge AS ep USING (event_id) - WHERE state_group IN (%s) AND ep.event_id IS NULL - """ % ( - ",".join("?" for _ in current_search), + WHERE ep.event_id IS NULL AND + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "state_group", current_search ) - txn.execute(sql, list(current_search)) + txn.execute(sql + clause, list(args)) referenced = set(sg for sg, in txn) referenced_groups |= referenced diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 5717baf48c..97728a6da2 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -21,6 +21,7 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) @@ -312,12 +313,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): INNER JOIN event_json USING (event_id) LEFT JOIN rejections USING (event_id) WHERE - prev_event_id IN (%s) - AND NOT events.outlier - """ % ( - ",".join("?" for _ in to_check), + NOT events.outlier + AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", to_check ) - txn.execute(sql, to_check) + txn.execute(sql + clause, list(args)) for prev_event_id, event_id, metadata, rejected in txn: if event_id in graph: diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 57ce0304e9..4c4b76bd93 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -31,12 +31,11 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.types import get_domain_from_id from synapse.util import batch_iter from synapse.util.metrics import Measure -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) @@ -623,10 +622,14 @@ class EventsWorkerStore(SQLBaseStore): " rej.reason " " FROM event_json as e" " LEFT JOIN rejections as rej USING (event_id)" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(evs)),) + " WHERE " + ) - txn.execute(sql, evs) + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", evs + ) + + txn.execute(sql + clause, args) for row in txn: event_id = row[0] @@ -640,11 +643,11 @@ class EventsWorkerStore(SQLBaseStore): } # check for redactions - redactions_sql = ( - "SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)" - ) % (",".join(["?"] * len(evs)),) + redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " - txn.execute(redactions_sql, evs) + clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) + + txn.execute(redactions_sql + clause, args) for (redacter, redacted) in txn: d = event_dict.get(redacted) @@ -753,10 +756,11 @@ class EventsWorkerStore(SQLBaseStore): results = set() def have_seen_events_txn(txn, chunk): - sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % ( - ",".join("?" * len(chunk)), + sql = "SELECT event_id FROM events as e WHERE " + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", chunk ) - txn.execute(sql, chunk) + txn.execute(sql + clause, args) for (event_id,) in txn: results.add(event_id) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 5db6f2d84a..3a641f538b 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -18,11 +18,10 @@ from collections import namedtuple from twisted.internet import defer from synapse.api.constants import PresenceState +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedList -from ._base import SQLBaseStore - class UserPresenceState( namedtuple( @@ -119,14 +118,13 @@ class PresenceStore(SQLBaseStore): ) # Delete old rows to stop database from getting really big - sql = ( - "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)" - ) + sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " for states in batch_iter(presence_states, 50): - args = [stream_id] - args.extend(s.user_id for s in states) - txn.execute(sql % (",".join("?" for _ in states),), args) + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", [s.user_id for s in states] + ) + txn.execute(sql + clause, [stream_id] + list(args)) def get_all_presence_updates(self, last_id, current_id): if last_id == current_id: diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 290ddb30e8..0c24430f28 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -21,12 +21,11 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import SQLBaseStore -from .util.id_generators import StreamIdGenerator - logger = logging.getLogger(__name__) @@ -217,24 +216,26 @@ class ReceiptsWorkerStore(SQLBaseStore): def f(txn): if from_key: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" - ) % (",".join(["?"] * len(room_ids))) - args = list(room_ids) - args.extend([from_key, to_key]) + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) - txn.execute(sql, args) + txn.execute(sql + clause, [from_key, to_key] + list(args)) else: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id IN (%s) AND stream_id <= ?" - ) % (",".join(["?"] * len(room_ids))) + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? AND + """ - args = list(room_ids) - args.append(to_key) + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) - txn.execute(sql, args) + txn.execute(sql + clause, [to_key] + list(args)) return self.cursor_to_dict(txn) @@ -433,13 +434,19 @@ class ReceiptsStore(ReceiptsWorkerStore): # we need to points in graph -> linearized form. # TODO: Make this better. def graph_to_linear(txn): - query = ( - "SELECT event_id WHERE room_id = ? AND stream_ordering IN (" - " SELECT max(stream_ordering) WHERE event_id IN (%s)" - ")" - ) % (",".join(["?"] * len(event_ids))) + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) - txn.execute(query, [room_id] + event_ids) + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) + + txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4e606a8380..ff63487823 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction +from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.events_worker import EventsWorkerStore @@ -372,6 +372,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): results = [] if membership_list: if self._current_state_events_membership_up_to_date: + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) sql = """ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering FROM current_state_events AS c @@ -379,11 +382,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): WHERE c.type = 'm.room.member' AND state_key = ? - AND c.membership IN (%s) + AND %s """ % ( - ",".join("?" * len(membership_list)) + clause, ) else: + clause, args = make_in_list_sql_clause( + self.database_engine, "m.membership", membership_list + ) sql = """ SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering FROM current_state_events AS c @@ -392,12 +398,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): WHERE c.type = 'm.room.member' AND state_key = ? - AND m.membership IN (%s) + AND %s """ % ( - ",".join("?" * len(membership_list)) + clause, ) - txn.execute(sql, (user_id, *membership_list)) + txn.execute(sql, (user_id, *args)) results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] if do_invite: diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 6ba4190f1a..4be6e56dfa 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -24,6 +24,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.storage._base import add_in_list_sql_clause from synapse.storage.engines import PostgresEngine, Sqlite3Engine from .background_updates import BackgroundUpdateStore @@ -385,8 +386,9 @@ class SearchStore(SearchBackgroundUpdateStore): # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: - clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) - args.extend(room_ids) + add_in_list_sql_clause( + self.database_engine, "room_id", room_ids, clauses, args + ) local_clauses = [] for key in keys: @@ -492,8 +494,9 @@ class SearchStore(SearchBackgroundUpdateStore): # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: - clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) - args.extend(room_ids) + add_in_list_sql_clause( + self.database_engine, "room_id", room_ids, clauses, args + ) local_clauses = [] for key in keys: diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py index 05cabc2282..aa4f0da5f0 100644 --- a/synapse/storage/user_erasure_store.py +++ b/synapse/storage/user_erasure_store.py @@ -56,15 +56,15 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - def _get_erased_users(txn): - txn.execute( - "SELECT user_id FROM erased_users WHERE user_id IN (%s)" - % (",".join("?" * len(user_ids))), - user_ids, - ) - return set(r[0] for r in txn) + rows = yield self._simple_select_many_batch( + table="erased_users", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="are_users_erased", + ) + erased_users = set(row["user_id"] for row in rows) - erased_users = yield self.runInteraction("are_users_erased", _get_erased_users) res = dict((u, u in erased_users) for u in user_ids) return res