Merge branch 'develop' into matthew/brand-from-header

This commit is contained in:
Matthew Hodgson 2016-06-03 12:14:18 +01:00
commit 8d740132f4
35 changed files with 691 additions and 207 deletions

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""This module contains classes for authenticating the user."""
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
@ -42,13 +41,20 @@ AuthEventTypes = (
class Auth(object): class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ", "gen = ",
"guest = ", "guest = ",
@ -525,7 +531,7 @@ class Auth(object):
return default return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False): def get_user_by_req(self, request, allow_guest=False, rights="access"):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -547,7 +553,7 @@ class Auth(object):
) )
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token) user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
@ -608,7 +614,7 @@ class Auth(object):
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token): def get_user_by_access_token(self, token, rights="access"):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -619,7 +625,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: try:
ret = yield self.get_user_from_macaroon(token) ret = yield self.get_user_from_macaroon(token, rights)
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
@ -627,11 +633,11 @@ class Auth(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_from_macaroon(self, macaroon_str): def get_user_from_macaroon(self, macaroon_str, rights="access"):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token) self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None user = None
@ -654,6 +660,13 @@ class Auth(object):
"is_guest": True, "is_guest": True,
"token_id": None, "token_id": None,
} }
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"token_id": None,
}
else: else:
# This codepath exists so that we can actually return a # This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device # token ID, because we use token IDs in place of device
@ -685,7 +698,8 @@ class Auth(object):
Args: Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token this is (e.g. "access", "refresh") type_string(str): The kind of token required (e.g. "access", "refresh",
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet. token refresh, so we can't enforce expiry yet.

View file

@ -21,6 +21,7 @@ from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig from synapse.config.logger import LoggingConfig
from synapse.config.emailconfig import EmailConfig from synapse.config.emailconfig import EmailConfig
from synapse.config.key import KeyConfig
from synapse.http.site import SynapseSite from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
@ -63,6 +64,26 @@ class SlaveConfig(DatabaseConfig):
self.pid_file = self.abspath(config.get("pid_file")) self.pid_file = self.abspath(config.get("pid_file"))
self.public_baseurl = config["public_baseurl"] self.public_baseurl = config["public_baseurl"]
# some things used by the auth handler but not actually used in the
# pusher codebase
self.bcrypt_rounds = None
self.ldap_enabled = None
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = None
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
# We would otherwise try to use the registration shared secret as the
# macaroon shared secret if there was no macaroon_shared_secret, but
# that means pulling in RegistrationConfig too. We don't need to be
# backwards compaitible in the pusher codebase so just make people set
# macaroon_shared_secret. We set this to None to prevent it referencing
# an undefined key.
self.registration_shared_secret = None
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
pid_file = self.abspath("pusher.pid") pid_file = self.abspath("pusher.pid")
return """\ return """\
@ -95,7 +116,7 @@ class SlaveConfig(DatabaseConfig):
""" % locals() """ % locals()
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig): class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
pass pass

View file

@ -529,6 +529,11 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)

View file

@ -26,9 +26,9 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute, run_on_reactor
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -908,13 +908,16 @@ class MessageHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
with PreserveLoggingContext(): @defer.inlineCallbacks
# Don't block waiting on waking up all the listeners. def _notify():
yield run_on_reactor()
self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
preserve_fn(_notify)()
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)

View file

@ -198,9 +198,8 @@ class SyncHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules_for_user(self, user): def push_rules_for_user(self, user):
user_id = user.to_string() user_id = user.to_string()
rawrules = yield self.store.get_push_rules_for_user(user_id) rules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) rules = format_push_rules_for_user(user, rules)
rules = format_push_rules_for_user(user, rawrules, enabled_map)
defer.returnValue(rules) defer.returnValue(rules)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -33,11 +33,7 @@ from .metric import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# We'll keep all the available metrics in a single toplevel dict, one shared all_metrics = []
# for the entire process. We don't currently support per-HomeServer instances
# of metrics, because in practice any one python VM will host only one
# HomeServer anyway. This makes a lot of implementation neater
all_metrics = {}
class Metrics(object): class Metrics(object):
@ -53,7 +49,7 @@ class Metrics(object):
metric = metric_class(full_name, *args, **kwargs) metric = metric_class(full_name, *args, **kwargs)
all_metrics[full_name] = metric all_metrics.append(metric)
return metric return metric
def register_counter(self, *args, **kwargs): def register_counter(self, *args, **kwargs):
@ -84,12 +80,12 @@ def render_all():
# TODO(paul): Internal hack # TODO(paul): Internal hack
update_resource_metrics() update_resource_metrics()
for name in sorted(all_metrics.keys()): for metric in all_metrics:
try: try:
strs += all_metrics[name].render() strs += metric.render()
except Exception: except Exception:
strs += ["# FAILED to render %s" % name] strs += ["# FAILED to render"]
logger.exception("Failed to render %s metric", name) logger.exception("Failed to render metric")
strs.append("") # to generate a final CRLF strs.append("") # to generate a final CRLF

View file

@ -47,9 +47,6 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)]) for k, v in zip(self.labels, values)])
) )
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
class CounterMetric(BaseMetric): class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing """The simplest kind of metric; one that stores a monotonically-increasing
@ -83,6 +80,9 @@ class CounterMetric(BaseMetric):
def render_item(self, k): def render_item(self, k):
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
class CallbackMetric(BaseMetric): class CallbackMetric(BaseMetric):
"""A metric that returns the numeric value returned by a callback whenever """A metric that returns the numeric value returned by a callback whenever
@ -126,30 +126,30 @@ class DistributionMetric(object):
class CacheMetric(object): class CacheMetric(object):
"""A combination of two CounterMetrics, one to count cache hits and one to __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
count a total, and a callback metric to yield the current size.
This metric generates standard metric name pairs, so that monitoring rules def __init__(self, name, size_callback, cache_name):
can easily be applied to measure hit ratio."""
def __init__(self, name, size_callback, labels=[]):
self.name = name self.name = name
self.cache_name = cache_name
self.hits = CounterMetric(name + ":hits", labels=labels) self.hits = 0
self.total = CounterMetric(name + ":total", labels=labels) self.misses = 0
self.size = CallbackMetric( self.size_callback = size_callback
name + ":size",
callback=size_callback,
labels=labels,
)
def inc_hits(self, *values): def inc_hits(self):
self.hits.inc(*values) self.hits += 1
self.total.inc(*values)
def inc_misses(self, *values): def inc_misses(self):
self.total.inc(*values) self.misses += 1
def render(self): def render(self):
return self.hits.render() + self.total.render() + self.size.render() size = self.size_callback()
hits = self.hits
total = self.misses + self.hits
return [
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
]

