diff --git a/changelog.d/9035.doc b/changelog.d/9035.doc new file mode 100644 index 0000000000..2a7f0db518 --- /dev/null +++ b/changelog.d/9035.doc @@ -0,0 +1 @@ +Corrected a typo in the `systemd-with-workers` documentation. diff --git a/changelog.d/9041.misc b/changelog.d/9041.misc new file mode 100644 index 0000000000..4952fbe8a2 --- /dev/null +++ b/changelog.d/9041.misc @@ -0,0 +1 @@ +Various cleanups to device inbox store. diff --git a/changelog.d/9042.feature b/changelog.d/9042.feature new file mode 100644 index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9042.feature @@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9043.feature b/changelog.d/9043.feature new file mode 100644 index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9043.feature @@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9044.feature b/changelog.d/9044.feature new file mode 100644 index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9044.feature @@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/docs/systemd-with-workers/README.md b/docs/systemd-with-workers/README.md index 8e57d4f62e..cfa36be7b4 100644 --- a/docs/systemd-with-workers/README.md +++ b/docs/systemd-with-workers/README.md @@ -31,7 +31,7 @@ There is no need for a separate configuration file for the master process. 1. Adjust synapse configuration files as above. 1. Copy the `*.service` and `*.target` files in [system](system) to `/etc/systemd/system`. -1. Run `systemctl deamon-reload` to tell systemd to load the new unit files. +1. Run `systemctl daemon-reload` to tell systemd to load the new unit files. 1. Run `systemctl enable matrix-synapse.service`. This will configure the synapse master process to be started as part of the `matrix-synapse.target` target. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 5ad17aa90f..22dd169bfb 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -629,6 +629,7 @@ class Porter(object): await self._setup_state_group_id_seq() await self._setup_user_id_seq() await self._setup_events_stream_seqs() + await self._setup_device_inbox_seq() # Step 3. Get tables. self.progress.set_state("Fetching tables") @@ -911,6 +912,32 @@ class Porter(object): "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, ) + async def _setup_device_inbox_seq(self): + """Set the device inbox sequence to the correct value. + """ + curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="device_inbox", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ) + + curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="device_federation_outbox", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ) + + next_id = max(curr_local_id, curr_federation_id) + 1 + + def r(txn): + txn.execute( + "ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,) + ) + + return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r) + ############################################## # The following is simply UI stuff diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fa23d9bb20..4428472707 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -108,6 +108,7 @@ from synapse.rest.client.v2_alpha.account_data import ( ) from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource @@ -520,6 +521,8 @@ class GenericWorkerServer(HomeServer): room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) + SendToDeviceRestServlet(self).register(resource) + user_directory.register_servlets(self, resource) # If presence is disabled, use the stub servlet that does diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 7ca9efec52..364583f48b 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -53,6 +53,9 @@ class WriterLocations: default=["master"], type=List[str], converter=_instance_to_list_converter ) typing = attr.ib(default="master", type=str) + to_device = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -124,7 +127,7 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing"): + for stream in ("events", "typing", "to_device"): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -133,6 +136,11 @@ class WorkerConfig(Config): % (instance, stream) ) + if len(self.writers.to_device) != 1: + raise ConfigError( + "Must only specify one instance to handle `to_device` messages." + ) + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) # Whether this worker should run background tasks or not. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 00a1738e7c..30970f4ad3 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from typing import ( TYPE_CHECKING, Any, @@ -860,8 +861,10 @@ class FederationHandlerRegistry: ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] - # Map from type to instance name that we should route EDU handling to. - self._edu_type_to_instance = {} # type: Dict[str, str] + # Map from type to instance names that we should route EDU handling to. + # We randomly choose one instance from the list to route to for each new + # EDU received. + self._edu_type_to_instance = {} # type: Dict[str, List[str]] def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] @@ -905,7 +908,12 @@ class FederationHandlerRegistry: def register_instance_for_edu(self, edu_type: str, instance_name: str): """Register that the EDU handler is on a different instance than master. """ - self._edu_type_to_instance[edu_type] = instance_name + self._edu_type_to_instance[edu_type] = [instance_name] + + def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): + """Register that the EDU handler is on multiple instances. + """ + self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict): if not self.config.use_presence and edu_type == "m.presence": @@ -928,8 +936,11 @@ class FederationHandlerRegistry: return # Check if we can route it somewhere else that isn't us - route_to = self._edu_type_to_instance.get(edu_type, "master") - if route_to != self._instance_name: + instances = self._edu_type_to_instance.get(edu_type, ["master"]) + if self._instance_name not in instances: + # Pick an instance randomly so that we don't overload one. + route_to = random.choice(instances) + try: await self._send_edu( instance_name=route_to, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 5aa56013a4..109dc7932f 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -24,6 +24,7 @@ from synapse.logging.opentracing import ( set_tag, start_active_span, ) +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string @@ -44,13 +45,37 @@ class DeviceMessageHandler: self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.is_mine = hs.is_mine - self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler( - "m.direct_to_device", self.on_direct_to_device_edu - ) + # We only need to poke the federation sender explicitly if its on the + # same instance. Other federation sender instances will get notified by + # `synapse.app.generic_worker.FederationSenderHandler` when it sees it + # in the to-device replication stream. + self.federation_sender = None + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() - self._device_list_updater = hs.get_device_handler().device_list_updater + # If we can handle the to device EDUs we do so, otherwise we route them + # to the appropriate worker. + if hs.get_instance_name() in hs.config.worker.writers.to_device: + hs.get_federation_registry().register_edu_handler( + "m.direct_to_device", self.on_direct_to_device_edu + ) + else: + hs.get_federation_registry().register_instances_for_edu( + "m.direct_to_device", hs.config.worker.writers.to_device, + ) + + # The handler to call when we think a user's device list might be out of + # sync. We do all device list resyncing on the master instance, so if + # we're on a worker we hit the device resync replication API. + if hs.config.worker.worker_app is None: + self._user_device_resync = ( + hs.get_device_handler().device_list_updater.user_device_resync + ) + else: + self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( + hs + ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: local_messages = {} @@ -138,9 +163,7 @@ class DeviceMessageHandler: await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) # Immediately attempt a resync in the background - run_in_background( - self._device_list_updater.user_device_resync, sender_user_id - ) + run_in_background(self._user_device_resync, sender_user_id) async def send_device_message( self, @@ -199,7 +222,8 @@ class DeviceMessageHandler: ) log_kv({"remote_messages": remote_messages}) - for destination in remote_messages.keys(): - # Enqueue a new federation transaction to send the new - # device messages to each remote destination. - self.federation.send_device_messages(destination) + if self.federation_sender: + for destination in remote_messages.keys(): + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + self.federation_sender.send_device_messages(destination) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 5b045bed02..1260f6d141 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -14,46 +14,8 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import ToDeviceStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore -from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_inbox", "stream_id" - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token(), - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token(), - ) - - self._last_device_delete_cache = ExpiringCache( - cache_name="last_device_delete_cache", - clock=self._clock, - max_len=10000, - expiry_ms=30 * 60 * 1000, - ) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ToDeviceStream.NAME: - self._device_inbox_id_gen.advance(instance_name, token) - for row in rows: - if row.entity.startswith("@"): - self._device_inbox_stream_cache.entity_has_changed( - row.entity, token - ) - else: - self._device_federation_outbox_stream_cache.entity_has_changed( - row.entity, token - ) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 95e5502bf2..1f89249475 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -56,6 +56,7 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + ToDeviceStream, TypingStream, ) @@ -115,6 +116,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, ToDeviceStream): + # Only add ToDeviceStream as a source on instances in charge of + # sending to device messages. + if hs.get_instance_name() in hs.config.worker.writers.to_device: + self._streams_to_replicate.append(stream) + + continue + if isinstance(stream, TypingStream): # Only add TypingStream as a source on the instance in charge of # typing. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 701748f93b..c4de07a0a8 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -127,9 +127,6 @@ class DataStore( self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) - self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_inbox", "stream_id" - ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) @@ -189,36 +186,6 @@ class DataStore( prefilled_cache=presence_cache_prefill, ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( - db_conn, - "device_inbox", - entity_column="user_id", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - min_device_inbox_id, - prefilled_cache=device_inbox_prefill, - ) - # The federation outbox and the local device inbox uses the same - # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( - db_conn, - "device_federation_outbox", - entity_column="destination", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - min_device_outbox_id, - prefilled_cache=device_outbox_prefill, - ) - device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index d42faa3f1f..58d3f71e45 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -17,15 +17,100 @@ import logging from typing import List, Tuple from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.replication.tcp.streams import ToDeviceStream +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) class DeviceInboxWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + # Map of (user_id, device_id) to the last stream_id that has been + # deleted up to. This is so that we can no op deletions. + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_device = ( + self._instance_name in hs.config.worker.writers.to_device + ) + + self._device_inbox_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="to_device", + instance_name=self._instance_name, + table="device_inbox", + instance_column="instance_name", + id_column="stream_id", + sequence_name="device_inbox_sequence", + writers=hs.config.worker.writers.to_device, + ) + else: + self._can_write_to_device = True + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_inbox", "stream_id" + ) + + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( + db_conn, + "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( + db_conn, + "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + for row in rows: + if row.entity.startswith("@"): + self._device_inbox_stream_cache.entity_has_changed( + row.entity, token + ) + else: + self._device_federation_outbox_stream_cache.entity_has_changed( + row.entity, token + ) + return super().process_replication_rows(stream_name, instance_name, token, rows) + def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() @@ -278,52 +363,6 @@ class DeviceInboxWorkerStore(SQLBaseStore): "get_all_new_device_messages", get_all_new_device_messages_txn ) - -class DeviceInboxBackgroundUpdateStore(SQLBaseStore): - DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self.db_pool.updates.register_background_index_update( - "device_inbox_stream_index", - index_name="device_inbox_stream_id_user_id", - table="device_inbox", - columns=["stream_id", "user_id"], - ) - - self.db_pool.updates.register_background_update_handler( - self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox - ) - - async def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") - txn.close() - - await self.db_pool.runWithConnection(reindex_txn) - - await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) - - return 1 - - -class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): - DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - # Map of (user_id, device_id) to the last stream_id that has been - # deleted up to. This is so that we can no op deletions. - self._last_device_delete_cache = ExpiringCache( - cache_name="last_device_delete_cache", - clock=self._clock, - max_len=10000, - expiry_ms=30 * 60 * 1000, - ) - @trace async def add_messages_to_device_inbox( self, @@ -342,6 +381,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) The new stream_id. """ + assert self._can_write_to_device + def add_messages_txn(txn, now_ms, stream_id): # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( @@ -351,16 +392,20 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Add the remote messages to the federation outbox. # We'll send them to a remote server when we next send a # federation transaction to that destination. - sql = ( - "INSERT INTO device_federation_outbox" - " (destination, stream_id, queued_ts, messages_json)" - " VALUES (?,?,?,?)" + self.db_pool.simple_insert_many_txn( + txn, + table="device_federation_outbox", + values=[ + { + "destination": destination, + "stream_id": stream_id, + "queued_ts": now_ms, + "messages_json": json_encoder.encode(edu), + "instance_name": self._instance_name, + } + for destination, edu in remote_messages_by_destination.items() + ], ) - rows = [] - for destination, edu in remote_messages_by_destination.items(): - edu_json = json_encoder.encode(edu) - rows.append((destination, stream_id, now_ms, edu_json)) - txn.executemany(sql, rows) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() @@ -379,6 +424,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) async def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: dict ) -> int: + assert self._can_write_to_device + def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our @@ -427,38 +474,45 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): + assert self._can_write_to_device + local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. - sql = "SELECT device_id FROM devices WHERE user_id = ?" - txn.execute(sql, (user_id,)) + devices = self.db_pool.simple_select_onecol_txn( + txn, + table="devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + ) + message_json = json_encoder.encode(messages_by_device["*"]) - for row in txn: + for device_id in devices: # Add the message for all devices for this user on this # server. - device = row[0] - messages_json_for_user[device] = message_json + messages_json_for_user[device_id] = message_json else: if not devices: continue - clause, args = make_in_list_sql_clause( - txn.database_engine, "device_id", devices + rows = self.db_pool.simple_select_many_txn( + txn, + table="devices", + keyvalues={"user_id": user_id}, + column="device_id", + iterable=devices, + retcols=("device_id",), ) - sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause - # TODO: Maybe this needs to be done in batches if there are - # too many local devices for a given user. - txn.execute(sql, [user_id] + list(args)) - for row in txn: + for row in rows: # Only insert into the local inbox if the device exists on # this server - device = row[0] - message_json = json_encoder.encode(messages_by_device[device]) - messages_json_for_user[device] = message_json + device_id = row["device_id"] + message_json = json_encoder.encode(messages_by_device[device_id]) + messages_json_for_user[device_id] = message_json if messages_json_for_user: local_by_user_then_device[user_id] = messages_json_for_user @@ -466,14 +520,52 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) if not local_by_user_then_device: return - sql = ( - "INSERT INTO device_inbox" - " (user_id, device_id, stream_id, message_json)" - " VALUES (?,?,?,?)" + self.db_pool.simple_insert_many_txn( + txn, + table="device_inbox", + values=[ + { + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "message_json": message_json, + "instance_name": self._instance_name, + } + for user_id, messages_by_device in local_by_user_then_device.items() + for device_id, message_json in messages_by_device.items() + ], ) - rows = [] - for user_id, messages_by_device in local_by_user_then_device.items(): - for device_id, message_json in messages_by_device.items(): - rows.append((user_id, device_id, stream_id, message_json)) - txn.executemany(sql, rows) + +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.db_pool.updates.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox + ) + + async def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") + txn.close() + + await self.db_pool.runWithConnection(reindex_txn) + + await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + return 1 + + +class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): + pass diff --git a/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql new file mode 100644 index 0000000000..d781a92fec --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE device_inbox ADD COLUMN instance_name TEXT; +ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT; +ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres new file mode 100644 index 0000000000..45a845a3a5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence; + +-- We need to take the max across both device_inbox and device_federation_outbox +-- tables as they share the ID generator +SELECT setval('device_inbox_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox), + (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox) + ) +));