Return thread notification counts down sync.

This commit is contained in:
Patrick Cloke 2022-06-09 13:18:25 -04:00
parent cbbe77f620
commit f03935dcb7
3 changed files with 73 additions and 41 deletions

View file

@ -1052,7 +1052,7 @@ class SyncHandler:
async def unread_notifs_for_room_id( async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig self, room_id: str, sync_config: SyncConfig
) -> NotifCounts: ) -> Dict[Optional[str], NotifCounts]:
with Measure(self.clock, "unread_notifs_for_room_id"): with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
@ -2122,7 +2122,7 @@ class SyncHandler:
) )
if room_builder.rtype == "joined": if room_builder.rtype == "joined":
unread_notifications: Dict[str, int] = {} unread_notifications: JsonDict = {}
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
@ -2137,10 +2137,18 @@ class SyncHandler:
if room_sync or always_include: if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config) notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
unread_notifications["notification_count"] = notifs.notify_count # Notifications for the main timeline.
unread_notifications["highlight_count"] = notifs.highlight_count main_notifs = notifs[None]
unread_notifications.update(main_notifs.to_dict())
room_sync.unread_count = notifs.unread_count room_sync.unread_count = main_notifs.unread_count
# And add info for each thread.
unread_notifications["unread_thread_notifications"] = {
thread_id: thread_notifs.to_dict()
for thread_id, thread_notifs in notifs.items()
if thread_id is not None
}
sync_result_builder.joined.append(room_sync) sync_result_builder.joined.append(room_sync)

View file

@ -39,7 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
room_id, user_id, last_unread_event_id room_id, user_id, last_unread_event_id
) )
) )
if notifs.notify_count == 0: # Combine the counts from all the threads.
notify_count = sum(n.notify_count for n in notifs.values())
if notify_count == 0:
continue continue
if group_by_room: if group_by_room:
@ -47,7 +50,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1 badge += 1
else: else:
# increment the badge count by the number of unread messages in the room # increment the badge count by the number of unread messages in the room
badge += notifs.notify_count badge += notify_count
return badge return badge

View file

@ -24,6 +24,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -79,7 +80,7 @@ class UserPushAction(EmailPushAction):
profile_tag: str profile_tag: str
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
class NotifCounts: class NotifCounts:
""" """
The per-user, per-room count of notifications. Used by sync and push. The per-user, per-room count of notifications. Used by sync and push.
@ -89,6 +90,12 @@ class NotifCounts:
unread_count: int unread_count: int
highlight_count: int highlight_count: int
def to_dict(self) -> JsonDict:
return {
"notification_count": self.notify_count,
"highlight_count": self.highlight_count,
}
def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str: def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions. """Custom serializer for actions. This allows us to "compress" common actions.
@ -148,13 +155,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_notifs, 30 * 60 * 1000 self._rotate_notifs, 30 * 60 * 1000
) )
@cached(num_args=3, tree=True, max_entries=5000) @cached(max_entries=5000, tree=True, iterable=True)
async def get_unread_event_push_actions_by_room_for_user( async def get_unread_event_push_actions_by_room_for_user(
self, self,
room_id: str, room_id: str,
user_id: str, user_id: str,
last_read_event_id: Optional[str], last_read_event_id: Optional[str],
) -> NotifCounts: ) -> Dict[Optional[str], NotifCounts]:
"""Get the notification count, the highlight count and the unread message count """Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt. for a given user in a given room after the given read receipt.
@ -187,7 +194,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
room_id: str, room_id: str,
user_id: str, user_id: str,
last_read_event_id: Optional[str], last_read_event_id: Optional[str],
) -> NotifCounts: ) -> Dict[Optional[str], NotifCounts]:
stream_ordering = None stream_ordering = None
if last_read_event_id is not None: if last_read_event_id is not None:
@ -217,49 +224,63 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_unread_counts_by_pos_txn( def _get_unread_counts_by_pos_txn(
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
) -> NotifCounts: ) -> Dict[Optional[str], NotifCounts]:
sql = ( sql = """
"SELECT" SELECT
" COUNT(CASE WHEN notif = 1 THEN 1 END)," COUNT(CASE WHEN notif = 1 THEN 1 END),
" COUNT(CASE WHEN highlight = 1 THEN 1 END)," COUNT(CASE WHEN highlight = 1 THEN 1 END),
" COUNT(CASE WHEN unread = 1 THEN 1 END)" COUNT(CASE WHEN unread = 1 THEN 1 END),
" FROM event_push_actions ea" thread_id
" WHERE user_id = ?" FROM event_push_actions ea
" AND room_id = ?" WHERE user_id = ?
" AND stream_ordering > ?" AND room_id = ?
) AND stream_ordering > ?
GROUP BY thread_id
"""
txn.execute(sql, (user_id, room_id, stream_ordering)) txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone() rows = txn.fetchall()
(notif_count, highlight_count, unread_count) = (0, 0, 0) notif_counts: Dict[Optional[str], NotifCounts] = {
# Ensure the main timeline has notification counts.
if row: None: NotifCounts(
(notif_count, highlight_count, unread_count) = row notify_count=0,
unread_count=0,
highlight_count=0,
)
}
for notif_count, highlight_count, unread_count, thread_id in rows:
notif_counts[thread_id] = NotifCounts(
notify_count=notif_count,
unread_count=unread_count,
highlight_count=highlight_count,
)
txn.execute( txn.execute(
""" """
SELECT notif_count, unread_count FROM event_push_summary SELECT notif_count, unread_count, thread_id FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ? WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""", """,
(room_id, user_id, stream_ordering), (room_id, user_id, stream_ordering),
) )
row = txn.fetchone() rows = txn.fetchall()
if row: for notif_count, unread_count, thread_id in rows:
notif_count += row[0] if unread_count is None:
# The unread_count column of event_push_summary is NULLable.
unread_count = 0
if row[1] is not None: if thread_id in notif_counts:
# The unread_count column of event_push_summary is NULLable, so we need notif_counts[thread_id].notify_count += notif_count
# to make sure we don't try increasing the unread counts if it's NULL notif_counts[thread_id].unread_count += unread_count
# for this row. else:
unread_count += row[1] notif_counts[thread_id] = NotifCounts(
notify_count=notif_count,
unread_count=unread_count,
highlight_count=0,
)
return NotifCounts( return notif_counts
notify_count=notif_count,
unread_count=unread_count,
highlight_count=highlight_count,
)
async def get_push_action_users_in_range( async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int self, min_stream_ordering: int, max_stream_ordering: int