View file

@ -40,7 +40,7 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "handle_push_actions_for_event"): with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store event, self.hs, self.store, context.current_state
) )
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(

View file

@ -18,10 +18,9 @@ import ujson as json
from twisted.internet import defer from twisted.internet import defer
from .baserules import list_with_base_rules
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, Membership
from synapse.visibility import filter_events_for_clients from synapse.visibility import filter_events_for_clients
@ -38,62 +37,41 @@ def decode_rule_json(rule):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_rules(room_id, user_ids, store): def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids) rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
rules_by_user = {
uid: list_with_base_rules([
decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, [])
])
for uid in user_ids
}
# We apply the rules-enabled map here: bulk_get_push_rules doesn't
# fetch disabled rules, but this won't account for any server default
# rules the user has disabled, so we need to do this too.
for uid in user_ids:
user_enabled_map = rules_enabled_by_user.get(uid)
if not user_enabled_map:
continue
for i, rule in enumerate(rules_by_user[uid]):
rule_id = rule['rule_id']
if rule_id in user_enabled_map:
if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(user_enabled_map[rule_id])
rules_by_user[uid][i] = rule
defer.returnValue(rules_by_user) defer.returnValue(rules_by_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store): def evaluator_for_event(event, hs, store, current_state):
room_id = event.room_id room_id = event.room_id
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
users_with_pushers = yield store.get_users_with_pushers_in_room(room_id)
# We also will want to generate notifs for other people in the room so # We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid # their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've # generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room. # sent a read receipt into the room.
all_in_room = yield store.get_users_in_room(room_id) local_users_in_room = set(
all_in_room = set(all_in_room) e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
)
receipts = yield store.get_receipts_for_room(room_id, "m.read") # users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers # any users with pushers must be ours: they have pushers
user_ids = set(users_with_pushers) for uid in users_with_receipts:
for r in receipts: if uid in local_users_in_room:
if hs.is_mine_id(r['user_id']) and r['user_id'] in all_in_room: user_ids.add(uid)
user_ids.add(r['user_id'])
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
@ -104,8 +82,6 @@ def evaluator_for_event(event, hs, store):
if has_pusher: if has_pusher:
user_ids.add(invited_user) user_ids.add(invited_user)
user_ids = list(user_ids)
rules_by_user = yield _get_rules(room_id, user_ids, store) rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
@ -143,7 +119,10 @@ class BulkPushRuleEvaluator:
self.store, user_tuples, [event], {event.event_id: current_state} self.store, user_tuples, [event], {event.event_id: current_state}
) )
room_members = yield self.store.get_users_in_room(self.room_id) room_members = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
)
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))

