Support operating on room tags for push rules.

This commit is contained in:
Patrick Cloke 2023-02-27 12:33:43 -05:00
parent e47d971ccb
commit 7fb013adea
8 changed files with 97 additions and 4 deletions

View file

@ -49,6 +49,7 @@ fn bench_match_exact(b: &mut Bencher) {
Some(0), Some(0),
Default::default(), Default::default(),
Default::default(), Default::default(),
Default::default(),
true, true,
vec![], vec![],
false, false,
@ -97,6 +98,7 @@ fn bench_match_word(b: &mut Bencher) {
Some(0), Some(0),
Default::default(), Default::default(),
Default::default(), Default::default(),
Default::default(),
true, true,
vec![], vec![],
false, false,
@ -145,6 +147,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
Some(0), Some(0),
Default::default(), Default::default(),
Default::default(), Default::default(),
Default::default(),
true, true,
vec![], vec![],
false, false,
@ -193,6 +196,7 @@ fn bench_eval_message(b: &mut Bencher) {
Some(0), Some(0),
Default::default(), Default::default(),
Default::default(), Default::default(),
Default::default(),
true, true,
vec![], vec![],
false, false,

View file

@ -85,6 +85,9 @@ pub struct PushRuleEvaluator {
/// outlier. /// outlier.
sender_power_level: Option<i64>, sender_power_level: Option<i64>,
// User's tags for this event's room.
tags_by_user: BTreeMap<String, BTreeSet<String>>,
/// The related events, indexed by relation type. Flattened in the same manner as /// The related events, indexed by relation type. Flattened in the same manner as
/// `flattened_keys`. /// `flattened_keys`.
related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>, related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>,
@ -118,6 +121,7 @@ impl PushRuleEvaluator {
room_member_count: u64, room_member_count: u64,
sender_power_level: Option<i64>, sender_power_level: Option<i64>,
notification_power_levels: BTreeMap<String, i64>, notification_power_levels: BTreeMap<String, i64>,
tags_by_user: BTreeMap<String, BTreeSet<String>>,
related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>, related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>,
related_event_match_enabled: bool, related_event_match_enabled: bool,
room_version_feature_flags: Vec<String>, room_version_feature_flags: Vec<String>,
@ -138,6 +142,7 @@ impl PushRuleEvaluator {
room_member_count, room_member_count,
notification_power_levels, notification_power_levels,
sender_power_level, sender_power_level,
tags_by_user,
related_events_flattened, related_events_flattened,
related_event_match_enabled, related_event_match_enabled,
room_version_feature_flags, room_version_feature_flags,
@ -340,6 +345,10 @@ impl PushRuleEvaluator {
false false
} }
} }
KnownCondition::RoomTag { tag } => match (user_id, tag) {
(Some(user_id), tag) => self.match_tag(user_id, tag)?,
_ => false,
},
KnownCondition::SenderNotificationPermission { key } => { KnownCondition::SenderNotificationPermission { key } => {
if let Some(sender_power_level) = &self.sender_power_level { if let Some(sender_power_level) = &self.sender_power_level {
let required_level = self let required_level = self
@ -498,6 +507,15 @@ impl PushRuleEvaluator {
Ok(matches) Ok(matches)
} }
/// Match if any of the room's tags for the given user exist.
fn match_tag(&self, user_id: &str, tag: &str) -> Result<bool, Error> {
if let Some(tags) = self.tags_by_user.get(user_id) {
Ok(tags.contains(tag))
} else {
Ok(false)
}
}
} }
#[test] #[test]
@ -515,6 +533,7 @@ fn push_rule_evaluator() {
Some(0), Some(0),
BTreeMap::new(), BTreeMap::new(),
BTreeMap::new(), BTreeMap::new(),
BTreeMap::new(),
true, true,
vec![], vec![],
true, true,
@ -547,6 +566,7 @@ fn test_requires_room_version_supports_condition() {
Some(0), Some(0),
BTreeMap::new(), BTreeMap::new(),
BTreeMap::new(), BTreeMap::new(),
BTreeMap::new(),
false, false,
flags, flags,
true, true,

View file

@ -347,6 +347,10 @@ pub enum KnownCondition {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
is: Option<Cow<'static, str>>, is: Option<Cow<'static, str>>,
}, },
#[serde(rename = "org.matrix.msc3964.room_tag")]
RoomTag {
tag: Cow<'static, str>,
},
SenderNotificationPermission { SenderNotificationPermission {
key: Cow<'static, str>, key: Cow<'static, str>,
}, },

View file

@ -12,7 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union from typing import (
AbstractSet,
Any,
Collection,
Dict,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from synapse.types import JsonDict, JsonValue from synapse.types import JsonDict, JsonValue
@ -62,6 +73,7 @@ class PushRuleEvaluator:
room_member_count: int, room_member_count: int,
sender_power_level: Optional[int], sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int], notification_power_levels: Mapping[str, int],
tags_by_user: Mapping[str, AbstractSet[str]],
related_events_flattened: Mapping[str, Mapping[str, JsonValue]], related_events_flattened: Mapping[str, Mapping[str, JsonValue]],
related_event_match_enabled: bool, related_event_match_enabled: bool,
room_version_feature_flags: Tuple[str, ...], room_version_feature_flags: Tuple[str, ...],

View file

@ -194,3 +194,8 @@ class ExperimentalConfig(Config):
self.msc3966_exact_event_property_contains = experimental.get( self.msc3966_exact_event_property_contains = experimental.get(
"msc3966_exact_event_property_contains", False "msc3966_exact_event_property_contains", False
) )
# MSC3964: Notifications for room tags.
self.msc3964_notifications_for_room_tags = experimental.get(
"msc3964_notifications_for_room_tags", False
)

View file

@ -409,6 +409,12 @@ class BulkPushRuleEvaluator:
filter(lambda item: isinstance(item, str), user_mentions_raw) filter(lambda item: isinstance(item, str), user_mentions_raw)
) )
# Fetch the room's tags for each user.
if self.hs.config.experimental.msc3964_notifications_for_room_tags:
tags_by_user = await self.store.get_all_users_tags_for_room(event.room_id)
else:
tags_by_user = {}
evaluator = PushRuleEvaluator( evaluator = PushRuleEvaluator(
_flatten_dict( _flatten_dict(
event, event,
@ -419,6 +425,7 @@ class BulkPushRuleEvaluator:
room_member_count, room_member_count,
sender_power_level, sender_power_level,
notification_levels, notification_levels,
tags_by_user,
related_events, related_events,
self._related_event_match_enabled, self._related_event_match_enabled,
event.room_version.msc3931_push_features, event.room_version.msc3931_push_features,

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast from typing import AbstractSet, Any, Dict, Iterable, List, Mapping, Set, Tuple, cast
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream from synapse.replication.tcp.streams import AccountDataStream
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore): class TagsWorkerStore(AccountDataWorkerStore):
@cached() @cached(iterable=True)
async def get_tags_for_user( async def get_tags_for_user(
self, user_id: str self, user_id: str
) -> Mapping[str, Mapping[str, JsonDict]]: ) -> Mapping[str, Mapping[str, JsonDict]]:
@ -55,6 +55,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
room_tags[row["tag"]] = db_to_json(row["content"]) room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room return tags_by_room
@cached(iterable=True)
async def get_all_users_tags_for_room(
self, room_id: str
) -> Mapping[str, AbstractSet[str]]:
"""Get all the tags for a room.
Args:
room_id: The room to get the tags for.
Returns:
A mapping from user IDs to a list of room tags.
"""
rows = await self.db_pool.simple_select_list(
"room_tags", {"room_id": room_id}, ["user_id", "tag"]
)
tags_by_user: Dict[str, Set[str]] = {}
for row in rows:
tags_by_user.setdefault(row["user_id"], set()).add(row["tag"])
return tags_by_user
async def get_all_updated_tags( async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, str, str]], int, bool]: ) -> Tuple[List[Tuple[int, str, str]], int, bool]:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Set, Union, cast from typing import AbstractSet, Any, Dict, List, Optional, Set, Union, cast
import frozendict import frozendict
@ -149,6 +149,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
*, *,
has_mentions: bool = False, has_mentions: bool = False,
user_mentions: Optional[Set[str]] = None, user_mentions: Optional[Set[str]] = None,
tags_by_user: Optional[Dict[str, AbstractSet[str]]] = None,
related_events: Optional[JsonDict] = None, related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator: ) -> PushRuleEvaluator:
event = FrozenEvent( event = FrozenEvent(
@ -172,6 +173,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_member_count, room_member_count,
sender_power_level, sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})), cast(Dict[str, int], power_levels.get("notifications", {})),
tags_by_user or {},
{} if related_events is None else related_events, {} if related_events is None else related_events,
related_event_match_enabled=True, related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features, room_version_feature_flags=event.room_version.msc3931_push_features,
@ -844,6 +846,24 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
) )
) )
def test_room_tags(self) -> None:
"""Ensure that matching by room tag works."""
condition = {"kind": "org.matrix.msc3964.room_tag", "tag": "foo"}
# If the user has no tags it should not match.
evaluator = self._get_evaluator({})
self.assertFalse(evaluator.matches(condition, "@user:test", "display_name"))
evaluator = self._get_evaluator({}, tags_by_user={"@user:test": set()})
self.assertFalse(evaluator.matches(condition, "@user:test", "display_name"))
# If the user has *other* tags it should not match.
evaluator = self._get_evaluator({}, tags_by_user={"@user:test": {"bar"}})
self.assertFalse(evaluator.matches(condition, "@user:test", "display_name"))
# If the user has at least the given tag it should match.
evaluator = self._get_evaluator({}, tags_by_user={"@user:test": {"foo", "bar"}})
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"))
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
"""Tests for the bulk push rule evaluator""" """Tests for the bulk push rule evaluator"""