diff --git a/changelog.d/12252.feature b/changelog.d/12252.feature new file mode 100644 index 0000000000..82b9e82f86 --- /dev/null +++ b/changelog.d/12252.feature @@ -0,0 +1 @@ +Move `update_client_ip` background job from the main process to the background worker. \ No newline at end of file diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 12750d9b89..5eb545c86e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1268,6 +1268,7 @@ class DatabasePool: value_names: Collection[str], value_values: Collection[Collection[Any]], desc: str, + lock: bool = True, ) -> None: """ Upsert, many times. @@ -1279,6 +1280,8 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. + lock: True to lock the table when doing the upsert. Unused if the database engine + supports native upserts. """ # We can autocommit if we are going to use native upserts @@ -1286,7 +1289,7 @@ class DatabasePool: self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ) - return await self.runInteraction( + await self.runInteraction( desc, self.simple_upsert_many_txn, table, @@ -1294,6 +1297,7 @@ class DatabasePool: key_values, value_names, value_values, + lock=lock, db_autocommit=autocommit, ) @@ -1305,6 +1309,7 @@ class DatabasePool: key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Iterable[Iterable[Any]], + lock: bool = True, ) -> None: """ Upsert, many times. @@ -1316,6 +1321,8 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. + lock: True to lock the table when doing the upsert. Unused if the database engine + supports native upserts. """ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( @@ -1323,7 +1330,7 @@ class DatabasePool: ) else: return self.simple_upsert_many_txn_emulated( - txn, table, key_names, key_values, value_names, value_values + txn, table, key_names, key_values, value_names, value_values, lock=lock ) def simple_upsert_many_txn_emulated( @@ -1334,6 +1341,7 @@ class DatabasePool: key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Iterable[Iterable[Any]], + lock: bool = True, ) -> None: """ Upsert, many times, but without native UPSERT support or batching. @@ -1345,17 +1353,24 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. + lock: True to lock the table when doing the upsert. """ # No value columns, therefore make a blank list so that the following # zip() works correctly. if not value_names: value_values = [() for x in range(len(key_values))] + if lock: + # Lock the table just once, to prevent it being done once per row. + # Note that, according to Postgres' documentation, once obtained, + # the lock is held for the remainder of the current transaction. + self.engine.lock_table(txn, "user_ips") + for keyv, valv in zip(key_values, value_values): _keys = {x: y for x, y in zip(key_names, keyv)} _vals = {x: y for x, y in zip(value_names, valv)} - self.simple_upsert_txn_emulated(txn, table, _keys, _vals) + self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False) def simple_upsert_many_txn_native_upsert( self, @@ -1792,6 +1807,86 @@ class DatabasePool: return txn.rowcount + async def simple_update_many( + self, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[Any]], + desc: str, + ) -> None: + """ + Update, many times, using batching where possible. + If the keys don't match anything, nothing will be updated. + + Args: + table: The table to update + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The names of value columns to update. + value_values: A list of each row's value column values. + """ + + await self.runInteraction( + desc, + self.simple_update_many_txn, + table, + key_names, + key_values, + value_names, + value_values, + ) + + @staticmethod + def simple_update_many_txn( + txn: LoggingTransaction, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Collection[Iterable[Any]], + ) -> None: + """ + Update, many times, using batching where possible. + If the keys don't match anything, nothing will be updated. + + Args: + table: The table to update + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The names of value columns to update. + value_values: A list of each row's value column values. + """ + + if len(value_values) != len(key_values): + raise ValueError( + f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." + ) + + # List of tuples of (value values, then key values) + # (This matches the order needed for the query) + args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)] + + for ks, vs in zip(key_values, value_values): + args.append(tuple(vs) + tuple(ks)) + + # 'col1 = ?, col2 = ?, ...' + set_clause = ", ".join(f"{n} = ?" for n in value_names) + + if key_names: + # 'WHERE col3 = ? AND col4 = ? AND col5 = ?' + where_clause = "WHERE " + (" AND ".join(f"{n} = ?" for n in key_names)) + else: + where_clause = "" + + # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ? + sql = f""" + UPDATE {table} SET {set_clause} {where_clause} + """ + + txn.execute_batch(sql, args) + async def simple_update_one( self, table: str, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 8480ea4e1c..0df160d2b0 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -616,9 +616,10 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke to_update = self._batch_row_update self._batch_row_update = {} - await self.db_pool.runInteraction( - "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update - ) + if to_update: + await self.db_pool.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update + ) def _update_client_ips_batch_txn( self, @@ -629,42 +630,43 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke self._update_on_this_worker ), "This worker is not designated to update client IPs" - if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( - not self.database_engine.can_native_upsert - ): - self.database_engine.lock_table(txn, "user_ips") + # Keys and values for the `user_ips` upsert. + user_ips_keys = [] + user_ips_values = [] + + # Keys and values for the `devices` update. + devices_keys = [] + devices_values = [] for entry in to_update.items(): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - - self.db_pool.simple_upsert_txn( - txn, - table="user_ips", - keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip}, - values={ - "user_agent": user_agent, - "device_id": device_id, - "last_seen": last_seen, - }, - lock=False, - ) + user_ips_keys.append((user_id, access_token, ip)) + user_ips_values.append((user_agent, device_id, last_seen)) # Technically an access token might not be associated with # a device so we need to check. if device_id: - # this is always an update rather than an upsert: the row should - # already exist, and if it doesn't, that may be because it has been - # deleted, and we don't want to re-create it. - self.db_pool.simple_update_txn( - txn, - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - updatevalues={ - "user_agent": user_agent, - "last_seen": last_seen, - "ip": ip, - }, - ) + devices_keys.append((user_id, device_id)) + devices_values.append((user_agent, last_seen, ip)) + + self.db_pool.simple_upsert_many_txn( + txn, + table="user_ips", + key_names=("user_id", "access_token", "ip"), + key_values=user_ips_keys, + value_names=("user_agent", "device_id", "last_seen"), + value_values=user_ips_values, + ) + + if devices_values: + self.db_pool.simple_update_many_txn( + txn, + table="devices", + key_names=("user_id", "device_id"), + key_values=devices_keys, + value_names=("user_agent", "last_seen", "ip"), + value_values=devices_values, + ) async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 366398e39d..09cb06d614 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,7 +14,7 @@ # limitations under the License. import secrets -from typing import Any, Dict, Generator, List, Tuple +from typing import Generator, Tuple from twisted.test.proto_helpers import MemoryReactor @@ -24,7 +24,7 @@ from synapse.util import Clock from tests import unittest -class UpsertManyTests(unittest.HomeserverTestCase): +class UpdateUpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.storage = hs.get_datastores().main @@ -46,9 +46,13 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) - def _dump_to_tuple( - self, res: List[Dict[str, Any]] - ) -> Generator[Tuple[int, str, str], None, None]: + def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]: + res = self.get_success( + self.storage.db_pool.simple_select_list( + self.table_name, None, ["id, username, value"] + ) + ) + for i in res: yield (i["id"], i["username"], i["value"]) @@ -75,13 +79,8 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) # Check results are what we expect - res = self.get_success( - self.storage.db_pool.simple_select_list( - self.table_name, None, ["id, username, value"] - ) - ) self.assertEqual( - set(self._dump_to_tuple(res)), + set(self._dump_table_to_tuple()), {(1, "user1", "hello"), (2, "user2", "there")}, ) @@ -102,12 +101,54 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) # Check results are what we expect - res = self.get_success( - self.storage.db_pool.simple_select_list( - self.table_name, None, ["id, username, value"] - ) - ) self.assertEqual( - set(self._dump_to_tuple(res)), + set(self._dump_table_to_tuple()), {(1, "user1", "hello"), (2, "user2", "bleb")}, ) + + def test_simple_update_many(self): + """ + simple_update_many performs many updates at once. + """ + # First add some data. + self.get_success( + self.storage.db_pool.simple_insert_many( + table=self.table_name, + keys=("id", "username", "value"), + values=[(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")], + desc="insert", + ) + ) + + # Check the data made it to the table + self.assertEqual( + set(self._dump_table_to_tuple()), + {(1, "alice", "A"), (2, "bob", "B"), (3, "charlie", "C")}, + ) + + # Now use simple_update_many + self.get_success( + self.storage.db_pool.simple_update_many( + table=self.table_name, + key_names=("username",), + key_values=( + ("alice",), + ("bob",), + ("stranger",), + ), + value_names=("value",), + value_values=( + ("aaa!",), + ("bbb!",), + ("???",), + ), + desc="update_many1", + ) + ) + + # Check the table is how we expect: + # charlie has been left alone + self.assertEqual( + set(self._dump_table_to_tuple()), + {(1, "alice", "aaa!"), (2, "bob", "bbb!"), (3, "charlie", "C")}, + )