View file

@ -23,10 +23,7 @@ import copy
import simplejson as json import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map): def load_rules_for_user(user, rawrules, enabled_map):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
ruleslist = [] ruleslist = []
for rawrule in rawrules: for rawrule in rawrules:
rule = dict(rawrule) rule = dict(rawrule)
@ -35,7 +32,26 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
ruleslist.append(rule) ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule['rule_id']
if rule_id in enabled_map:
if rule.get('enabled', True) != bool(enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(enabled_map[rule_id])
rules[i] = rule
return rules
def format_push_rules_for_user(user, ruleslist):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
rules = {'global': {}, 'device': {}} rules = {'global': {}, 'device': {}}
@ -60,9 +76,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
template_rule = _rule_to_template(r) template_rule = _rule_to_template(r)
if template_rule: if template_rule:
if r['rule_id'] in enabled_map: if 'enabled' in r:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled'] template_rule['enabled'] = r['enabled']
else: else:
template_rule['enabled'] = True template_rule['enabled'] = True

View file

@ -279,5 +279,5 @@ class EmailPusher(object):
logger.info("Sending notif email for user %r", self.user_id) logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail( yield self.mailer.send_notification_mail(
self.user_id, self.email, push_actions, reason self.app_id, self.user_id, self.email, push_actions, reason
) )

View file

@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \ MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
"in the %s room..." "in the %(room)s room..."
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..." MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
@ -81,6 +81,7 @@ class Mailer(object):
def __init__(self, hs, app_name): def __init__(self, hs, app_name):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name self.app_name = app_name
@ -96,7 +97,8 @@ class Mailer(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_notification_mail(self, user_id, email_address, push_actions, reason): def send_notification_mail(self, app_id, user_id, email_address,
push_actions, reason):
try: try:
from_string = self.hs.config.email_notif_from % { from_string = self.hs.config.email_notif_from % {
"app": self.app_name "app": self.app_name
@ -167,7 +169,9 @@ class Mailer(object):
template_vars = { template_vars = {
"user_display_name": user_display_name, "user_display_name": user_display_name,
"unsubscribe_link": self.make_unsubscribe_link(), "unsubscribe_link": self.make_unsubscribe_link(
user_id, app_id, email_address
),
"summary_text": summary_text, "summary_text": summary_text,
"app_name": self.app_name, "app_name": self.app_name,
"rooms": rooms, "rooms": rooms,
@ -433,9 +437,18 @@ class Mailer(object):
notif['room_id'], notif['event_id'] notif['room_id'], notif['event_id']
) )
def make_unsubscribe_link(self): def make_unsubscribe_link(self, user_id, app_id, email_address):
# XXX: matrix.to params = {
return "https://vector.im/#/settings" "access_token": self.auth_handler.generate_delete_pusher_token(user_id),
"app_id": app_id,
"pushkey": email_address,
}
# XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl,
urllib.urlencode(params),
)
def mxc_to_http_filter(self, value, width, height, resize_method="crop"): def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
if value[0:6] != "mxc://": if value[0:6] != "mxc://":

View file

@ -15,7 +15,10 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.account_data import AccountDataStore from synapse.storage.account_data import AccountDataStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore): class SlavedAccountDataStore(BaseSlavedStore):
@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore):
self._account_data_id_gen = SlavedIdTracker( self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id", db_conn, "account_data_max_stream_id", "stream_id",
) )
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = (
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = ( get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore):
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
) )
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def stream_positions(self): def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions() result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token() position = self._account_data_id_gen.get_current_token()
@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore):
if stream: if stream:
self._account_data_id_gen.advance(int(stream["position"])) self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]: for row in stream["rows"]:
user_id, data_type = row[1:3] position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate( self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,) (data_type, user_id,)
) )
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("room_account_data") stream = result.get("room_account_data")
if stream: if stream:
self._account_data_id_gen.advance(int(stream["position"])) self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("tag_account_data") stream = result.get("tag_account_data")
if stream: if stream:
self._account_data_id_gen.advance(int(stream["position"])) self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.config.appservice import load_appservices
class SlavedApplicationServiceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
self.services_cache = load_appservices(
hs.config.server_name,
hs.config.app_service_config_files
)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__

