Don't split at word boundaries, actually use regex

This commit is contained in:
Erik Johnston 2016-01-18 16:48:17 +00:00
parent d16dcf642e
commit 29c353c553
2 changed files with 46 additions and 61 deletions

View file

@ -81,7 +81,7 @@ class BulkPushRuleEvaluator:
users_dict.items(), [event]
)
evaluator = PushRuleEvaluatorForEvent.create(event, len(self.users_in_room))
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
condition_cache = {}

View file

@ -127,7 +127,7 @@ class PushRuleEvaluator:
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
evaluator = PushRuleEvaluatorForEvent.create(ev, room_member_count)
evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
for r in self.rules:
if self.enabled_map.get(r['rule_id'], None) is False:
@ -180,33 +180,13 @@ class PushRuleEvaluator:
class PushRuleEvaluatorForEvent(object):
WORD_BOUNDARY = re.compile(r'\b')
def __init__(self, event, body_parts, room_member_count):
def __init__(self, event, room_member_count):
self._event = event
# This is a list of words of the content.body (if event has one). Each
# word has been converted to lower case.
self._body_parts = body_parts
self._room_member_count = room_member_count
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
@staticmethod
def create(event, room_member_count):
body = event.get("content", {}).get("body", None)
if body:
body_parts = PushRuleEvaluatorForEvent.WORD_BOUNDARY.split(body)
body_parts[:] = [
part.lower() for part in body_parts
]
else:
body_parts = []
return PushRuleEvaluatorForEvent(event, body_parts, room_member_count)
def matches(self, condition, user_id, display_name, profile_tag):
if condition['kind'] == 'event_match':
return self._event_match(condition, user_id)
@ -239,67 +219,72 @@ class PushRuleEvaluatorForEvent(object):
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
matcher = _glob_to_matcher(pattern)
body = self._event["content"].get("body", None)
if not body:
return False
for part in self._body_parts:
if matcher(part):
return True
return False
return _glob_matches(pattern, body, word_boundary=True)
else:
haystack = self._get_value(condition['key'])
if haystack is None:
return False
matcher = _glob_to_matcher(pattern)
return matcher(haystack.lower())
return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name):
if not display_name:
return False
lower_display_name = display_name.lower()
for part in self._body_parts:
if part == lower_display_name:
return True
body = self._event["content"].get("body", None)
if not body:
return False
return False
return _glob_matches(display_name, body, word_boundary=True)
def _get_value(self, dotted_key):
return self._value_cache.get(dotted_key, None)
def _glob_to_matcher(glob):
"""Takes a glob and returns a `func(string) -> bool`, which returns if the
string matches the glob. Assumes given string is lower case.
def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob.
The matcher returned is either a simple string comparison for globs without
wildcards, or a regex matcher for globs with wildcards.
Args:
glob (string)
value (string): String to test against glob.
word_boundary (bool): Whether to match against word boundaries or entire
string. Defaults to False.
Returns:
bool
"""
glob = glob.lower()
if IS_GLOB.search(glob):
r = re.escape(glob)
if not IS_GLOB.search(glob):
return lambda value: value == glob
r = r.replace(r'\*', '.*?')
r = r.replace(r'\?', '.')
r = re.escape(glob)
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
'[%s%s]' % (
x.group(1) and '^' or '',
x.group(2).replace(r'\\\-', '-')
)
),
r,
)
r = r + "$"
r = re.compile(r, flags=re.IGNORECASE)
r = r.replace(r'\*', '.*?')
r = r.replace(r'\?', '.')
return r.match(value)
elif word_boundary:
r = re.escape(glob)
r = "\b%s\b" % (r,)
r = re.compile(r, flags=re.IGNORECASE)
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
'[%s%s]' % (
x.group(1) and '^' or '',
x.group(2).replace(r'\\\-', '-')
)
),
r,
)
r = r + "$"
r = re.compile(r)
return lambda value: r.match(value)
return r.search(value)
else:
return value.lower() == glob.lower()
def _flatten_dict(d, prefix=[], result={}):