Support operating on room tags for push rules.

This commit is contained in:
Patrick Cloke 2023-02-27 12:33:43 -05:00
parent e47d971ccb
commit 7fb013adea
8 changed files with 97 additions and 4 deletions

View file

@ -49,6 +49,7 @@ fn bench_match_exact(b: &mut Bencher) {
Some(0),
Default::default(),
Default::default(),
Default::default(),
true,
vec![],
false,
@ -97,6 +98,7 @@ fn bench_match_word(b: &mut Bencher) {
Some(0),
Default::default(),
Default::default(),
Default::default(),
true,
vec![],
false,
@ -145,6 +147,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
Some(0),
Default::default(),
Default::default(),
Default::default(),
true,
vec![],
false,
@ -193,6 +196,7 @@ fn bench_eval_message(b: &mut Bencher) {
Some(0),
Default::default(),
Default::default(),
Default::default(),
true,
vec![],
false,

View file

@ -85,6 +85,9 @@ pub struct PushRuleEvaluator {
/// outlier.
sender_power_level: Option<i64>,
// User's tags for this event's room.
tags_by_user: BTreeMap<String, BTreeSet<String>>,
/// The related events, indexed by relation type. Flattened in the same manner as
/// `flattened_keys`.
related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>,
@ -118,6 +121,7 @@ impl PushRuleEvaluator {
room_member_count: u64,
sender_power_level: Option<i64>,
notification_power_levels: BTreeMap<String, i64>,
tags_by_user: BTreeMap<String, BTreeSet<String>>,
related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>,
related_event_match_enabled: bool,
room_version_feature_flags: Vec<String>,
@ -138,6 +142,7 @@ impl PushRuleEvaluator {
room_member_count,
notification_power_levels,
sender_power_level,
tags_by_user,
related_events_flattened,
related_event_match_enabled,
room_version_feature_flags,
@ -340,6 +345,10 @@ impl PushRuleEvaluator {
false
}
}
KnownCondition::RoomTag { tag } => match (user_id, tag) {
(Some(user_id), tag) => self.match_tag(user_id, tag)?,
_ => false,
},
KnownCondition::SenderNotificationPermission { key } => {
if let Some(sender_power_level) = &self.sender_power_level {
let required_level = self
@ -498,6 +507,15 @@ impl PushRuleEvaluator {
Ok(matches)
}
/// Match if any of the room's tags for the given user exist.
fn match_tag(&self, user_id: &str, tag: &str) -> Result<bool, Error> {
if let Some(tags) = self.tags_by_user.get(user_id) {
Ok(tags.contains(tag))
} else {
Ok(false)
}
}
}
#[test]
@ -515,6 +533,7 @@ fn push_rule_evaluator() {
Some(0),
BTreeMap::new(),
BTreeMap::new(),
BTreeMap::new(),
true,
vec![],
true,
@ -547,6 +566,7 @@ fn test_requires_room_version_supports_condition() {
Some(0),
BTreeMap::new(),
BTreeMap::new(),
BTreeMap::new(),
false,
flags,
true,

View file

@ -347,6 +347,10 @@ pub enum KnownCondition {
#[serde(skip_serializing_if = "Option::is_none")]
is: Option<Cow<'static, str>>,
},
#[serde(rename = "org.matrix.msc3964.room_tag")]
RoomTag {
tag: Cow<'static, str>,
},
SenderNotificationPermission {
key: Cow<'static, str>,
},

View file

@ -12,7 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
from typing import (
AbstractSet,
Any,
Collection,
Dict,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from synapse.types import JsonDict, JsonValue
@ -62,6 +73,7 @@ class PushRuleEvaluator:
room_member_count: int,
sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int],
tags_by_user: Mapping[str, AbstractSet[str]],
related_events_flattened: Mapping[str, Mapping[str, JsonValue]],
related_event_match_enabled: bool,
room_version_feature_flags: Tuple[str, ...],

View file

@ -194,3 +194,8 @@ class ExperimentalConfig(Config):
self.msc3966_exact_event_property_contains = experimental.get(
"msc3966_exact_event_property_contains", False
)
# MSC3964: Notifications for room tags.
self.msc3964_notifications_for_room_tags = experimental.get(
"msc3964_notifications_for_room_tags", False
)

View file

@ -409,6 +409,12 @@ class BulkPushRuleEvaluator:
filter(lambda item: isinstance(item, str), user_mentions_raw)
)
# Fetch the room's tags for each user.
if self.hs.config.experimental.msc3964_notifications_for_room_tags:
tags_by_user = await self.store.get_all_users_tags_for_room(event.room_id)
else:
tags_by_user = {}
evaluator = PushRuleEvaluator(
_flatten_dict(
event,
@ -419,6 +425,7 @@ class BulkPushRuleEvaluator:
room_member_count,
sender_power_level,
notification_levels,
tags_by_user,
related_events,
self._related_event_match_enabled,
event.room_version.msc3931_push_features,

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
from typing import AbstractSet, Any, Dict, Iterable, List, Mapping, Set, Tuple, cast
from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
@cached(iterable=True)
async def get_tags_for_user(
self, user_id: str
) -> Mapping[str, Mapping[str, JsonDict]]:
@ -55,6 +55,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room
@cached(iterable=True)
async def get_all_users_tags_for_room(
self, room_id: str
) -> Mapping[str, AbstractSet[str]]:
"""Get all the tags for a room.
Args:
room_id: The room to get the tags for.
Returns:
A mapping from user IDs to a list of room tags.
"""
rows = await self.db_pool.simple_select_list(
"room_tags", {"room_id": room_id}, ["user_id", "tag"]
)
tags_by_user: Dict[str, Set[str]] = {}
for row in rows:
tags_by_user.setdefault(row["user_id"], set()).add(row["tag"])
return tags_by_user
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, str, str]], int, bool]:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Set, Union, cast
from typing import AbstractSet, Any, Dict, List, Optional, Set, Union, cast
import frozendict
@ -149,6 +149,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
*,
has_mentions: bool = False,
user_mentions: Optional[Set[str]] = None,
tags_by_user: Optional[Dict[str, AbstractSet[str]]] = None,
related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
@ -172,6 +173,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_member_count,
sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})),
tags_by_user or {},
{} if related_events is None else related_events,
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
@ -844,6 +846,24 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
)
)
def test_room_tags(self) -> None:
"""Ensure that matching by room tag works."""
condition = {"kind": "org.matrix.msc3964.room_tag", "tag": "foo"}
# If the user has no tags it should not match.
evaluator = self._get_evaluator({})
self.assertFalse(evaluator.matches(condition, "@user:test", "display_name"))
evaluator = self._get_evaluator({}, tags_by_user={"@user:test": set()})
self.assertFalse(evaluator.matches(condition, "@user:test", "display_name"))
# If the user has *other* tags it should not match.
evaluator = self._get_evaluator({}, tags_by_user={"@user:test": {"bar"}})
self.assertFalse(evaluator.matches(condition, "@user:test", "display_name"))
# If the user has at least the given tag it should match.
evaluator = self._get_evaluator({}, tags_by_user={"@user:test": {"foo", "bar"}})
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"))
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
"""Tests for the bulk push rule evaluator"""