View file

@ -23,6 +23,7 @@ from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore from synapse.storage.state import StateStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json import ujson as json
@ -57,6 +58,9 @@ class SlavedEventStore(BaseSlavedStore):
"EventsRoomStreamChangeCache", min_event_val, "EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill, prefilled_cache=event_cache_prefill,
) )
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
# Cached functions can't be accessed through a class instance so we need # Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
@ -87,6 +91,9 @@ class SlavedEventStore(BaseSlavedStore):
_get_state_group_from_group = ( _get_state_group_from_group = (
StateStore.__dict__["_get_state_group_from_group"] StateStore.__dict__["_get_state_group_from_group"]
) )
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_unread_push_actions_for_user_in_range = ( get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__ DataStore.get_unread_push_actions_for_user_in_range.__func__
@ -109,10 +116,16 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_room_events_stream_for_room.__func__ DataStore.get_room_events_stream_for_room.__func__
) )
get_events_around = DataStore.get_events_around.__func__ get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__ get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__ get_state_groups = DataStore.get_state_groups.__func__
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = DataStore._set_before_and_after _set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__ _get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__ _get_events_from_cache = DataStore._get_events_from_cache.__func__
@ -220,9 +233,9 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate((event.state_key,)) self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,)) # self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
# self._membership_stream_cache.entity_has_changed( self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering
# ) )
self.get_invited_rooms_for_user.invalidate((event.state_key,)) self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state(): if not event.is_state():

View file

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.filtering import FilteringStore
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedFilteringStore, self).__init__(db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]

View file

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPresenceStore, self).__init__(db_conn, hs)
self._presence_id_gen = SlavedIdTracker(
db_conn, "presence_stream", "stream_id",
)
self._presence_on_startup = self._get_active_presence(db_conn)
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
_get_active_presence = DataStore._get_active_presence.__func__
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
get_presence_for_users = DataStore.get_presence_for_users.__func__
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication(self, result):
stream = result.get("presence")
if stream:
self._presence_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.presence_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedPresenceStore, self).process_replication(result)

View file

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore):
def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("push_rules")
if stream:
for row in stream["rows"]:
position = row[0]
user_id = row[2]
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
self.push_rules_stream_cache.entity_has_changed(
user_id, position
)
self._push_rules_stream_id_gen.advance(int(stream["position"]))
return super(SlavedPushRuleStore, self).process_replication(result)

