diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 9a871f5693..a3dcd72ddf 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -49,6 +49,7 @@ fn bench_match_exact(b: &mut Bencher) { Some(0), Default::default(), Default::default(), + Default::default(), true, vec![], false, @@ -97,6 +98,7 @@ fn bench_match_word(b: &mut Bencher) { Some(0), Default::default(), Default::default(), + Default::default(), true, vec![], false, @@ -145,6 +147,7 @@ fn bench_match_word_miss(b: &mut Bencher) { Some(0), Default::default(), Default::default(), + Default::default(), true, vec![], false, @@ -193,6 +196,7 @@ fn bench_eval_message(b: &mut Bencher) { Some(0), Default::default(), Default::default(), + Default::default(), true, vec![], false, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index a65c645caf..86b4f2ea3e 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -85,6 +85,9 @@ pub struct PushRuleEvaluator { /// outlier. sender_power_level: Option, + // User's tags for this event's room. + tags_by_user: BTreeMap>, + /// The related events, indexed by relation type. Flattened in the same manner as /// `flattened_keys`. related_events_flattened: BTreeMap>, @@ -118,6 +121,7 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, + tags_by_user: BTreeMap>, related_events_flattened: BTreeMap>, related_event_match_enabled: bool, room_version_feature_flags: Vec, @@ -138,6 +142,7 @@ impl PushRuleEvaluator { room_member_count, notification_power_levels, sender_power_level, + tags_by_user, related_events_flattened, related_event_match_enabled, room_version_feature_flags, @@ -340,6 +345,10 @@ impl PushRuleEvaluator { false } } + KnownCondition::RoomTag { tag } => match (user_id, tag) { + (Some(user_id), tag) => self.match_tag(user_id, tag)?, + _ => false, + }, KnownCondition::SenderNotificationPermission { key } => { if let Some(sender_power_level) = &self.sender_power_level { let required_level = self @@ -498,6 +507,15 @@ impl PushRuleEvaluator { Ok(matches) } + + /// Match if any of the room's tags for the given user exist. + fn match_tag(&self, user_id: &str, tag: &str) -> Result { + if let Some(tags) = self.tags_by_user.get(user_id) { + Ok(tags.contains(tag)) + } else { + Ok(false) + } + } } #[test] @@ -515,6 +533,7 @@ fn push_rule_evaluator() { Some(0), BTreeMap::new(), BTreeMap::new(), + BTreeMap::new(), true, vec![], true, @@ -547,6 +566,7 @@ fn test_requires_room_version_supports_condition() { Some(0), BTreeMap::new(), BTreeMap::new(), + BTreeMap::new(), false, flags, true, diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index a56c7d36f5..9803e3523c 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -347,6 +347,10 @@ pub enum KnownCondition { #[serde(skip_serializing_if = "Option::is_none")] is: Option>, }, + #[serde(rename = "org.matrix.msc3964.room_tag")] + RoomTag { + tag: Cow<'static, str>, + }, SenderNotificationPermission { key: Cow<'static, str>, }, diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index a8f0ed2435..7bf15b925b 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union +from typing import ( + AbstractSet, + Any, + Collection, + Dict, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) from synapse.types import JsonDict, JsonValue @@ -62,6 +73,7 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], + tags_by_user: Mapping[str, AbstractSet[str]], related_events_flattened: Mapping[str, Mapping[str, JsonValue]], related_event_match_enabled: bool, room_version_feature_flags: Tuple[str, ...], diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index bc38fae0b6..51f849f7bd 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -194,3 +194,8 @@ class ExperimentalConfig(Config): self.msc3966_exact_event_property_contains = experimental.get( "msc3966_exact_event_property_contains", False ) + + # MSC3964: Notifications for room tags. + self.msc3964_notifications_for_room_tags = experimental.get( + "msc3964_notifications_for_room_tags", False + ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 3c4a152d6b..cb3a2766bc 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -409,6 +409,12 @@ class BulkPushRuleEvaluator: filter(lambda item: isinstance(item, str), user_mentions_raw) ) + # Fetch the room's tags for each user. + if self.hs.config.experimental.msc3964_notifications_for_room_tags: + tags_by_user = await self.store.get_all_users_tags_for_room(event.room_id) + else: + tags_by_user = {} + evaluator = PushRuleEvaluator( _flatten_dict( event, @@ -419,6 +425,7 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, + tags_by_user, related_events, self._related_event_match_enabled, event.room_version.msc3931_push_features, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index c149a9eacb..663aeec32b 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast +from typing import AbstractSet, Any, Dict, Iterable, List, Mapping, Set, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.tcp.streams import AccountDataStream @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) class TagsWorkerStore(AccountDataWorkerStore): - @cached() + @cached(iterable=True) async def get_tags_for_user( self, user_id: str ) -> Mapping[str, Mapping[str, JsonDict]]: @@ -55,6 +55,27 @@ class TagsWorkerStore(AccountDataWorkerStore): room_tags[row["tag"]] = db_to_json(row["content"]) return tags_by_room + @cached(iterable=True) + async def get_all_users_tags_for_room( + self, room_id: str + ) -> Mapping[str, AbstractSet[str]]: + """Get all the tags for a room. + + Args: + room_id: The room to get the tags for. + + Returns: + A mapping from user IDs to a list of room tags. + """ + + rows = await self.db_pool.simple_select_list( + "room_tags", {"room_id": room_id}, ["user_id", "tag"] + ) + tags_by_user: Dict[str, Set[str]] = {} + for row in rows: + tags_by_user.setdefault(row["user_id"], set()).add(row["tag"]) + return tags_by_user + async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, str, str]], int, bool]: diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 1d30e3c3e4..ef4e37406c 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Set, Union, cast +from typing import AbstractSet, Any, Dict, List, Optional, Set, Union, cast import frozendict @@ -149,6 +149,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): *, has_mentions: bool = False, user_mentions: Optional[Set[str]] = None, + tags_by_user: Optional[Dict[str, AbstractSet[str]]] = None, related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( @@ -172,6 +173,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, cast(Dict[str, int], power_levels.get("notifications", {})), + tags_by_user or {}, {} if related_events is None else related_events, related_event_match_enabled=True, room_version_feature_flags=event.room_version.msc3931_push_features, @@ -844,6 +846,24 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ) ) + def test_room_tags(self) -> None: + """Ensure that matching by room tag works.""" + condition = {"kind": "org.matrix.msc3964.room_tag", "tag": "foo"} + + # If the user has no tags it should not match. + evaluator = self._get_evaluator({}) + self.assertFalse(evaluator.matches(condition, "@user:test", "display_name")) + evaluator = self._get_evaluator({}, tags_by_user={"@user:test": set()}) + self.assertFalse(evaluator.matches(condition, "@user:test", "display_name")) + + # If the user has *other* tags it should not match. + evaluator = self._get_evaluator({}, tags_by_user={"@user:test": {"bar"}}) + self.assertFalse(evaluator.matches(condition, "@user:test", "display_name")) + + # If the user has at least the given tag it should match. + evaluator = self._get_evaluator({}, tags_by_user={"@user:test": {"foo", "bar"}}) + self.assertTrue(evaluator.matches(condition, "@user:test", "display_name")) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator"""