From cbbe77f62073379da7ca77d5b743f64c1bbc3e82 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Jun 2022 13:46:49 -0400 Subject: [PATCH] Include the thread ID in the event push actions. --- synapse/push/bulk_push_rule_evaluator.py | 30 ++++++------ .../databases/main/event_push_actions.py | 48 +++++++++++-------- synapse/storage/databases/main/events.py | 4 +- synapse/storage/databases/main/receipts.py | 2 +- .../main/delta/70/03thread_notifications.sql | 23 +++++++++ .../replication/slave/storage/test_events.py | 1 + tests/storage/test_event_push_actions.py | 1 + 7 files changed, 71 insertions(+), 38 deletions(-) create mode 100644 synapse/storage/schema/main/delta/70/03thread_notifications.sql diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7791b289e2..d1c929e202 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -195,7 +195,7 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level async def _get_mutual_relations( - self, event: EventBase, rules: Iterable[Dict[str, Any]] + self, parent_id: str, rules: Iterable[Dict[str, Any]] ) -> Dict[str, Set[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. @@ -203,7 +203,7 @@ class BulkPushRuleEvaluator: If the given event has no relation information, returns an empty dictionary. Args: - event_id: The event ID which is targeted by relations. + parent_id: The event ID which is targeted by relations. rules: The push rules which will be processed for this event. Returns: @@ -217,12 +217,6 @@ class BulkPushRuleEvaluator: if not self._relations_match_enabled: return {} - # If the event does not have a relation, then cannot have any mutual - # relations. - relation = relation_from_event(event) - if not relation: - return {} - # Pre-filter to figure out which relation types are interesting. rel_types = set() for rule in rules: @@ -244,9 +238,7 @@ class BulkPushRuleEvaluator: return {} # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations( - relation.parent_id, rel_types - ) + return await self.store.get_mutual_event_relations(parent_id, rel_types) @measure_func("action_for_event_by_user") async def action_for_event_by_user( @@ -272,9 +264,18 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) - relations = await self._get_mutual_relations( - event, itertools.chain(*rules_by_user.values()) - ) + relation = relation_from_event(event) + # If the event does not have a relation, then cannot have any mutual + # relations or thread ID. + relations = {} + thread_id = None + if relation: + relations = await self._get_mutual_relations( + relation.parent_id, itertools.chain(*rules_by_user.values()) + ) + # XXX Does this need to point to a valid parent ID or anything? + if relation.rel_type == RelationTypes.THREAD: + thread_id = relation.parent_id evaluator = PushRuleEvaluatorForEvent( event, @@ -339,6 +340,7 @@ class BulkPushRuleEvaluator: event.event_id, actions_by_user, count_as_unread, + thread_id, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b0b3695012..812ed1a3d4 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -528,6 +528,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]], count_as_unread: bool, + thread_id: Optional[str], ) -> None: """Add the push actions for the event to the push action staging area. @@ -536,6 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): user_id_actions: A mapping of user_id to list of push actions, where an action can either be a string or dict. count_as_unread: Whether this event should increment unread counts. + thread_id: The thread this event is parent of, if applicable. """ if not user_id_actions: return @@ -544,7 +546,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( user_id: str, actions: List[Union[dict, str]] - ) -> Tuple[str, str, str, int, int, int]: + ) -> Tuple[str, str, str, int, int, int, Optional[str]]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 return ( @@ -554,6 +556,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): notif, # notif column is_highlight, # highlight column int(count_as_unread), # unread column + thread_id, # thread_id column ) def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None: @@ -562,8 +565,8 @@ class EventPushActionsWorkerStore(SQLBaseStore): sql = """ INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight, unread) - VALUES (?, ?, ?, ?, ?, ?) + (event_id, user_id, actions, notif, highlight, unread, thread_id) + VALUES (?, ?, ?, ?, ?, ?, ?) """ txn.execute_batch( @@ -810,20 +813,20 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Calculate the new counts that should be upserted into event_push_summary sql = """ - SELECT user_id, room_id, + SELECT user_id, room_id, thread_id, coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering, old.user_id FROM ( - SELECT user_id, room_id, count(*) as cnt, + SELECT user_id, room_id, thread_id, count(*) as cnt, max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0 AND %s = 1 - GROUP BY user_id, room_id + GROUP BY user_id, room_id, thread_id ) AS upd - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) """ # First get the count of unread messages. @@ -837,12 +840,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): # object because we might not have the same amount of rows in each of them. To do # this, we use a dict indexed on the user ID and room ID to make it easier to # populate. - summaries: Dict[Tuple[str, str], _EventPushSummary] = {} + summaries: Dict[Tuple[str, str, Optional[str]], _EventPushSummary] = {} for row in txn: - summaries[(row[0], row[1])] = _EventPushSummary( - unread_count=row[2], - stream_ordering=row[3], - old_user_id=row[4], + summaries[(row[0], row[1], row[2])] = _EventPushSummary( + unread_count=row[3], + stream_ordering=row[4], + old_user_id=row[5], notif_count=0, ) @@ -853,18 +856,18 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) for row in txn: - if (row[0], row[1]) in summaries: - summaries[(row[0], row[1])].notif_count = row[2] + if (row[0], row[1], row[2]) in summaries: + summaries[(row[0], row[1], row[2])].notif_count = row[3] else: # Because the rules on notifying are different than the rules on marking # a message unread, we might end up with messages that notify but aren't # marked unread, so we might not have a summary for this (user, room) # tuple to complete. - summaries[(row[0], row[1])] = _EventPushSummary( + summaries[(row[0], row[1], row[2])] = _EventPushSummary( unread_count=0, - stream_ordering=row[3], - old_user_id=row[4], - notif_count=row[2], + stream_ordering=row[4], + old_user_id=row[5], + notif_count=row[3], ) logger.info("Rotating notifications, handling %d rows", len(summaries)) @@ -881,6 +884,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): "notif_count", "unread_count", "stream_ordering", + "thread_id", ), values=[ ( @@ -889,8 +893,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): summary.notif_count, summary.unread_count, summary.stream_ordering, + thread_id, ) - for ((user_id, room_id), summary) in summaries.items() + for ((user_id, room_id, thread_id), summary) in summaries.items() if summary.old_user_id is None ], ) @@ -899,7 +904,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ UPDATE event_push_summary SET notif_count = ?, unread_count = ?, stream_ordering = ? - WHERE user_id = ? AND room_id = ? + WHERE user_id = ? AND room_id = ? AND thread_id = ? """, ( ( @@ -908,8 +913,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): summary.stream_ordering, user_id, room_id, + thread_id, ) - for ((user_id, room_id), summary) in summaries.items() + for ((user_id, room_id, thread_id), summary) in summaries.items() if summary.old_user_id is not None ), ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5b86ac55e9..6a1564349f 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2302,9 +2302,9 @@ class PersistEventsStore: sql = """ INSERT INTO event_push_actions ( room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight, unread + topological_ordering, notif, highlight, unread, thread_id ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id FROM event_push_actions_staging WHERE event_id = ? """ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 4622e8910e..cece802c6e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -731,7 +731,7 @@ class ReceiptsWorkerStore(SQLBaseStore): user_id, receipt_type, start_topo_ordering - 1 if start_topo_ordering is not None else None, - end_topo_ordering + 1, + end_topo_ordering + 1 if end_topo_ordering is not None else None, ), ) overlapping_receipts = txn.fetchall() diff --git a/synapse/storage/schema/main/delta/70/03thread_notifications.sql b/synapse/storage/schema/main/delta/70/03thread_notifications.sql new file mode 100644 index 0000000000..6fd444ccc1 --- /dev/null +++ b/synapse/storage/schema/main/delta/70/03thread_notifications.sql @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE event_push_actions_staging + ADD COLUMN thread_id TEXT DEFAULT NULL; + +ALTER TABLE event_push_actions + ADD COLUMN thread_id TEXT DEFAULT NULL; + +ALTER TABLE event_push_summary + ADD COLUMN thread_id TEXT DEFAULT NULL; diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 6d3d4afe52..3ad2b2ad79 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -393,6 +393,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event.event_id, {user_id: actions for user_id, actions in push_actions}, False, + None, ) ) return event, context diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 0f9add4841..1b52da201c 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -79,6 +79,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): event.event_id, {user_id: action}, False, + None, ) ) self.get_success(