View file

@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from # So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the # a DataStore, but we don't want to take functions that either write to the
@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore):
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_linearized_receipts_for_room = (
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self): def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions() result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token() result["receipts"] = self._receipts_id_gen.get_current_token()
@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore):
if stream: if stream:
self._receipts_id_gen.advance(int(stream["position"])) self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]: for row in stream["rows"]:
room_id, receipt_type, user_id = row[1:4] position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result) return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.registration import RegistrationStore
class SlavedRegistrationStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
].orig
_query_for_auth = DataStore._query_for_auth.__func__

View file

@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.store.get_push_rules_for_user(user_id) rules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) rules = format_push_rules_for_user(requester.user, rules)
rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
path = request.postpath[1:] path = request.postpath[1:]

View file

@ -17,7 +17,11 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import (
parse_json_object_from_request, parse_string, RestServlet
)
from synapse.http.server import finish_request
from synapse.api.errors import StoreError
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -136,6 +140,57 @@ class PushersSetRestServlet(ClientV1RestServlet):
return 200, {} return 200, {}
class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
PATTERNS = client_path_patterns("/pushers/remove$")
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_v1auth()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
pusher_pool = self.hs.get_pusherpool()
try:
yield pusher_pool.remove_pusher(
app_id=app_id,
pushkey=pushkey,
user_id=user.to_string(),
)
except StoreError as se:
if se.code != 404:
# This is fine: they're already unsubscribed
raise
self.notifier.on_new_replication_data()
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
defer.returnValue(None)
def on_OPTIONS(self, _):
return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PushersRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server) PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)

View file

@ -149,7 +149,7 @@ class DataStore(RoomMemberStore, RoomStore,
"AccountDataAndTagsChangeCache", account_max, "AccountDataAndTagsChangeCache", account_max,
) )
self.__presence_on_startup = self._get_active_presence(db_conn) self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict( presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream", db_conn, "presence_stream",
@ -190,8 +190,8 @@ class DataStore(RoomMemberStore, RoomStore,
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
def take_presence_startup_info(self): def take_presence_startup_info(self):
active_on_startup = self.__presence_on_startup active_on_startup = self._presence_on_startup
self.__presence_on_startup = None self._presence_on_startup = None
return active_on_startup return active_on_startup
def _get_active_presence(self, db_conn): def _get_active_presence(self, db_conn):

View file

@ -342,9 +342,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
)
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))

View file

