diff --git a/changelog.d/13005.misc b/changelog.d/13005.misc new file mode 100644 index 0000000000..3bb51962e7 --- /dev/null +++ b/changelog.d/13005.misc @@ -0,0 +1 @@ +Reduce DB usage of `/sync` when a large number of unread messages have recently been sent in a room. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 9586086c03..9c06c837dc 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -58,6 +58,9 @@ from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateSt from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore +from synapse.storage.databases.main.event_push_actions import ( + EventPushActionsWorkerStore, +) from synapse.storage.databases.main.events_bg_updates import ( EventsBackgroundUpdatesStore, ) @@ -199,6 +202,7 @@ R = TypeVar("R") class Store( + EventPushActionsWorkerStore, ClientIpBackgroundUpdateStore, DeviceInboxBackgroundUpdateStore, DeviceBackgroundUpdateStore, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index af19c513be..6ad053f678 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership, ReceiptTypes +from synapse.api.constants import EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -1054,14 +1054,10 @@ class SyncHandler: self, room_id: str, sync_config: SyncConfig ) -> NotifCounts: with Measure(self.clock, "unread_notifs_for_room_id"): - last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( - user_id=sync_config.user.to_string(), - room_id=room_id, - receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), - ) return await self.store.get_unread_event_push_actions_by_room_for_user( - room_id, sync_config.user.to_string(), last_unread_event_id + room_id, + sync_config.user.to_string(), ) async def generate_sync_result( diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 8397229ccb..6661887d9f 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Dict -from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage.controllers import StorageControllers @@ -24,30 +23,24 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user( - user_id, (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) - ) - badge = len(invites) for room_id in joins: - if room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[room_id] - - notifs = await ( - store.get_unread_event_push_actions_by_room_for_user( - room_id, user_id, last_unread_event_id - ) + notifs = await ( + store.get_unread_event_push_actions_by_room_for_user( + room_id, + user_id, ) - if notifs.notify_count == 0: - continue + ) + if notifs.notify_count == 0: + continue - if group_by_room: - # return one badge count per conversation - badge += 1 - else: - # increment the badge count by the number of unread messages in the room - badge += notifs.notify_count + if group_by_room: + # return one badge count per conversation + badge += 1 + else: + # increment the badge count by the number of unread messages in the room + badge += notifs.notify_count return badge diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a78d68a9d7..e8c63cf567 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -92,6 +92,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", + "event_push_summary": "event_push_summary_unique_index", } diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9121badb3a..cb3d1242bb 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -104,13 +104,14 @@ class DataStore( PusherStore, PushRuleStore, ApplicationServiceTransactionStore, + EventPushActionsStore, + ServerMetricsStore, ReceiptsStore, EndToEndKeyStore, EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, - EventPushActionsStore, OpenIdStore, ClientIpWorkerStore, DeviceStore, @@ -124,7 +125,6 @@ class DataStore( UIAuthStore, EventForwardExtremitiesStore, CacheInvalidationWorkerStore, - ServerMetricsStore, LockStore, SessionStore, ): diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b019979350..ae705889a5 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import attr +from synapse.api.constants import ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -24,6 +25,8 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -79,15 +82,15 @@ class UserPushAction(EmailPushAction): profile_tag: str -@attr.s(slots=True, frozen=True, auto_attribs=True) +@attr.s(slots=True, auto_attribs=True) class NotifCounts: """ The per-user, per-room count of notifications. Used by sync and push. """ - notify_count: int - unread_count: int - highlight_count: int + notify_count: int = 0 + unread_count: int = 0 + highlight_count: int = 0 def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str: @@ -119,7 +122,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st return DEFAULT_NOTIF_ACTION -class EventPushActionsWorkerStore(SQLBaseStore): +class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -148,12 +151,20 @@ class EventPushActionsWorkerStore(SQLBaseStore): self._rotate_notifs, 30 * 60 * 1000 ) - @cached(num_args=3, tree=True, max_entries=5000) + self.db_pool.updates.register_background_index_update( + "event_push_summary_unique_index", + index_name="event_push_summary_unique_index", + table="event_push_summary", + columns=["user_id", "room_id"], + unique=True, + replaces_index="event_push_summary_user_rm", + ) + + @cached(tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, user_id: str, - last_read_event_id: Optional[str], ) -> NotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after the given read receipt. @@ -165,8 +176,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): Args: room_id: The room to retrieve the counts in. user_id: The user to retrieve the counts for. - last_read_event_id: The event associated with the latest read receipt for - this user in this room. None if no receipt for this user in this room. Returns A dict containing the counts mentioned earlier in this docstring, @@ -178,7 +187,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): self._get_unread_counts_by_receipt_txn, room_id, user_id, - last_read_event_id, ) def _get_unread_counts_by_receipt_txn( @@ -186,16 +194,17 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn: LoggingTransaction, room_id: str, user_id: str, - last_read_event_id: Optional[str], ) -> NotifCounts: - stream_ordering = None + result = self.get_last_receipt_for_user_txn( + txn, + user_id, + room_id, + receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + ) - if last_read_event_id is not None: - stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined] - txn, - last_read_event_id, - allow_none=True, - ) + stream_ordering = None + if result: + _, stream_ordering = result if stream_ordering is None: # Either last_read_event_id is None, or it's an event we don't have (e.g. @@ -218,49 +227,95 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_unread_counts_by_pos_txn( self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int ) -> NotifCounts: - sql = ( - "SELECT" - " COUNT(CASE WHEN notif = 1 THEN 1 END)," - " COUNT(CASE WHEN highlight = 1 THEN 1 END)," - " COUNT(CASE WHEN unread = 1 THEN 1 END)" - " FROM event_push_actions ea" - " WHERE user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) + """Get the number of unread messages for a user/room that have happened + since the given stream ordering. + """ - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - - (notif_count, highlight_count, unread_count) = (0, 0, 0) - - if row: - (notif_count, highlight_count, unread_count) = row + counts = NotifCounts() + # First we pull the counts from the summary table txn.execute( """ - SELECT notif_count, unread_count FROM event_push_summary + SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) + FROM event_push_summary WHERE room_id = ? AND user_id = ? AND stream_ordering > ? """, (room_id, user_id, stream_ordering), ) row = txn.fetchone() + summary_stream_ordering = 0 if row: - notif_count += row[0] + summary_stream_ordering = row[0] + counts.notify_count += row[1] + counts.unread_count += row[2] - if row[1] is not None: - # The unread_count column of event_push_summary is NULLable, so we need - # to make sure we don't try increasing the unread counts if it's NULL - # for this row. - unread_count += row[1] + # Next we need to count highlights, which aren't summarized + sql = """ + SELECT COUNT(*) FROM event_push_actions + WHERE user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND highlight = 1 + """ + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + if row: + counts.highlight_count += row[0] - return NotifCounts( - notify_count=notif_count, - unread_count=unread_count, - highlight_count=highlight_count, + # Finally we need to count push actions that haven't been summarized + # yet. + # We only want to pull out push actions that we haven't summarized yet. + stream_ordering = max(stream_ordering, summary_stream_ordering) + notify_count, unread_count = self._get_notif_unread_count_for_user_room( + txn, room_id, user_id, stream_ordering ) + counts.notify_count += notify_count + counts.unread_count += unread_count + + return counts + + def _get_notif_unread_count_for_user_room( + self, + txn: LoggingTransaction, + room_id: str, + user_id: str, + stream_ordering: int, + max_stream_ordering: Optional[int] = None, + ) -> Tuple[int, int]: + """Returns the notify and unread counts from `event_push_actions` for + the given user/room in the given range. + + Does not consult `event_push_summary` table, which may include push + actions that have been deleted from `event_push_actions` table. + """ + + clause = "" + args = [user_id, room_id, stream_ordering] + if max_stream_ordering is not None: + clause = "AND ea.stream_ordering <= ?" + args.append(max_stream_ordering) + + sql = f""" + SELECT + COUNT(CASE WHEN notif = 1 THEN 1 END), + COUNT(CASE WHEN unread = 1 THEN 1 END) + FROM event_push_actions ea + WHERE user_id = ? + AND room_id = ? + AND ea.stream_ordering > ? + {clause} + """ + + txn.execute(sql, args) + row = txn.fetchone() + + if row: + return cast(Tuple[int, int], row) + + return 0, 0 + async def get_push_action_users_in_range( self, min_stream_ordering: int, max_stream_ordering: int ) -> List[str]: @@ -754,6 +809,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): if caught_up: break await self.hs.get_clock().sleep(self._rotate_delay) + + await self._remove_old_push_actions_that_have_rotated() finally: self._doing_notif_rotation = False @@ -782,20 +839,16 @@ class EventPushActionsWorkerStore(SQLBaseStore): stream_row = txn.fetchone() if stream_row: (offset_stream_ordering,) = stream_row - assert self.stream_ordering_day_ago is not None - rotate_to_stream_ordering = min( - self.stream_ordering_day_ago, offset_stream_ordering - ) - caught_up = offset_stream_ordering >= self.stream_ordering_day_ago + rotate_to_stream_ordering = offset_stream_ordering + caught_up = False else: - rotate_to_stream_ordering = self.stream_ordering_day_ago + rotate_to_stream_ordering = self._stream_id_gen.get_current_token() caught_up = True logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) - # We have caught up iff we were limited by `stream_ordering_day_ago` return caught_up def _rotate_notifs_before_txn( @@ -819,7 +872,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? - AND highlight = 0 AND %s = 1 GROUP BY user_id, room_id ) AS upd @@ -914,19 +966,73 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - txn.execute( - "DELETE FROM event_push_actions" - " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) - txn.execute( "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", (rotate_to_stream_ordering,), ) + async def _remove_old_push_actions_that_have_rotated( + self, + ) -> None: + """Clear out old push actions that have been summarized.""" + + # We want to clear out anything that older than a day that *has* already + # been rotated. + rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol( + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + max_stream_ordering_to_delete = min( + rotated_upto_stream_ordering, self.stream_ordering_day_ago + ) + + def remove_old_push_actions_that_have_rotated_txn( + txn: LoggingTransaction, + ) -> bool: + # We don't want to clear out too much at a time, so we bound our + # deletes. + batch_size = 10000 + + txn.execute( + """ + SELECT stream_ordering FROM event_push_actions + WHERE stream_ordering < ? AND highlight = 0 + ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? + """, + ( + max_stream_ordering_to_delete, + batch_size, + ), + ) + stream_row = txn.fetchone() + + if stream_row: + (stream_ordering,) = stream_row + else: + stream_ordering = max_stream_ordering_to_delete + + txn.execute( + """ + DELETE FROM event_push_actions + WHERE stream_ordering < ? AND highlight = 0 + """, + (stream_ordering,), + ) + + logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) + + return txn.rowcount < batch_size + + while True: + done = await self.db_pool.runInteraction( + "_remove_old_push_actions_that_have_rotated", + remove_old_push_actions_that_have_rotated_txn, + ) + if done: + break + def _remove_old_push_actions_before_txn( self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int ) -> None: @@ -965,12 +1071,26 @@ class EventPushActionsWorkerStore(SQLBaseStore): (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), ) - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + notif_count, unread_count = self._get_notif_unread_count_for_user_room( + txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering + ) + + self.db_pool.simple_upsert_txn( + txn, + table="event_push_summary", + keyvalues={"room_id": room_id, "user_id": user_id}, + values={ + "notif_count": notif_count, + "unread_count": unread_count, + "stream_ordering": old_rotate_stream_ordering, + }, ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d5aefe02b6..86649c1e6c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -110,9 +110,9 @@ def _load_rules( # the abstract methods being implemented. class PushRulesWorkerStore( ApplicationServiceWorkerStore, - ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, + ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta, diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index b6106affa6..bec6d60577 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -118,7 +118,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return self._receipts_id_gen.get_current_token() async def get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_types: Iterable[str] + self, user_id: str, room_id: str, receipt_types: Collection[str] ) -> Optional[str]: """ Fetch the event ID for the latest receipt in a room with one of the given receipt types. @@ -126,58 +126,63 @@ class ReceiptsWorkerStore(SQLBaseStore): Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. - receipt_type: The receipt types to fetch. Earlier receipt types - are given priority if multiple receipts point to the same event. + receipt_type: The receipt types to fetch. Returns: The latest receipt, if one exists. """ - latest_event_id: Optional[str] = None - latest_stream_ordering = 0 - for receipt_type in receipt_types: - result = await self._get_last_receipt_event_id_for_user( - user_id, room_id, receipt_type - ) - if result is None: - continue - event_id, stream_ordering = result + result = await self.db_pool.runInteraction( + "get_last_receipt_event_id_for_user", + self.get_last_receipt_for_user_txn, + user_id, + room_id, + receipt_types, + ) + if not result: + return None - if latest_event_id is None or latest_stream_ordering < stream_ordering: - latest_event_id = event_id - latest_stream_ordering = stream_ordering + event_id, _ = result + return event_id - return latest_event_id - - @cached() - async def _get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_type: str + def get_last_receipt_for_user_txn( + self, + txn: LoggingTransaction, + user_id: str, + room_id: str, + receipt_types: Collection[str], ) -> Optional[Tuple[str, int]]: """ - Fetch the event ID and stream ordering for the latest receipt. + Fetch the event ID and stream_ordering for the latest receipt in a room + with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. - receipt_type: The receipt type to fetch. + receipt_type: The receipt types to fetch. Returns: - The event ID and stream ordering of the latest receipt, if one exists; - otherwise `None`. + The latest receipt, if one exists. """ - sql = """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "receipt_type", receipt_types + ) + + sql = f""" SELECT event_id, stream_ordering FROM receipts_linearized INNER JOIN events USING (room_id, event_id) - WHERE user_id = ? + WHERE {clause} + AND user_id = ? AND room_id = ? - AND receipt_type = ? + ORDER BY stream_ordering DESC + LIMIT 1 """ - def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]: - txn.execute(sql, (user_id, room_id, receipt_type)) - return cast(Optional[Tuple[str, int]], txn.fetchone()) + args.extend((user_id, room_id)) + txn.execute(sql, args) - return await self.db_pool.runInteraction("get_own_receipt_for_user", f) + return cast(Optional[Tuple[str, int]], txn.fetchone()) async def get_receipts_for_user( self, user_id: str, receipt_types: Iterable[str] @@ -577,8 +582,11 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -> None: self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) - self._get_last_receipt_event_id_for_user.invalidate( - (user_id, room_id, receipt_type) + + # We use this method to invalidate so that we don't end up with circular + # dependencies between the receipts and push action stores. + self._attempt_to_invalidate_cache( + "get_unread_event_push_actions_by_room_for_user", (room_id,) ) def process_replication_rows( diff --git a/synapse/storage/schema/main/delta/40/event_push_summary.sql b/synapse/storage/schema/main/delta/40/event_push_summary.sql index 3918f0b794..499bf60178 100644 --- a/synapse/storage/schema/main/delta/40/event_push_summary.sql +++ b/synapse/storage/schema/main/delta/40/event_push_summary.sql @@ -13,9 +13,10 @@ * limitations under the License. */ --- Aggregate of old notification counts that have been deleted out of the --- main event_push_actions table. This count does not include those that were --- highlights, as they remain in the event_push_actions table. +-- Aggregate of notification counts up to `stream_ordering`, including those +-- that may have been deleted out of the main event_push_actions table. This +-- count does not include those that were highlights, as they remain in the +-- event_push_actions table. CREATE TABLE event_push_summary ( user_id TEXT NOT NULL, room_id TEXT NOT NULL, diff --git a/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql b/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql new file mode 100644 index 0000000000..9cdcea21ae --- /dev/null +++ b/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql @@ -0,0 +1,18 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a unique index to `event_push_summary` +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7002, 'event_push_summary_unique_index', '{}'); diff --git a/tests/push/test_http.py b/tests/push/test_http.py index ba158f5d93..d9c68cdd2d 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -577,7 +577,7 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # # This push should still only contain an unread count of 1 (for 1 unread room) - self._check_push_attempt(6, 1) + self._check_push_attempt(7, 1) @override_config({"push": {"group_unread_count_by_room": False}}) def test_push_unread_count_message_count(self) -> None: @@ -591,7 +591,7 @@ class HTTPPusherTests(HomeserverTestCase): # # We're counting every unread message, so there should now be 3 since the # last read receipt - self._check_push_attempt(6, 3) + self._check_push_attempt(7, 3) def _test_push_unread_count(self) -> None: """ @@ -641,18 +641,18 @@ class HTTPPusherTests(HomeserverTestCase): response = self.helper.send( room_id, body="Hello there!", tok=other_access_token ) - # To get an unread count, the user who is getting notified has to have a read - # position in the room. We'll set the read position to this event in a moment + first_message_event_id = response["event_id"] expected_push_attempts = 1 - self._check_push_attempt(expected_push_attempts, 0) + self._check_push_attempt(expected_push_attempts, 1) self._send_read_request(access_token, first_message_event_id, room_id) - # Unread count has not changed. Therefore, ensure that read request does not - # trigger a push notification. - self.assertEqual(len(self.push_attempts), 1) + # Unread count has changed. Therefore, ensure that read request triggers + # a push notification. + expected_push_attempts += 1 + self.assertEqual(len(self.push_attempts), expected_push_attempts) # Send another message response2 = self.helper.send( diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 6d3d4afe52..531a0db2d0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -15,7 +15,9 @@ import logging from typing import Iterable, Optional from canonicaljson import encode_canonical_json +from parameterized import parameterized +from synapse.api.constants import ReceiptTypes from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource @@ -156,17 +158,26 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ], ) - def test_push_actions_for_user(self): + @parameterized.expand([(True,), (False,)]) + def test_push_actions_for_user(self, send_receipt: bool): self.persist(type="m.room.create", key="", creator=USER_ID) - self.persist(type="m.room.join", key=USER_ID, membership="join") + self.persist(type="m.room.member", key=USER_ID, membership="join") self.persist( - type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" + type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join" ) event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello") self.replicate() + + if send_receipt: + self.get_success( + self.master_store.insert_receipt( + ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {} + ) + ) + self.check( "get_unread_event_push_actions_by_room_for_user", - [ROOM_ID, USER_ID_2, event1.event_id], + [ROOM_ID, USER_ID_2], NotifCounts(highlight_count=0, unread_count=0, notify_count=0), ) @@ -179,7 +190,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.replicate() self.check( "get_unread_event_push_actions_by_room_for_user", - [ROOM_ID, USER_ID_2, event1.event_id], + [ROOM_ID, USER_ID_2], NotifCounts(highlight_count=0, unread_count=0, notify_count=1), ) @@ -194,7 +205,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.replicate() self.check( "get_unread_event_push_actions_by_room_for_user", - [ROOM_ID, USER_ID_2, event1.event_id], + [ROOM_ID, USER_ID_2], NotifCounts(highlight_count=1, unread_count=0, notify_count=2), ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 0f9add4841..4273524c4c 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -51,10 +51,16 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): room_id = "!foo:example.com" user_id = "@user1235:example.com" + last_read_stream_ordering = [0] + def _assert_counts(noitf_count, highlight_count): counts = self.get_success( self.store.db_pool.runInteraction( - "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 + "", + self.store._get_unread_counts_by_pos_txn, + room_id, + user_id, + last_read_stream_ordering[0], ) ) self.assertEqual( @@ -98,6 +104,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) def _mark_read(stream, depth): + last_read_stream_ordering[0] = stream self.get_success( self.store.db_pool.runInteraction( "", @@ -144,8 +151,19 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(1, 1) _rotate(9) _assert_counts(1, 1) - _rotate(10) - _assert_counts(1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _inject_actions(10, PlAIN_NOTIF) + _rotate(11) + _assert_counts(2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(8, 8) + _assert_counts(1, 0) + _mark_read(10, 10) + _assert_counts(0, 0) def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts):