Include the thread ID in the event push actions.

This commit is contained in:
Patrick Cloke 2022-06-08 13:46:49 -04:00
parent 3ca9a381ab
commit cbbe77f620
7 changed files with 71 additions and 38 deletions

View file

@ -195,7 +195,7 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
async def _get_mutual_relations(
self, event: EventBase, rules: Iterable[Dict[str, Any]]
self, parent_id: str, rules: Iterable[Dict[str, Any]]
) -> Dict[str, Set[Tuple[str, str]]]:
"""
Fetch event metadata for events which related to the same event as the given event.
@ -203,7 +203,7 @@ class BulkPushRuleEvaluator:
If the given event has no relation information, returns an empty dictionary.
Args:
event_id: The event ID which is targeted by relations.
parent_id: The event ID which is targeted by relations.
rules: The push rules which will be processed for this event.
Returns:
@ -217,12 +217,6 @@ class BulkPushRuleEvaluator:
if not self._relations_match_enabled:
return {}
# If the event does not have a relation, then cannot have any mutual
# relations.
relation = relation_from_event(event)
if not relation:
return {}
# Pre-filter to figure out which relation types are interesting.
rel_types = set()
for rule in rules:
@ -244,9 +238,7 @@ class BulkPushRuleEvaluator:
return {}
# If any valid rules were found, fetch the mutual relations.
return await self.store.get_mutual_event_relations(
relation.parent_id, rel_types
)
return await self.store.get_mutual_event_relations(parent_id, rel_types)
@measure_func("action_for_event_by_user")
async def action_for_event_by_user(
@ -272,9 +264,18 @@ class BulkPushRuleEvaluator:
sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context)
relations = await self._get_mutual_relations(
event, itertools.chain(*rules_by_user.values())
)
relation = relation_from_event(event)
# If the event does not have a relation, then cannot have any mutual
# relations or thread ID.
relations = {}
thread_id = None
if relation:
relations = await self._get_mutual_relations(
relation.parent_id, itertools.chain(*rules_by_user.values())
)
# XXX Does this need to point to a valid parent ID or anything?
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
evaluator = PushRuleEvaluatorForEvent(
event,
@ -339,6 +340,7 @@ class BulkPushRuleEvaluator:
event.event_id,
actions_by_user,
count_as_unread,
thread_id,
)

View file

@ -528,6 +528,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
event_id: str,
user_id_actions: Dict[str, List[Union[dict, str]]],
count_as_unread: bool,
thread_id: Optional[str],
) -> None:
"""Add the push actions for the event to the push action staging area.
@ -536,6 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts.
thread_id: The thread this event is parent of, if applicable.
"""
if not user_id_actions:
return
@ -544,7 +546,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
user_id: str, actions: List[Union[dict, str]]
) -> Tuple[str, str, str, int, int, int]:
) -> Tuple[str, str, str, int, int, int, Optional[str]]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
@ -554,6 +556,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
thread_id, # thread_id column
)
def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
@ -562,8 +565,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """
INSERT INTO event_push_actions_staging
(event_id, user_id, actions, notif, highlight, unread)
VALUES (?, ?, ?, ?, ?, ?)
(event_id, user_id, actions, notif, highlight, unread, thread_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
"""
txn.execute_batch(
@ -810,20 +813,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
SELECT user_id, room_id, thread_id,
coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering,
old.user_id
FROM (
SELECT user_id, room_id, count(*) as cnt,
SELECT user_id, room_id, thread_id, count(*) as cnt,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
AND %s = 1
GROUP BY user_id, room_id
GROUP BY user_id, room_id, thread_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
"""
# First get the count of unread messages.
@ -837,12 +840,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
summaries: Dict[Tuple[str, str, Optional[str]], _EventPushSummary] = {}
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
stream_ordering=row[3],
old_user_id=row[4],
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=row[3],
stream_ordering=row[4],
old_user_id=row[5],
notif_count=0,
)
@ -853,18 +856,18 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
for row in txn:
if (row[0], row[1]) in summaries:
summaries[(row[0], row[1])].notif_count = row[2]
if (row[0], row[1], row[2]) in summaries:
summaries[(row[0], row[1], row[2])].notif_count = row[3]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
summaries[(row[0], row[1])] = _EventPushSummary(
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=0,
stream_ordering=row[3],
old_user_id=row[4],
notif_count=row[2],
stream_ordering=row[4],
old_user_id=row[5],
notif_count=row[3],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
@ -881,6 +884,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"notif_count",
"unread_count",
"stream_ordering",
"thread_id",
),
values=[
(
@ -889,8 +893,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
thread_id,
)
for ((user_id, room_id), summary) in summaries.items()
for ((user_id, room_id, thread_id), summary) in summaries.items()
if summary.old_user_id is None
],
)
@ -899,7 +904,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
WHERE user_id = ? AND room_id = ? AND thread_id = ?
""",
(
(
@ -908,8 +913,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
summary.stream_ordering,
user_id,
room_id,
thread_id,
)
for ((user_id, room_id), summary) in summaries.items()
for ((user_id, room_id, thread_id), summary) in summaries.items()
if summary.old_user_id is not None
),
)

View file

@ -2302,9 +2302,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
topological_ordering, notif, highlight, unread
topological_ordering, notif, highlight, unread, thread_id
)
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
FROM event_push_actions_staging
WHERE event_id = ?
"""

View file

@ -731,7 +731,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
user_id,
receipt_type,
start_topo_ordering - 1 if start_topo_ordering is not None else None,
end_topo_ordering + 1,
end_topo_ordering + 1 if end_topo_ordering is not None else None,
),
)
overlapping_receipts = txn.fetchall()

View file

@ -0,0 +1,23 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE event_push_actions_staging
ADD COLUMN thread_id TEXT DEFAULT NULL;
ALTER TABLE event_push_actions
ADD COLUMN thread_id TEXT DEFAULT NULL;
ALTER TABLE event_push_summary
ADD COLUMN thread_id TEXT DEFAULT NULL;

View file

@ -393,6 +393,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event.event_id,
{user_id: actions for user_id, actions in push_actions},
False,
None,
)
)
return event, context

View file

@ -79,6 +79,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
event.event_id,
{user_id: action},
False,
None,
)
)
self.get_success(