@ -15,6 +15,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -23,6 +24,29 @@ import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule['rule_id']
if rule_id in enabled_map:
if rule.get('enabled', True) != bool(enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(enabled_map[rule_id])
rules[i] = rule
return rules
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks(lru=True)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
@ -42,7 +66,11 @@ class PushRuleStore(SQLBaseStore):
key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
) )
defer.returnValue(rows) enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
rules = _load_rules(rows, enabled_map)
defer.returnValue(rules)
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks(lru=True)
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
@ -85,6 +113,14 @@ class PushRuleStore(SQLBaseStore):
for row in rows: for row in rows:
results.setdefault(row['user_name'], []).append(row) results.setdefault(row['user_name'], []).append(row)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {})
)
defer.returnValue(results) defer.returnValue(results)
@cachedList(cached_method_name="get_push_rules_enabled_for_user", @cachedList(cached_method_name="get_push_rules_enabled_for_user",

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
import logging import logging
import simplejson as json import simplejson as json
@ -135,19 +135,35 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn "get_all_updated_pushers", get_all_updated_pushers_txn
) )
@cachedInlineCallbacks(num_args=1) @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
def get_users_with_pushers_in_room(self, room_id): def get_if_user_has_pusher(self, user_id):
users = yield self.get_users_in_room(room_id)
result = yield self._simple_select_many_batch( result = yield self._simple_select_many_batch(
table='pushers', table='pushers',
column='user_name', keyvalues={
iterable=users, 'user_name': 'user_id',
retcols=['user_name'], },
desc='get_users_with_pushers_in_room' retcol='user_name',
desc='get_if_user_has_pusher',
allow_none=True,
) )
defer.returnValue([r['user_name'] for r in result]) defer.returnValue(bool(result))
@cachedList(cached_method_name="get_if_user_has_pusher",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def get_if_users_have_pushers(self, user_ids):
rows = yield self._simple_select_many_batch(
table='pushers',
column='user_name',
iterable=user_ids,
retcols=['user_name'],
desc='get_if_users_have_pushers'
)
result = {user_id: False for user_id in user_ids}
result.update({r['user_name']: True for r in rows})
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id, def add_pusher(self, user_id, access_token, kind, app_id,
@ -178,16 +194,16 @@ class PusherStore(SQLBaseStore):
}, },
) )
if newly_inserted: if newly_inserted:
# get_users_with_pushers_in_room only cares if the user has # get_if_user_has_pusher only cares if the user has
# at least *one* pusher. # at least *one* pusher.
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all) txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
yield self.runInteraction("add_pusher", f) yield self.runInteraction("add_pusher", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id): def delete_pusher_txn(txn, stream_id):
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all) txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
self._simple_delete_one_txn( self._simple_delete_one_txn(
txn, txn,

View file

@ -34,6 +34,26 @@ class ReceiptsStore(SQLBaseStore):
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
) )
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts))
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
if receipt_type != "m.read":
return
# Returns an ObservableDeferred
res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
if res and res.called and user_id in res.result:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
return
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
@cached(num_args=2) @cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type): def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list( return self._simple_select_list(
@ -228,6 +248,10 @@ class ReceiptsStore(SQLBaseStore):
txn.call_after( txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type) self.get_receipts_for_room.invalidate, (room_id, receipt_type)
) )
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id, receipt_type, user_id,
)
txn.call_after( txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type) self.get_receipts_for_user.invalidate, (user_id, receipt_type)
) )
@ -373,6 +397,10 @@ class ReceiptsStore(SQLBaseStore):
txn.call_after( txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type) self.get_receipts_for_room.invalidate, (room_id, receipt_type)
) )
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id, receipt_type, user_id,
)
txn.call_after( txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type) self.get_receipts_for_user.invalidate, (user_id, receipt_type)
) )

View file

