diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index f0353929da..8e17800364 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -436,15 +436,27 @@ class DeviceStore(SQLBaseStore): ) def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # First we DELETE all rows such that only the latest row for each + # (destination, user_id is left. We do this by selecting first and + # deleting. + sql = """ + SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + GROUP BY user_id + HAVING count(*) > 1 + """ + txn.execute(sql, (destination, stream_id,)) + rows = txn.fetchall() + sql = """ DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id < ( - SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? - ) + WHERE destination = ? AND user_id = ? AND stream_id < ? """ - txn.execute(sql, (destination, destination, stream_id,)) + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows) + ) + # Mark everything that is left as sent sql = """ UPDATE device_lists_outbound_pokes SET sent = ? WHERE destination = ? AND stream_id <= ? @@ -545,18 +557,22 @@ class DeviceStore(SQLBaseStore): (destination, user_id) tuple to ensure that the prev_ids remain correct if the server does come back. """ - now = self._clock.time_msec() + yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 def _prune_txn(txn): select_sql = """ SELECT destination, user_id, max(stream_id) as stream_id FROM device_lists_outbound_pokes GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 """ - txn.execute(select_sql) + txn.execute(select_sql, (yesterday,)) rows = txn.fetchall() + if not rows: + return + delete_sql = """ DELETE FROM device_lists_outbound_pokes WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? @@ -565,11 +581,13 @@ class DeviceStore(SQLBaseStore): txn.executemany( delete_sql, ( - (now, row["destination"], row["user_id"], row["stream_id"]) + (yesterday, row[0], row[1], row[2]) for row in rows ) ) + logger.info("Pruned %d device list outbound pokes", txn.rowcount) + return self.runInteraction( "_prune_old_outbound_device_pokes", _prune_txn )