diff --git a/changelog.d/12740.feature b/changelog.d/12740.feature new file mode 100644 index 0000000000..e674c31ae8 --- /dev/null +++ b/changelog.d/12740.feature @@ -0,0 +1 @@ +Experimental support for [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772): Push rule for mutually related events. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b20d949689..cc417e2fbf 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -84,3 +84,6 @@ class ExperimentalConfig(Config): # MSC3786 (Add a default push rule to ignore m.room.server_acl events) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) + + # MSC3772: A push rule for mutual relations. + self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index a17b35a605..4c7278b5a1 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -139,6 +139,7 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [ { "kind": "event_match", "key": "content.body", + # Match the localpart of the requester's MXID. "pattern_type": "user_localpart", } ], @@ -191,6 +192,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ "pattern": "invite", "_cache_key": "_invite_member", }, + # Match the requester's MXID. {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, ], "actions": [ @@ -350,6 +352,18 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [ {"set_tweak": "highlight", "value": False}, ], }, + { + "rule_id": "global/underride/.org.matrix.msc3772.thread_reply", + "conditions": [ + { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.thread", + # Match the requester's MXID. + "sender_type": "user_id", + } + ], + "actions": ["notify", {"set_tweak": "highlight", "value": False}], + }, { "rule_id": "global/underride/.m.rule.message", "conditions": [ diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4cc8a2ecca..1a8e7ef3dc 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union import attr from prometheus_client import Counter @@ -121,6 +122,9 @@ class BulkPushRuleEvaluator: resizable=False, ) + # Whether to support MSC3772 is supported. + self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled + async def _get_rules_for_event( self, event: EventBase, context: EventContext ) -> Dict[str, List[Dict[str, Any]]]: @@ -192,6 +196,60 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + async def _get_mutual_relations( + self, event: EventBase, rules: Iterable[Dict[str, Any]] + ) -> Dict[str, Set[Tuple[str, str]]]: + """ + Fetch event metadata for events which related to the same event as the given event. + + If the given event has no relation information, returns an empty dictionary. + + Args: + event_id: The event ID which is targeted by relations. + rules: The push rules which will be processed for this event. + + Returns: + A dictionary of relation type to: + A set of tuples of: + The sender + The event type + """ + + # If the experimental feature is not enabled, skip fetching relations. + if not self._relations_match_enabled: + return {} + + # If the event does not have a relation, then cannot have any mutual + # relations. + relation = relation_from_event(event) + if not relation: + return {} + + # Pre-filter to figure out which relation types are interesting. + rel_types = set() + for rule in rules: + # Skip disabled rules. + if "enabled" in rule and not rule["enabled"]: + continue + + for condition in rule["conditions"]: + if condition["kind"] != "org.matrix.msc3772.relation_match": + continue + + # rel_type is required. + rel_type = condition.get("rel_type") + if rel_type: + rel_types.add(rel_type) + + # If no valid rules were found, no mutual relations. + if not rel_types: + return {} + + # If any valid rules were found, fetch the mutual relations. + return await self.store.get_mutual_event_relations( + relation.parent_id, rel_types + ) + @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -216,8 +274,17 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + relations = await self._get_mutual_relations( + event, itertools.chain(*rules_by_user.values()) + ) + evaluator = PushRuleEvaluatorForEvent( - event, len(room_members), sender_power_level, power_levels + event, + len(room_members), + sender_power_level, + power_levels, + relations, + self._relations_match_enabled, ) # If the event is not a state event check if any users ignore the sender. diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 63b22d50ae..5117ef6854 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -48,6 +48,10 @@ def format_push_rules_for_user( elif pattern_type == "user_localpart": c["pattern"] = user.localpart + sender_type = c.pop("sender_type", None) + if sender_type == "user_id": + c["sender"] = user.to_string() + rulearray = rules["global"][template_name] template_rule = _rule_to_template(r) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 54db6b5612..2e8a017add 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,7 +15,7 @@ import logging import re -from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union from matrix_common.regex import glob_to_regex, to_word_pattern @@ -120,11 +120,15 @@ class PushRuleEvaluatorForEvent: room_member_count: int, sender_power_level: int, power_levels: Dict[str, Union[int, Dict[str, int]]], + relations: Dict[str, Set[Tuple[str, str]]], + relations_match_enabled: bool, ): self._event = event self._room_member_count = room_member_count self._sender_power_level = sender_power_level self._power_levels = power_levels + self._relations = relations + self._relations_match_enabled = relations_match_enabled # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) @@ -188,7 +192,16 @@ class PushRuleEvaluatorForEvent: return _sender_notification_permission( self._event, condition, self._sender_power_level, self._power_levels ) + elif ( + condition["kind"] == "org.matrix.msc3772.relation_match" + and self._relations_match_enabled + ): + return self._relation_match(condition, user_id) else: + # XXX This looks incorrect -- we have reached an unknown condition + # kind and are unconditionally returning that it matches. Note + # that it seems possible to provide a condition to the /pushrules + # endpoint with an unknown kind, see _rule_tuple_from_request_object. return True def _event_match(self, condition: dict, user_id: str) -> bool: @@ -256,6 +269,41 @@ class PushRuleEvaluatorForEvent: return bool(r.search(body)) + def _relation_match(self, condition: dict, user_id: str) -> bool: + """ + Check an "relation_match" push rule condition. + + Args: + condition: The "event_match" push rule condition to match. + user_id: The user's MXID. + + Returns: + True if the condition matches the event, False otherwise. + """ + rel_type = condition.get("rel_type") + if not rel_type: + logger.warning("relation_match condition missing rel_type") + return False + + sender_pattern = condition.get("sender") + if sender_pattern is None: + sender_type = condition.get("sender_type") + if sender_type == "user_id": + sender_pattern = user_id + type_pattern = condition.get("type") + + # If any other relations matches, return True. + for sender, event_type in self._relations.get(rel_type, ()): + if sender_pattern and not _glob_matches(sender_pattern, sender): + continue + if type_pattern and not _glob_matches(type_pattern, event_type): + continue + # All values must have matched. + return True + + # No relations matched. + return False + # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0df8ff5395..17e35cf63e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1828,6 +1828,10 @@ class PersistEventsStore: self.store.get_aggregation_groups_for_event.invalidate, (relation.parent_id,), ) + txn.call_after( + self.store.get_mutual_event_relations_for_rel_type.invalidate, + (relation.parent_id,), + ) if relation.rel_type == RelationTypes.REPLACE: txn.call_after( @@ -2004,6 +2008,11 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, + self.store.get_mutual_event_relations_for_rel_type, + (redacted_relates_to,), + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index ad67901cc1..4adabc88cc 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -61,6 +61,11 @@ def _is_experimental_rule_enabled( and not experimental_config.msc3786_enabled ): return False + if ( + rule_id == "global/underride/.org.matrix.msc3772.thread_reply" + and not experimental_config.msc3772_enabled + ): + return False return True diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index fe8fded88b..3b1b2ce6cb 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from collections import defaultdict from typing import ( Collection, Dict, @@ -767,6 +768,57 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(iterable=True) + async def get_mutual_event_relations_for_rel_type( + self, event_id: str, relation_type: str + ) -> Set[Tuple[str, str]]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_mutual_event_relations_for_rel_type", + list_name="relation_types", + ) + async def get_mutual_event_relations( + self, event_id: str, relation_types: Collection[str] + ) -> Dict[str, Set[Tuple[str, str]]]: + """ + Fetch event metadata for events which related to the same event as the given event. + + If the given event has no relation information, returns an empty dictionary. + + Args: + event_id: The event ID which is targeted by relations. + relation_types: The relation types to check for mutual relations. + + Returns: + A dictionary of relation type to: + A set of tuples of: + The sender + The event type + """ + rel_type_sql, rel_type_args = make_in_list_sql_clause( + self.database_engine, "relation_type", relation_types + ) + + sql = f""" + SELECT DISTINCT relation_type, sender, type FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND {rel_type_sql} + """ + + def _get_event_relations( + txn: LoggingTransaction, + ) -> Dict[str, Set[Tuple[str, str]]]: + txn.execute(sql, [event_id] + rel_type_args) + result = defaultdict(set) + for rel_type, sender, type in txn.fetchall(): + result[rel_type].add((sender, type)) + return result + + return await self.db_pool.runInteraction( + "get_event_relations", _get_event_relations + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 5dba187076..9b623d0033 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union +from typing import Dict, Optional, Set, Tuple, Union import frozendict @@ -26,7 +26,12 @@ from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent: + def _get_evaluator( + self, + content: JsonDict, + relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, + relations_match_enabled: bool = False, + ) -> PushRuleEvaluatorForEvent: event = FrozenEvent( { "event_id": "$event_id", @@ -42,7 +47,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): sender_power_level = 0 power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluatorForEvent( - event, room_member_count, sender_power_level, power_levels + event, + room_member_count, + sender_power_level, + power_levels, + relations or set(), + relations_match_enabled, ) def test_display_name(self) -> None: @@ -276,3 +286,71 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): push_rule_evaluator.tweaks_for_actions(actions), {"sound": "default", "highlight": True}, ) + + def test_relation_match(self) -> None: + """Test the relation_match push rule kind.""" + + # Check if the experimental feature is disabled. + evaluator = self._get_evaluator( + {}, {"m.annotation": {("@user:test", "m.reaction")}} + ) + condition = {"kind": "relation_match"} + # Oddly, an unknown condition always matches. + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # A push rule evaluator with the experimental rule enabled. + evaluator = self._get_evaluator( + {}, {"m.annotation": {("@user:test", "m.reaction")}}, True + ) + + # Check just relation type. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check relation type and sender. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@user:test", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@other:test", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + # Check relation type and event type. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "type": "m.reaction", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check just sender, this fails since rel_type is required. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "sender": "@user:test", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + # Check sender glob. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "sender": "@*:test", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + # Check event type glob. + condition = { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.annotation", + "event_type": "*.reaction", + } + self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))