@ -58,9 +58,6 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
)
txn.call_after( txn.call_after(
self._membership_stream_cache.entity_has_changed, self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering
@ -241,23 +238,10 @@ class RoomMemberStore(SQLBaseStore):
return results return results
@cached(max_entries=5000) @cachedInlineCallbacks(max_entries=5000)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self.runInteraction( user_ids = yield self.get_users_in_room(room_id)
"get_joined_hosts_for_room", defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
self._get_joined_hosts_for_room_txn,
room_id,
)
def _get_joined_hosts_for_room_txn(self, txn, room_id):
rows = self._get_members_rows_txn(
txn,
room_id, membership=Membership.JOIN
)
joined_domains = set(get_domain_from_id(r["user_id"]) for r in rows)
return joined_domains
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn( rows = self._get_members_rows_txn(

View file

@ -102,6 +102,15 @@ class ObservableDeferred(object):
def observers(self): def observers(self):
return self._observers return self._observers
def has_called(self):
return self._result is not None
def has_succeeded(self):
return self._result is not None and self._result[0] is True
def get_result(self):
return self._result[1]
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._deferred, name) return getattr(self._deferred, name)

View file

@ -24,12 +24,22 @@ DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches") metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {} caches_by_name = {}
cache_counter = metrics.register_cache( # cache_counter = metrics.register_cache(
# "cache",
# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
# labels=["name"],
# )
def register_cache(name, cache):
caches_by_name[name] = cache
return metrics.register_cache(
"cache", "cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, lambda: len(cache),
labels=["name"], name,
) )
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) _string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache caches_by_name["string_cache"] = _string_cache

View file

@ -22,7 +22,7 @@ from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
) )
from . import caches_by_name, DEBUG_CACHES, cache_counter from . import DEBUG_CACHES, register_cache
from twisted.internet import defer from twisted.internet import defer
@ -33,6 +33,7 @@ import functools
import inspect import inspect
import threading import threading
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,6 +44,15 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class Cache(object): class Cache(object):
__slots__ = (
"cache",
"max_entries",
"name",
"keylen",
"sequence",
"thread",
"metrics",
)
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru: if lru:
@ -59,7 +69,7 @@ class Cache(object):
self.keylen = keylen self.keylen = keylen
self.sequence = 0 self.sequence = 0
self.thread = None self.thread = None
caches_by_name[name] = self.cache self.metrics = register_cache(name, self.cache)
def check_thread(self): def check_thread(self):
expected_thread = self.thread expected_thread = self.thread
@ -74,10 +84,10 @@ class Cache(object):
def get(self, key, default=_CacheSentinel): def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel) val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel: if val is not _CacheSentinel:
cache_counter.inc_hits(self.name) self.metrics.inc_hits()
return val return val
cache_counter.inc_misses(self.name) self.metrics.inc_misses()
if default is _CacheSentinel: if default is _CacheSentinel:
raise KeyError() raise KeyError()
@ -293,16 +303,21 @@ class CacheListDescriptor(object):
# cached is a dict arg -> deferred, where deferred results in a # cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`) # 2-tuple (`arg`, `result`)
cached = {} results = {}
cached_defers = {}
missing = [] missing = []
for arg in list_args: for arg in list_args:
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = cache.get(tuple(key)).observe() res = cache.get(tuple(key))
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res cached_defers[arg] = res
else:
results[arg] = res.get_result()
except KeyError: except KeyError:
missing.append(arg) missing.append(arg)
@ -340,12 +355,21 @@ class CacheListDescriptor(object):
res = observer.observe() res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res cached_defers[arg] = res
if cached_defers:
def update_results_dict(res):
results.update(res)
return results
return preserve_context_over_deferred(defer.gatherResults( return preserve_context_over_deferred(defer.gatherResults(
cached.values(), cached_defers.values(),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))) ).addCallback(update_results_dict).addErrback(
unwrapFirstError
))
else:
return results
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped

View file

@ -15,7 +15,7 @@
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from collections import namedtuple from collections import namedtuple
from . import caches_by_name, cache_counter from . import register_cache
import threading import threading
import logging import logging
@ -43,7 +43,7 @@ class DictionaryCache(object):
__slots__ = [] __slots__ = []
self.sentinel = Sentinel() self.sentinel = Sentinel()
caches_by_name[name] = self.cache self.metrics = register_cache(name, self.cache)
def check_thread(self): def check_thread(self):
expected_thread = self.thread expected_thread = self.thread
@ -58,7 +58,7 @@ class DictionaryCache(object):
def get(self, key, dict_keys=None): def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel) entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel: if entry is not self.sentinel:
cache_counter.inc_hits(self.name) self.metrics.inc_hits()
if dict_keys is None: if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value)) return DictionaryEntry(entry.full, dict(entry.value))
@ -69,7 +69,7 @@ class DictionaryCache(object):
if k in entry.value if k in entry.value
}) })
cache_counter.inc_misses(self.name) self.metrics.inc_misses()
return DictionaryEntry(False, {}) return DictionaryEntry(False, {})
def invalidate(self, key): def invalidate(self, key):

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.caches import cache_counter, caches_by_name from synapse.util.caches import register_cache
import logging import logging
@ -49,7 +49,7 @@ class ExpiringCache(object):
self._cache = {} self._cache = {}
caches_by_name[cache_name] = self._cache self.metrics = register_cache(cache_name, self._cache)
def start(self): def start(self):
if not self._expiry_ms: if not self._expiry_ms:
@ -78,9 +78,9 @@ class ExpiringCache(object):
def __getitem__(self, key): def __getitem__(self, key):
try: try:
entry = self._cache[key] entry = self._cache[key]
cache_counter.inc_hits(self._cache_name) self.metrics.inc_hits()
except KeyError: except KeyError:
cache_counter.inc_misses(self._cache_name) self.metrics.inc_misses()
raise raise
if self._reset_expiry_on_get: if self._reset_expiry_on_get:

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.caches import cache_counter, caches_by_name from synapse.util.caches import register_cache
from blist import sorteddict from blist import sorteddict
@ -42,7 +42,7 @@ class StreamChangeCache(object):
self._cache = sorteddict() self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos self._earliest_known_stream_pos = current_stream_pos
self.name = name self.name = name
caches_by_name[self.name] = self._cache self.metrics = register_cache(self.name, self._cache)
for entity, stream_pos in prefilled_cache.items(): for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos) self.entity_has_changed(entity, stream_pos)
@ -53,19 +53,19 @@ class StreamChangeCache(object):
assert type(stream_pos) is int assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos: if stream_pos < self._earliest_known_stream_pos:
cache_counter.inc_misses(self.name) self.metrics.inc_misses()
return True return True
latest_entity_change_pos = self._entity_to_key.get(entity, None) latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None: if latest_entity_change_pos is None:
cache_counter.inc_hits(self.name) self.metrics.inc_hits()
return False return False
if stream_pos < latest_entity_change_pos: if stream_pos < latest_entity_change_pos:
cache_counter.inc_misses(self.name) self.metrics.inc_misses()
return True return True
cache_counter.inc_hits(self.name) self.metrics.inc_hits()
return False return False
def get_entities_changed(self, entities, stream_pos): def get_entities_changed(self, entities, stream_pos):
@ -82,10 +82,10 @@ class StreamChangeCache(object):
self._cache[k] for k in keys[i:] self._cache[k] for k in keys[i:]
).intersection(entities) ).intersection(entities)
cache_counter.inc_hits(self.name) self.metrics.inc_hits()
else: else:
result = entities result = entities
cache_counter.inc_misses(self.name) self.metrics.inc_misses()
return result return result

View file

@ -61,9 +61,6 @@ class CounterMetricTestCase(unittest.TestCase):
'vector{method="PUT"} 1', 'vector{method="PUT"} 1',
]) ])
# Check that passing too few values errors
self.assertRaises(ValueError, counter.inc)
class CallbackMetricTestCase(unittest.TestCase): class CallbackMetricTestCase(unittest.TestCase):
@ -138,27 +135,27 @@ class CacheMetricTestCase(unittest.TestCase):
def test_cache(self): def test_cache(self):
d = dict() d = dict()
metric = CacheMetric("cache", lambda: len(d)) metric = CacheMetric("cache", lambda: len(d), "cache_name")
self.assertEquals(metric.render(), [ self.assertEquals(metric.render(), [
'cache:hits 0', 'cache:hits{name="cache_name"} 0',
'cache:total 0', 'cache:total{name="cache_name"} 0',
'cache:size 0', 'cache:size{name="cache_name"} 0',
]) ])
metric.inc_misses() metric.inc_misses()
d["key"] = "value" d["key"] = "value"
self.assertEquals(metric.render(), [ self.assertEquals(metric.render(), [
'cache:hits 0', 'cache:hits{name="cache_name"} 0',
'cache:total 1', 'cache:total{name="cache_name"} 1',
'cache:size 1', 'cache:size{name="cache_name"} 1',
]) ])
metric.inc_hits() metric.inc_hits()
self.assertEquals(metric.render(), [ self.assertEquals(metric.render(), [
'cache:hits 1', 'cache:hits{name="cache_name"} 1',
'cache:total 2', 'cache:total{name="cache_name"} 2',
'cache:size 1', 'cache:size{name="cache_name"} 1',
]) ])