Merge branch 'release-v0.18.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-09-19 17:20:25 +01:00
commit 88acb99747
80 changed files with 3093 additions and 781 deletions

View file

@ -1,3 +1,57 @@
Changes in synapse v0.18.0 (2016-09-19)
=======================================
The release includes major changes to the state storage database schemas, which
significantly reduce database size. Synapse will attempt to upgrade the current
data in the background. Servers with large SQLite database may experience
degradation of performance while this upgrade is in progress, therefore you may
want to consider migrating to using Postgres before upgrading very large SQLite
daabases
Changes:
* Make public room search case insensitive (PR #1127)
Bug fixes:
* Fix and clean up publicRooms pagination (PR #1129)
Changes in synapse v0.18.0-rc1 (2016-09-16)
===========================================
Features:
* Add ``only=highlight`` on ``/notifications`` (PR #1081)
* Add server param to /publicRooms (PR #1082)
* Allow clients to ask for the whole of a single state event (PR #1094)
* Add is_direct param to /createRoom (PR #1108)
* Add pagination support to publicRooms (PR #1121)
* Add very basic filter API to /publicRooms (PR #1126)
* Add basic direct to device messaging support for E2E (PR #1074, #1084, #1104,
#1111)
Changes:
* Move to storing state_groups_state as deltas, greatly reducing DB size (PR
#1065)
* Reduce amount of state pulled out of the DB during common requests (PR #1069)
* Allow PDF to be rendered from media repo (PR #1071)
* Reindex state_groups_state after pruning (PR #1085)
* Clobber EDUs in send queue (PR #1095)
* Conform better to the CAS protocol specification (PR #1100)
* Limit how often we ask for keys from dead servers (PR #1114)
Bug fixes:
* Fix /notifications API when used with ``from`` param (PR #1080)
* Fix backfill when cannot find an event. (PR #1107)
Changes in synapse v0.17.3 (2016-09-09) Changes in synapse v0.17.3 (2016-09-09)
======================================= =======================================

View file

@ -42,6 +42,7 @@ The current available worker applications are:
* synapse.app.appservice - handles output traffic to Application Services * synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API) * synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository. * synapse.app.media_repository - handles the media repository.
* synapse.app.client_reader - handles client API endpoints like /publicRooms
Each worker configuration file inherits the configuration of the main homeserver Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker, configuration file. You can then override configuration specific to that worker,

View file

@ -20,3 +20,5 @@ export SYNAPSE_CACHE_FACTOR=1
--pusher \ --pusher \
--synchrotron \ --synchrotron \
--federation-reader \ --federation-reader \
--client-reader \
--appservice \

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.17.3" __version__ = "0.18.0"

View file

@ -583,12 +583,15 @@ class Auth(object):
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
user_id = yield self._get_appservice_user_id(request.args) user_id = yield self._get_appservice_user_id(request)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id)) defer.returnValue(synapse.types.create_requester(user_id))
access_token = request.args["access_token"][0] access_token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
user_info = yield self.get_user_by_access_token(access_token, rights) user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
@ -629,17 +632,19 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request_args): def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token( app_service = yield self.store.get_app_service_by_token(
request_args["access_token"][0] get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
) )
if app_service is None: if app_service is None:
defer.returnValue(None) defer.returnValue(None)
if "user_id" not in request_args: if "user_id" not in request.args:
defer.returnValue(app_service.sender) defer.returnValue(app_service.sender)
user_id = request_args["user_id"][0] user_id = request.args["user_id"][0]
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue(app_service.sender) defer.returnValue(app_service.sender)
@ -833,7 +838,9 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = request.args["access_token"][0] token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token: %s" % (token,))
@ -1142,3 +1149,40 @@ class Auth(object):
"This server requires you to be a moderator in the room to" "This server requires you to be a moderator in the room to"
" edit its room list entry" " edit its room list entry"
) )
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
return bool(query_params)
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
query_params = request.args.get("access_token")
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0]

View file

@ -187,6 +187,7 @@ def start(config_options):
def start(): def start():
ps.replicate() ps.replicate()
ps.get_datastore().start_profiling() ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -0,0 +1,216 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.client_reader")
class ClientReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
pass
class ClientReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
PublicRoomListRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse client reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse client reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.client_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-client-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -182,6 +182,7 @@ def start(config_options):
reactor.run() reactor.run()
def start(): def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate() ss.replicate()

View file

@ -188,6 +188,7 @@ def start(config_options):
reactor.run() reactor.run()
def start(): def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate() ss.replicate()

View file

@ -276,6 +276,7 @@ def start(config_options):
ps.replicate() ps.replicate()
ps.get_pusherpool().start() ps.get_pusherpool().start()
ps.get_datastore().start_profiling() ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -242,6 +242,9 @@ class SynchrotronTyping(object):
self._room_typing = {} self._room_typing = {}
def stream_positions(self): def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial} return {"typing": self._latest_room_serial}
def process_replication(self, result): def process_replication(self, result):
@ -462,6 +465,7 @@ def start(config_options):
def start(): def start():
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate() ss.replicate()
ss.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable" APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_metadata(info):
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.quote(protocol) urllib.quote(protocol)
) )
try: try:
defer.returnValue((yield self.get_json(uri, {}))) info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning("query_3pe_protocol to %s did not return a"
" valid result", uri)
defer.returnValue(None)
defer.returnValue(info)
except Exception as ex: except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s", logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex) uri, ex)
defer.returnValue({}) defer.returnValue(None)
key = (service.id, protocol) key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or ( return self.protocol_meta_cache.get(key) or (

View file

@ -29,7 +29,6 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix") self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
@ -142,14 +141,6 @@ class ServerConfig(Config):
# The GC threshold parameters to pass to `gc.set_threshold`, if defined # The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10] # gc_thresholds: [700, 10, 10]
# A list of other Home Servers to fetch the public room directory from
# and include in the public room directory of this home server
# This is a temporary stopgap solution to populate new server with a
# list of rooms until there exists a good solution of a decentralized
# room directory.
# secondary_directory_servers:
# - matrix.org
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:

View file

@ -15,9 +15,30 @@
class EventContext(object): class EventContext(object):
__slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group",
"rejected",
"push_actions",
"prev_group",
"delta_ids",
"prev_state_events",
]
def __init__(self): def __init__(self):
# The current state including the current event
self.current_state_ids = None self.current_state_ids = None
# The current state excluding the current event
self.prev_state_ids = None self.prev_state_ids = None
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = [] self.push_actions = []
# A previously persisted state group and a delta between that
# and this state.
self.prev_group = None
self.delta_ids = None
self.prev_state_events = None

View file

@ -24,7 +24,6 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@ -122,8 +121,12 @@ class FederationClient(FederationBase):
pdu.event_id pdu.event_id
) )
def send_presence(self, destination, states):
if destination != self.server_name:
self._transaction_queue.enqueue_presence(destination, states)
@log_function @log_function
def send_edu(self, destination, edu_type, content): def send_edu(self, destination, edu_type, content, key=None):
edu = Edu( edu = Edu(
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
@ -134,9 +137,15 @@ class FederationClient(FederationBase):
sent_edus_counter.inc() sent_edus_counter.inc()
# TODO, add errback, etc. # TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu) self._transaction_queue.enqueue_edu(edu, key=key)
return defer.succeed(None) return defer.succeed(None)
@log_function
def send_device_messages(self, destination):
"""Sends the device messages in the local database to the remote
destination"""
self._transaction_queue.enqueue_device_messages(destination)
@log_function @log_function
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination) self._transaction_queue.enqueue_failure(failure, destination)
@ -166,7 +175,7 @@ class FederationClient(FederationBase):
) )
@log_function @log_function
def query_client_keys(self, destination, content): def query_client_keys(self, destination, content, timeout):
"""Query device keys for a device hosted on a remote server. """Query device keys for a device hosted on a remote server.
Args: Args:
@ -178,10 +187,12 @@ class FederationClient(FederationBase):
response response
""" """
sent_queries_counter.inc("client_device_keys") sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(destination, content) return self.transport_layer.query_client_keys(
destination, content, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content): def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
Args: Args:
@ -193,7 +204,9 @@ class FederationClient(FederationBase):
response response
""" """
sent_queries_counter.inc("client_one_time_keys") sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(destination, content) return self.transport_layer.claim_client_keys(
destination, content, timeout
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -471,7 +484,7 @@ class FederationClient(FederationBase):
defer.DeferredList(deferreds, consumeErrors=True) defer.DeferredList(deferreds, consumeErrors=True)
) )
for success, result in res: for success, result in res:
if success: if success and result:
signed_events.append(result) signed_events.append(result)
batch.discard(result.event_id) batch.discard(result.event_id)
@ -705,24 +718,14 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks def get_public_rooms(self, destination, limit=None, since_token=None,
def get_public_rooms(self, destinations): search_filter=None):
results_by_server = {} if destination == self.server_name:
return
@defer.inlineCallbacks return self.transport_layer.get_public_rooms(
def _get_result(s): destination, limit, since_token, search_filter
if s == self.server_name: )
defer.returnValue()
try:
result = yield self.transport_layer.get_public_rooms(s)
results_by_server[s] = result
except:
logger.exception("Error getting room list from server %r", s)
yield concurrently_execute(_get_result, destinations, 3)
defer.returnValue(results_by_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):

View file

@ -188,7 +188,7 @@ class FederationServer(FederationBase):
except SynapseError as e: except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e) logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e: except Exception as e:
logger.exception("Failed to handle edu %r", edu_type, e) logger.exception("Failed to handle edu %r", edu_type)
else: else:
logger.warn("Received EDU of type %s with no handler", edu_type) logger.warn("Received EDU of type %s with no handler", edu_type)

View file

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -26,6 +26,7 @@ from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination, get_retry_limiter, NotRetryingDestination,
) )
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state
import synapse.metrics import synapse.metrics
import logging import logging
@ -69,18 +70,28 @@ class TransactionQueue(object):
# destination -> list of tuple(edu, deferred) # destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = edus = {} self.pending_edus_by_dest = edus = {}
# Presence needs to be separate as we send single aggragate EDUs
self.pending_presence_by_dest = presence = {}
self.pending_edus_keyed_by_dest = edus_keyed = {}
metrics.register_callback( metrics.register_callback(
"pending_pdus", "pending_pdus",
lambda: sum(map(len, pdus.values())), lambda: sum(map(len, pdus.values())),
) )
metrics.register_callback( metrics.register_callback(
"pending_edus", "pending_edus",
lambda: sum(map(len, edus.values())), lambda: (
sum(map(len, edus.values()))
+ sum(map(len, presence.values()))
+ sum(map(len, edus_keyed.values()))
),
) )
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
@ -128,12 +139,26 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_edu(self, edu): def enqueue_presence(self, destination, states):
self.pending_presence_by_dest.setdefault(destination, {}).update({
state.user_id: state for state in states
})
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
def enqueue_edu(self, edu, key=None):
destination = edu.destination destination = edu.destination
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
if key:
self.pending_edus_keyed_by_dest.setdefault(
destination, {}
)[(edu.edu_type, key)] = edu
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu) self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn( preserve_context_over_fn(
@ -155,10 +180,19 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
if not self.can_send_to(destination):
return
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor()
while True:
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending
@ -171,39 +205,20 @@ class TransactionQueue(object):
) )
return return
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures
)
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination) yield run_on_reactor()
txn_id = str(self._next_txn_id) while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
@ -211,13 +226,101 @@ class TransactionQueue(object):
self.store, self.store,
) )
device_message_edus, device_stream_id = (
yield self._get_new_device_messages(destination)
)
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
)
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter,
)
if not success:
break
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
@defer.inlineCallbacks
def _get_new_device_messages(self, destination):
last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
to_device_stream_id = self.store.get_to_device_stream_token()
contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
destination, last_device_stream_id, to_device_stream_id
)
edus = [
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.direct_to_device",
content=content,
)
for content in contents
]
defer.returnValue((edus, stream_id))
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id,
should_delete_from_device_stream, limiter):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
success = True
try:
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
logger.debug( logger.debug(
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d, failures: %d)",
destination, txn_id, destination, txn_id,
len(pending_pdus), len(pdus),
len(pending_edus), len(edus),
len(pending_failures) len(failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
@ -242,9 +345,9 @@ class TransactionQueue(object):
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pending_pdus), len(pdus),
len(pending_edus), len(edus),
len(pending_failures), len(failures),
) )
with limiter: with limiter:
@ -299,12 +402,14 @@ class TransactionQueue(object):
logger.info( logger.info(
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
except NotRetryingDestination: success = False
logger.info( else:
"TX [%s] not ready for retry yet - " # Remove the acknowledged device messages from the database
"dropping transaction for now", if should_delete_from_device_stream:
destination, yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
) )
self.last_device_stream_id_by_dest[destination] = device_stream_id
except RuntimeError as e: except RuntimeError as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
@ -314,6 +419,8 @@ class TransactionQueue(object):
e, e,
) )
success = False
for p in pdus: for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination) logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e: except Exception as e:
@ -325,9 +432,9 @@ class TransactionQueue(object):
e, e,
) )
success = False
for p in pdus: for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination) logger.info("Failed to send event %s to %s", p.event_id, destination)
finally: defer.returnValue(success)
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)

View file

@ -248,12 +248,22 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_public_rooms(self, remote_server): def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None):
path = PREFIX + "/publicRooms" path = PREFIX + "/publicRooms"
args = {}
if limit:
args["limit"] = [str(limit)]
if since_token:
args["since"] = [since_token]
# TODO(erikj): Actually send the search_filter across federation.
response = yield self.client.get_json( response = yield self.client.get_json(
destination=remote_server, destination=remote_server,
path=path, path=path,
args=args,
) )
defer.returnValue(response) defer.returnValue(response)
@ -298,7 +308,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def query_client_keys(self, destination, query_content): def query_client_keys(self, destination, query_content, timeout):
"""Query the device keys for a list of user ids hosted on a remote """Query the device keys for a list of user ids hosted on a remote
server. server.
@ -327,12 +337,13 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=query_content, data=query_content,
timeout=timeout,
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def claim_client_keys(self, destination, query_content): def claim_client_keys(self, destination, query_content, timeout):
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
Request: Request:
@ -363,6 +374,7 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=query_content, data=query_content,
timeout=timeout,
) )
defer.returnValue(content) defer.returnValue(content)

View file

@ -18,7 +18,9 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
)
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -554,7 +556,11 @@ class PublicRoomList(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query): def on_GET(self, origin, content, query):
data = yield self.room_list_handler.get_local_public_room_list() limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
data = yield self.room_list_handler.get_local_public_room_list(
limit, since_token
)
defer.returnValue((200, data)) defer.returnValue((200, data))

View file

@ -176,12 +176,41 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_3pe_protocols(self): def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
protocols = {} protocols = {}
# Collect up all the individual protocol responses out of the ASes
for s in services: for s in services:
for p in s.protocols: for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) if only_protocol is not None and p != only_protocol:
continue
if p not in protocols:
protocols[p] = []
info = yield self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
def _merge_instances(infos):
if not infos:
return {}
# Merge the 'instances' lists of multiple results, but just take
# the other fields from the first as they ought to be identical
# copy the result so as not to corrupt the cached one
combined = dict(infos[0])
combined["instances"] = list(combined["instances"])
for info in infos[1:]:
combined["instances"].extend(info["instances"])
return combined
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols) defer.returnValue(protocols)

View file

@ -58,7 +58,7 @@ class DeviceHandler(BaseHandler):
attempts = 0 attempts = 0
while attempts < 5: while attempts < 5:
try: try:
device_id = stringutils.random_string_with_symbols(16) device_id = stringutils.random_string(10).upper()
yield self.store.store_device( yield self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,

View file

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
origin, sender_user_id
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
@defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
message_id = random_string(16)
remote_edu_contents = {}
for destination, messages in remote_messages.items():
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
}
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation.send_device_messages(destination)

View file

@ -13,14 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import ujson as json
import json
import logging import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api import errors from synapse.api.errors import SynapseError, CodeMessageException
import synapse.types from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,8 +31,9 @@ class E2eKeysHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
@ -40,7 +43,7 @@ class E2eKeysHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def query_devices(self, query_body): def query_devices(self, query_body, timeout):
""" Handle a device key query from a client """ Handle a device key query from a client
{ {
@ -63,27 +66,60 @@ class E2eKeysHandler(object):
# separate users by domain. # separate users by domain.
# make a map from domain to user_id to device_ids # make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict) local_query = {}
remote_queries = {}
for user_id, device_ids in device_keys_query.items(): for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id) if self.is_mine_id(user_id):
queries_by_domain[user.domain][user_id] = device_ids local_query[user_id] = device_ids
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries # do the queries
# TODO: do these in parallel failures = {}
results = {} results = {}
for destination, destination_query in queries_by_domain.items(): if local_query:
if destination == self.server_name: local_result = yield self.query_local_devices(local_query)
res = yield self.query_local_devices(destination_query) for user_id, keys in local_result.items():
else: if user_id in local_query:
res = yield self.federation.query_client_keys( results[user_id] = keys
destination, {"device_keys": destination_query}
@defer.inlineCallbacks
def do_remote_query(destination):
destination_query = remote_queries[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
) )
res = res["device_keys"] with limiter:
for user_id, keys in res.items(): remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query: if user_id in destination_query:
results[user_id] = keys results[user_id] = keys
defer.returnValue((200, {"device_keys": results})) except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries
]))
defer.returnValue({
"device_keys": results, "failures": failures,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def query_local_devices(self, query): def query_local_devices(self, query):
@ -104,7 +140,7 @@ class E2eKeysHandler(object):
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s", logger.warning("Request for keys for non-local user %s",
user_id) user_id)
raise errors.SynapseError(400, "Not a user here") raise SynapseError(400, "Not a user here")
if not device_ids: if not device_ids:
local_query.append((user_id, None)) local_query.append((user_id, None))
@ -137,3 +173,107 @@ class E2eKeysHandler(object):
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query) res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res}) defer.returnValue({"device_keys": res})
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue({
"one_time_keys": json_result,
"failures": failures
})
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})

View file

@ -832,11 +832,13 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
message_handler = self.hs.get_handlers().message_handler users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
context destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
) )
destinations = set(destinations)
destinations.discard(origin) destinations.discard(origin)
logger.debug( logger.debug(
@ -1055,11 +1057,12 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
message_handler = self.hs.get_handlers().message_handler users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
context destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
) )
destinations = set(destinations)
destinations.discard(origin) destinations.discard(origin)
logger.debug( logger.debug(
@ -1582,10 +1585,12 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key if k != event_key
}) })
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
}) })
@ -1667,10 +1672,12 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key if k != event_key
}) })
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
}) })

View file

@ -30,7 +30,6 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -945,7 +944,12 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id event_stream_id, max_stream_id
) )
destinations = yield self.get_joined_hosts_for_room_from_state(context) users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = [
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
]
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
@ -963,39 +967,3 @@ class MessageHandler(BaseHandler):
preserve_fn(federation_handler.handle_new_event)( preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations, event, destinations=destinations,
) )
def get_joined_hosts_for_room_from_state(self, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_hosts_for_room_from_state(
state_group, context.current_state_ids
)
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
cache_context):
# Don't bother getting state for people on the same HS
current_state = yield self.store.get_events([
e_id for key, e_id in current_state_ids.items()
if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
])
destinations = set()
for e in current_state.itervalues():
try:
if e.type == EventTypes.Member:
if e.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(e.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", e.event_id
)
defer.returnValue(destinations)

View file

@ -52,6 +52,11 @@ bump_active_time_counter = metrics.register_counter("bump_active_time")
get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
state_transition_counter = metrics.register_counter(
"state_transition", labels=["from", "to"]
)
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active" # "currently_active"
@ -212,7 +217,7 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
logger.info( logger.info(
"Performing _on_shutdown. Persiting %d unpersisted changes", "Performing _on_shutdown. Persisting %d unpersisted changes",
len(self.user_to_current_state) len(self.user_to_current_state)
) )
@ -229,7 +234,7 @@ class PresenceHandler(object):
may stack up and slow down shutdown times. may stack up and slow down shutdown times.
""" """
logger.info( logger.info(
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
len(self.unpersisted_users_changes) len(self.unpersisted_users_changes)
) )
@ -260,6 +265,12 @@ class PresenceHandler(object):
to_notify = {} # Changes we want to notify everyone about to_notify = {} # Changes we want to notify everyone about
to_federation_ping = {} # These need sending keep-alives to_federation_ping = {} # These need sending keep-alives
# Only bother handling the last presence change for each user
new_states_dict = {}
for new_state in new_states:
new_states_dict[new_state.user_id] = new_state
new_state = new_states_dict.values()
for new_state in new_states: for new_state in new_states:
user_id = new_state.user_id user_id = new_state.user_id
@ -614,18 +625,8 @@ class PresenceHandler(object):
Args: Args:
hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
""" """
now = self.clock.time_msec()
for host, states in hosts_to_states.items(): for host, states in hosts_to_states.items():
self.federation.send_edu( self.federation.send_presence(host, states)
destination=host,
edu_type="m.presence",
content={
"push": [
_format_user_presence_state(state, now)
for state in states
]
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def incoming_presence(self, origin, content): def incoming_presence(self, origin, content):
@ -646,6 +647,13 @@ class PresenceHandler(object):
) )
continue continue
if get_domain_from_id(user_id) != origin:
logger.info(
"Got presence update from %r with bad 'user_id': %r",
origin, user_id,
)
continue
presence_state = push.get("presence", None) presence_state = push.get("presence", None)
if not presence_state: if not presence_state:
logger.info( logger.info(
@ -705,13 +713,13 @@ class PresenceHandler(object):
defer.returnValue([ defer.returnValue([
{ {
"type": "m.presence", "type": "m.presence",
"content": _format_user_presence_state(state, now), "content": format_user_presence_state(state, now),
} }
for state in updates for state in updates
]) ])
else: else:
defer.returnValue([ defer.returnValue([
_format_user_presence_state(state, now) for state in updates format_user_presence_state(state, now) for state in updates
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
@ -939,33 +947,38 @@ class PresenceHandler(object):
def should_notify(old_state, new_state): def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties. """Decides if a presence state change should be sent to interested parties.
""" """
if old_state == new_state:
return False
if old_state.status_msg != new_state.status_msg: if old_state.status_msg != new_state.status_msg:
notify_reason_counter.inc("status_msg_change")
return True
if old_state.state != new_state.state:
notify_reason_counter.inc("state_change")
state_transition_counter.inc(old_state.state, new_state.state)
return True return True
if old_state.state == PresenceState.ONLINE: if old_state.state == PresenceState.ONLINE:
if new_state.state != PresenceState.ONLINE:
# Always notify for online -> anything
return True
if new_state.currently_active != old_state.currently_active: if new_state.currently_active != old_state.currently_active:
notify_reason_counter.inc("current_active_change")
return True return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive # Only notify about last active bumps if we're not currently acive
if not (old_state.currently_active and new_state.currently_active): if not new_state.currently_active:
notify_reason_counter.inc("last_active_change_online")
return True return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped. # Always notify for a transition where last active gets bumped.
return True notify_reason_counter.inc("last_active_change_not_online")
if old_state.state != new_state.state:
return True return True
return False return False
def _format_user_presence_state(state, now): def format_user_presence_state(state, now):
"""Convert UserPresenceState to a format that can be sent down to clients """Convert UserPresenceState to a format that can be sent down to clients
and to other servers. and to other servers.
""" """
@ -1078,7 +1091,7 @@ class PresenceEventSource(object):
defer.returnValue(([ defer.returnValue(([
{ {
"type": "m.presence", "type": "m.presence",
"content": _format_user_presence_state(s, now), "content": format_user_presence_state(s, now),
} }
for s in updates.values() for s in updates.values()
if include_offline or s.state != PresenceState.OFFLINE if include_offline or s.state != PresenceState.OFFLINE

View file

@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
} }
}, },
}, },
key=(room_id, receipt_type, user_id),
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -20,12 +20,10 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, JoinRules, RoomCreationPreset, Membership, EventTypes, JoinRules, RoomCreationPreset
) )
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from collections import OrderedDict from collections import OrderedDict
@ -36,8 +34,6 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
id_server_scheme = "https://" id_server_scheme = "https://"
@ -196,6 +192,11 @@ class RoomCreationHandler(BaseHandler):
}, },
ratelimit=False) ratelimit=False)
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
content["is_direct"] = is_direct
for invitee in invite_list: for invitee in invite_list:
yield room_member_handler.update_membership( yield room_member_handler.update_membership(
requester, requester,
@ -203,6 +204,7 @@ class RoomCreationHandler(BaseHandler):
room_id, room_id,
"invite", "invite",
ratelimit=False, ratelimit=False,
content=content,
) )
for invite_3pid in invite_3pid_list: for invite_3pid in invite_3pid_list:
@ -342,149 +344,6 @@ class RoomCreationHandler(BaseHandler):
) )
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs)
self.remote_list_request_cache = ResponseCache(hs)
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
)
self.fetch_all_remote_lists()
def get_local_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
return result
@defer.inlineCallbacks
def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks
def handle_room(room_id):
current_state = yield self.state_handler.get_current_state(room_id)
# Double check that this is actually a public room.
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
result = {"room_id": room_id}
num_joined_users = len([
1 for _, event in current_state.items()
if event.type == EventTypes.Member
and event.membership == Membership.JOIN
])
if num_joined_users == 0:
return
result["num_joined_members"] = num_joined_users
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
results.append(result)
yield concurrently_execute(handle_room, room_ids, 10)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": results})
@defer.inlineCallbacks
def fetch_all_remote_lists(self):
deferred = self.hs.get_replication_layer().get_public_rooms(
self.hs.config.secondary_directory_servers
)
self.remote_list_request_cache.set((), deferred)
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
specified in the secondary_directory_servers config option.
XXX: Pagination...
"""
# We return the results from out cache which is updated by a looping call,
# unless we're missing a cache entry, in which case wait for the result
# of the fetch if there's one in progress. If not, omit that server.
wait = False
for s in self.hs.config.secondary_directory_servers:
if s not in self.remote_list_cache:
logger.warn("No cached room list from %s: waiting for fetch", s)
wait = True
break
if wait and self.remote_list_request_cache.get(()):
yield self.remote_list_request_cache.get(())
public_rooms = yield self.get_local_public_room_list()
# keep track of which room IDs we've seen so we can de-dup
room_ids = set()
# tag all the ones in our list with our server name.
# Also add the them to the de-deping set
for room in public_rooms['chunk']:
room["server_name"] = self.hs.hostname
room_ids.add(room["room_id"])
# Now add the results from federation
for server_name, server_result in self.remote_list_cache.items():
for room in server_result["chunk"]:
if room["room_id"] not in room_ids:
room["server_name"] = server_name
public_rooms["chunk"].append(room)
room_ids.add(room["room_id"])
defer.returnValue(public_rooms)
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, is_guest): def get_event_context(self, user, room_id, event_id, limit, is_guest):

View file

@ -0,0 +1,400 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import (
EventTypes, JoinRules,
)
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from collections import namedtuple
from unpaddedbase64 import encode_base64, decode_base64
import logging
import msgpack
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs)
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None):
if search_filter:
# We explicitly don't bother caching searches.
return self._get_public_room_list(limit, since_token, search_filter)
result = self.response_cache.get((limit, since_token))
if not result:
result = self.response_cache.set(
(limit, since_token),
self._get_public_room_list(limit, since_token)
)
return result
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
since_token = None
rooms_to_order_value = {}
rooms_to_num_joined = {}
rooms_to_latest_event_ids = {}
newly_visible = []
newly_unpublished = []
if since_token:
stream_token = since_token.stream_ordering
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id
)
# We want to return rooms in a particular order: the number of joined
# users. We then arbitrarily use the room_id as a tie breaker.
@defer.inlineCallbacks
def get_order_for_room(room_id):
latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
if not latest_event_ids:
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
rooms_to_latest_event_ids[room_id] = latest_event_ids
if not latest_event_ids:
return
joined_users = yield self.state_handler.get_current_user_in_room(
room_id, latest_event_ids,
)
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
if num_joined_users == 0:
return
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
sorted_rooms = [room_id for room_id, _ in sorted_entries]
# `sorted_rooms` should now be a list of all public room ids that is
# stable across pagination. Therefore, we can use indices into this
# list as our pagination tokens.
# Filter out rooms that we don't want to return
rooms_to_scan = [
r for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
]
if since_token:
# Filter out rooms we've already returned previously
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it.
if since_token.direction_is_forward:
rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:]
else:
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
# Actually generate the entries. _generate_room_entry will append to
# chunk but will stop if len(chunk) > limit
chunk = []
if limit and not search_filter:
step = limit + 1
for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop
# at first iteration, but occaisonally _generate_room_entry
# won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
yield concurrently_execute(
lambda r: self._generate_room_entry(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan[i:i + step], 10
)
if len(chunk) >= limit + 1:
break
else:
yield concurrently_execute(
lambda r: self._generate_room_entry(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan, 5
)
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
# Work out the new limit of the batch for pagination, or None if we
# know there are no more results that would be returned.
# i.e., [since_token.current_limit..new_limit] is the batch of rooms
# we've returned (or the reverse if we paginated backwards)
# We tried to pull out limit + 1 rooms above, so if we have <= limit
# then we know there are no more results to return
new_limit = None
if chunk and (not limit or len(chunk) > limit):
if not since_token or since_token.direction_is_forward:
if limit:
chunk = chunk[:limit]
last_room_id = chunk[-1]["room_id"]
else:
if limit:
chunk = chunk[-limit:]
last_room_id = chunk[0]["room_id"]
new_limit = sorted_rooms.index(last_room_id)
results = {
"chunk": chunk,
}
if since_token:
results["new_rooms"] = bool(newly_visible)
if not since_token or since_token.direction_is_forward:
if new_limit is not None:
results["next_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=True,
).to_token()
if since_token:
results["prev_batch"] = since_token.copy_and_replace(
direction_is_forward=False,
current_limit=since_token.current_limit + 1,
).to_token()
else:
if new_limit is not None:
results["prev_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=False,
).to_token()
if since_token:
results["next_batch"] = since_token.copy_and_replace(
direction_is_forward=True,
current_limit=since_token.current_limit - 1,
).to_token()
defer.returnValue(results)
@defer.inlineCallbacks
def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
search_filter):
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
result = {
"room_id": room_id,
"num_joined_members": num_joined_users,
}
current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
event_map = yield self.store.get_events([
event_id for key, event_id in current_state_ids.items()
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
EventTypes.Topic,
EventTypes.CanonicalAlias,
EventTypes.RoomHistoryVisibility,
EventTypes.GuestAccess,
"m.room.avatar",
)
])
current_state = {
(ev.type, ev.state_key): ev
for ev in event_map.values()
}
# Double check that this is actually a public room.
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
if _matches_room_entry(result, search_filter):
chunk.append(result)
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
search_filter=None):
if search_filter:
# We currently don't support searching across federation, so we have
# to do it manually without pagination
limit = None
since_token = None
res = yield self._get_remote_list_cached(
server_name, limit=limit, since_token=since_token,
)
if search_filter:
res = {"chunk": [
entry
for entry in list(res.get("chunk", []))
if _matches_room_entry(entry, search_filter)
]}
defer.returnValue(res)
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None):
repl_layer = self.hs.get_replication_layer()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
)
result = self.remote_response_cache.get((server_name, limit, since_token))
if not result:
result = self.remote_response_cache.set(
(server_name, limit, since_token),
repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
)
)
return result
class RoomListNextBatch(namedtuple("RoomListNextBatch", (
"stream_ordering", # stream_ordering of the first public room list
"public_room_stream_id", # public room stream id for first public room list
"current_limit", # The number of previous rooms returned
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch
))):
KEY_DICT = {
"stream_ordering": "s",
"public_room_stream_id": "p",
"current_limit": "n",
"direction_is_forward": "d",
}
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod
def from_token(cls, token):
return RoomListNextBatch(**{
cls.REVERSE_KEY_DICT[key]: val
for key, val in msgpack.loads(decode_base64(token)).items()
})
def to_token(self):
return encode_base64(msgpack.dumps({
self.KEY_DICT[key]: val
for key, val in self._asdict().items()
}))
def copy_and_replace(self, **kwds):
return self._replace(
**kwds
)
def _matches_room_entry(room_entry, search_filter):
if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper():
return True
elif generic_search_term in room_entry.get("topic", "").upper():
return True
elif generic_search_term in room_entry.get("canonical_alias", "").upper():
return True
else:
return True
return False

View file

@ -187,6 +187,7 @@ class TypingHandler(object):
"user_id": user_id, "user_id": user_id,
"typing": typing, "typing": typing,
}, },
key=(room_id, user_id),
)) ))
yield preserve_context_over_deferred( yield preserve_context_over_deferred(
@ -199,7 +200,14 @@ class TypingHandler(object):
user_id = content["user_id"] user_id = content["user_id"]
# Check that the string is a valid user id # Check that the string is a valid user id
UserID.from_string(user_id) user = UserID.from_string(user_id)
if user.domain != origin:
logger.info(
"Got typing update from %r with bad 'user_id': %r",
origin, user_id,
)
return
users = yield self.state.get_current_user_in_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users) domains = set(get_domain_from_id(u) for u in users)

View file

@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None, def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False): long_retries=False, timeout=None):
""" Sends the specifed json data using PUT """ Sends the specifed json data using PUT
Args: Args:
@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
use as the request body. use as the request body.
long_retries (bool): A boolean that indicates whether we should long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=True): def post_json(self, destination, path, data={}, long_retries=True,
timeout=None):
""" Sends the specifed json data using POST """ Sends the specifed json data using POST
Args: Args:
@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
long_retries (bool): A boolean that indicates whether we should long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=True, long_retries=True,
timeout=timeout,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:

View file

@ -41,9 +41,13 @@ def parse_integer(request, name, default=None, required=False):
SynapseError: if the parameter is absent and required, or if the SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer. parameter is present and not an integer.
""" """
if name in request.args: return parse_integer_from_args(request.args, name, default, required)
def parse_integer_from_args(args, name, default=None, required=False):
if name in args:
try: try:
return int(request.args[name][0]) return int(args[name][0])
except: except:
message = "Query parameter %r must be an integer" % (name,) message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message)
@ -116,9 +120,15 @@ def parse_string(request, name, default=None, required=False,
parameter is present, must be one of a list of allowed values and parameter is present, must be one of a list of allowed values and
is not one of those allowed values. is not one of those allowed values.
""" """
return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type,
)
if name in request.args:
value = request.args[name][0] def parse_string_from_args(args, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in args:
value = args[name][0]
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)

View file

@ -263,6 +263,8 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
# XXX: once m.direct is standardised everywhere, we should use it to detect
# a DM from the user's perspective rather than this heuristic.
{ {
'rule_id': 'global/underride/.m.rule.room_one_to_one', 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [
@ -289,6 +291,34 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one',
'conditions': [
{
'kind': 'room_member_count',
'is': '2',
'_id': 'member_count',
},
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.encrypted',
'_id': '_encrypted',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{ {
'rule_id': 'global/underride/.m.rule.message', 'rule_id': 'global/underride/.m.rule.message',
'conditions': [ 'conditions': [
@ -305,6 +335,25 @@ BASE_APPEND_UNDERRIDE_RULES = [
'value': False 'value': False
} }
] ]
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
'rule_id': 'global/underride/.m.rule.encrypted',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.encrypted',
'_id': '_encrypted',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': False
}
]
} }
] ]

View file

@ -26,15 +26,6 @@ from synapse.visibility import filter_events_for_clients_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, context): def evaluator_for_event(event, hs, store, context):
rules_by_user = yield store.bulk_get_push_rules_for_room( rules_by_user = yield store.bulk_get_push_rules_for_room(
@ -48,6 +39,7 @@ def evaluator_for_event(event, hs, store, context):
if invited_user and hs.is_mine_id(invited_user): if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user) has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher: if has_pusher:
rules_by_user = dict(rules_by_user)
rules_by_user[invited_user] = yield store.get_push_rules_for_user( rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user invited_user
) )

View file

@ -36,6 +36,7 @@ REQUIREMENTS = {
"blist": ["blist"], "blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {

View file

@ -42,6 +42,7 @@ STREAM_NAMES = (
("pushers",), ("pushers",),
("caches",), ("caches",),
("to_device",), ("to_device",),
("public_rooms",),
) )
@ -131,6 +132,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -144,6 +146,7 @@ class ReplicationResource(Resource):
0, # State stream is no longer a thing 0, # State stream is no longer a thing
caches_token, caches_token,
int(stream_token.to_device_key), int(stream_token.to_device_key),
int(public_rooms_token),
)) ))
@request_handler() @request_handler()
@ -181,7 +184,7 @@ class ReplicationResource(Resource):
def replicate(self, request_streams, limit): def replicate(self, request_streams, limit):
writer = _Writer() writer = _Writer()
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token) logger.debug("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit, request_streams) yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams)
@ -193,9 +196,10 @@ class ReplicationResource(Resource):
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.debug("Replicated %d rows", writer.total)
defer.returnValue(writer.finish()) defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams): def streams(self, writer, current_token, request_streams):
@ -274,11 +278,18 @@ class ReplicationResource(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token, request_streams): def typing(self, writer, current_token, request_streams):
current_position = current_token.presence current_position = current_token.typing
request_typing = request_streams.get("typing") request_typing = request_streams.get("typing")
if request_typing is not None: if request_typing is not None:
# If they have a higher token than current max, we can assume that
# they had been talking to a previous instance of the master. Since
# we reset the token on restart, the best (but hacky) thing we can
# do is to simply resend down all the typing notifications.
if request_typing > current_position:
request_typing = 0
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
request_typing, current_position request_typing, current_position
) )
@ -393,6 +404,20 @@ class ReplicationResource(Resource):
"position", "user_id", "device_id", "message_json" "position", "user_id", "device_id", "message_json"
)) ))
@defer.inlineCallbacks
def public_rooms(self, writer, current_token, limit, request_streams):
current_position = current_token.public_rooms
public_rooms = request_streams.get("public_rooms")
if public_rooms is not None:
public_rooms_rows = yield self.store.get_all_new_public_rooms(
public_rooms, current_position, limit
)
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
"position", "room_id", "visibility"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -421,7 +446,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
))): ))):
__slots__ = [] __slots__ = []

View file

@ -16,13 +16,18 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(BaseSlavedStore): class SlavedDeviceInboxStore(BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker( self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id", db_conn, "device_max_stream_id", "stream_id",
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token()
) )
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
@ -38,5 +43,11 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
stream = result.get("to_device") stream = result.get("to_device")
if stream: if stream:
self._device_inbox_id_gen.advance(int(stream["position"])) self._device_inbox_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
return super(SlavedDeviceInboxStore, self).process_replication(result) return super(SlavedDeviceInboxStore, self).process_replication(result)

View file

@ -61,6 +61,9 @@ class SlavedEventStore(BaseSlavedStore):
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max,
) )
self.stream_ordering_month_ago = 0
self._stream_order_on_start = self.get_room_max_stream_ordering()
# Cached functions can't be accessed through a class instance so we need # Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
@ -86,6 +89,9 @@ class SlavedEventStore(BaseSlavedStore):
_get_state_groups_from_groups = ( _get_state_groups_from_groups = (
StateStore.__dict__["_get_state_groups_from_groups"] StateStore.__dict__["_get_state_groups_from_groups"]
) )
_get_state_groups_from_groups_txn = (
DataStore._get_state_groups_from_groups_txn.__func__
)
_get_state_group_from_group = ( _get_state_group_from_group = (
StateStore.__dict__["_get_state_group_from_group"] StateStore.__dict__["_get_state_group_from_group"]
) )
@ -165,6 +171,15 @@ class SlavedEventStore(BaseSlavedStore):
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
get_forward_extremeties_for_room = (
DataStore.get_forward_extremeties_for_room.__func__
)
_get_forward_extremeties_for_room = (
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
)
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token() result["events"] = self._stream_id_gen.get_current_token()

View file

@ -15,7 +15,39 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage import DataStore
from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(BaseSlavedStore): class RoomStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(RoomStore, self).__init__(db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
)
get_public_room_ids = DataStore.get_public_room_ids.__func__ get_public_room_ids = DataStore.get_public_room_ids.__func__
get_current_public_room_stream_id = (
DataStore.get_current_public_room_stream_id.__func__
)
get_public_room_ids_at_stream_id = (
DataStore.get_public_room_ids_at_stream_id.__func__
)
get_public_room_ids_at_stream_id_txn = (
DataStore.get_public_room_ids_at_stream_id_txn.__func__
)
get_published_at_stream_id_txn = (
DataStore.get_published_at_stream_id_txn.__func__
)
get_public_room_changes = DataStore.get_public_room_changes.__func__
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("public_rooms")
if stream:
self._public_room_id_gen.advance(int(stream["position"]))
return super(RoomStore, self).process_replication(result)

View file

@ -318,7 +318,7 @@ class CasRedirectServlet(ClientV1RestServlet):
service_param = urllib.urlencode({ service_param = urllib.urlencode({
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
}) })
request.redirect("%s?%s" % (self.cas_server_url, service_param)) request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request) finish_request(request)
@ -385,7 +385,7 @@ class CasTicketServlet(ClientV1RestServlet):
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
user = None user = None
attributes = None attributes = {}
try: try:
root = ET.fromstring(cas_response_body) root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"): if not root.tag.endswith("serviceResponse"):
@ -395,7 +395,6 @@ class CasTicketServlet(ClientV1RestServlet):
if child.tag.endswith("user"): if child.tag.endswith("user"):
user = child.text user = child.text
if child.tag.endswith("attributes"): if child.tag.endswith("attributes"):
attributes = {}
for attribute in child: for attribute in child:
# ElementTree library expands the namespace in # ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace. # attribute tags to the full URL of the namespace.
@ -407,8 +406,6 @@ class CasTicketServlet(ClientV1RestServlet):
attributes[tag] = attribute.text attributes[tag] = attribute.text
if user is None: if user is None:
raise Exception("CAS response does not contain user") raise Exception("CAS response does not contain user")
if attributes is None:
raise Exception("CAS response does not contain attributes")
except Exception: except Exception:
logger.error("Error parsing CAS response", exc_info=1) logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response", raise LoginError(401, "Invalid CAS response",

View file

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
try: access_token = get_access_token_from_request(request)
access_token = request.args["access_token"][0]
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
yield self.store.delete_access_token(access_token) yield self.store.delete_access_token(access_token)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_app_service(self, request, register_json, session): def _do_app_service(self, request, register_json, session):
if "access_token" not in request.args: as_token = get_access_token_from_request(request)
raise SynapseError(400, "Expected application service token.")
if "user" not in register_json: if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.") raise SynapseError(400, "Expected 'user' key.")
as_token = request.args["access_token"][0]
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
user_json = parse_json_object_from_request(request) user_json = parse_json_object_from_request(request)
if "access_token" not in request.args: access_token = get_access_token_from_request(request)
raise SynapseError(400, "Expected application service token.")
app_service = yield self.store.get_app_service_by_token( app_service = yield self.store.get_app_service_by_token(
request.args["access_token"][0] access_token
) )
if not app_service: if not app_service:
raise SynapseError(403, "Invalid application service token.") raise SynapseError(403, "Invalid application service token.")

View file

@ -22,8 +22,10 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
)
import logging import logging
import urllib import urllib
@ -120,6 +122,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(request, "format", default="content",
allowed_values=["content", "event"])
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -134,6 +138,11 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
raise SynapseError( raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND 404, "Event not found.", errcode=Codes.NOT_FOUND
) )
if format == "event":
event = format_event_for_client_v2(data.get_dict())
defer.returnValue((200, event))
elif format == "content":
defer.returnValue((200, data.get_dict()["content"])) defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -295,15 +304,64 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
server = parse_string(request, "server", default=None)
try: try:
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request, allow_guest=True)
except AuthError: except AuthError as e:
# This endpoint isn't authed, but its useful to know who's hitting # We allow people to not be authed if they're just looking at our
# it if they *do* supply an access token # room list, but require auth when we proxy the request.
# In both cases we call the auth function, as that has the side
# effect of logging who issued this request if an access token was
# provided.
if server:
raise e
else:
pass pass
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
data = yield handler.get_aggregated_public_room_list() if server:
data = yield handler.get_remote_public_room_list(
server,
limit=limit,
since_token=since_token,
)
else:
data = yield handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
)
defer.returnValue((200, data))
@defer.inlineCallbacks
def on_POST(self, request):
yield self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
limit = int(content.get("limit", 100))
since_token = content.get("since", None)
search_filter = content.get("filter", None)
handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(
server,
limit=limit,
since_token=since_token,
search_filter=search_filter,
)
else:
data = yield handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
search_filter=search_filter,
)
defer.returnValue((200, data)) defer.returnValue((200, data))

View file

@ -17,6 +17,8 @@
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from synapse.api.auth import get_access_token_from_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -90,6 +92,6 @@ class HttpTransactionStore(object):
return response return response
def _get_key(self, request): def _get_key(self, request):
token = request.args["access_token"][0] token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0] path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token return path_without_txn_id + "/" + token

View file

@ -15,15 +15,12 @@
import logging import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors from synapse.api.errors import SynapseError
import synapse.server from synapse.http.servlet import (
import synapse.types RestServlet, parse_json_object_from_request, parse_integer
from synapse.http.servlet import RestServlet, parse_json_object_from_request )
from synapse.types import UserID
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,17 +60,13 @@ class KeyUploadServlet(RestServlet):
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
""" """
super(KeyUploadServlet, self).__init__() super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
if device_id is not None: if device_id is not None:
@ -88,52 +81,15 @@ class KeyUploadServlet(RestServlet):
device_id = requester.device_id device_id = requester.device_id
if device_id is None: if device_id is None:
raise synapse.api.errors.SynapseError( raise SynapseError(
400, 400,
"To upload keys, you must pass device_id when authenticating" "To upload keys, you must pass device_id when authenticating"
) )
time_now = self.clock.time_msec() result = yield self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = body.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
) )
# TODO: Sign the JSON with the server key defer.returnValue((200, result))
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = body.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
class KeyQueryServlet(RestServlet): class KeyQueryServlet(RestServlet):
@ -195,20 +151,23 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body) result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue(result) defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
auth_user_id = requester.user.to_string() auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []
result = yield self.e2e_keys_handler.query_devices( result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}} {"device_keys": {user_id: device_ids}},
timeout,
) )
defer.returnValue(result) defer.returnValue((200, result))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
@ -240,59 +199,29 @@ class OneTimeKeyServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__() super(OneTimeKeyServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm): def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
result = yield self.handle_request( timeout = parse_integer(request, "timeout", 10 * 1000)
{"one_time_keys": {user_id: {device_id: algorithm}}} result = yield self.e2e_keys_handler.claim_one_time_keys(
{"one_time_keys": {user_id: {device_id: algorithm}}},
timeout,
) )
defer.returnValue(result) defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm): def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.handle_request(body) result = yield self.e2e_keys_handler.claim_one_time_keys(
defer.returnValue(result) body,
timeout,
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
remote_queries.setdefault(user.domain, {})[user_id] = (
device_keys
) )
results = yield self.store.claim_e2e_one_time_keys(local_query) defer.returnValue((200, result))
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"one_time_keys": json_result}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View file

@ -45,11 +45,12 @@ class NotificationsServlet(RestServlet):
from_token = parse_string(request, "from", required=False) from_token = parse_string(request, "from", required=False)
limit = parse_integer(request, "limit", default=50) limit = parse_integer(request, "limit", default=50)
only = parse_string(request, "only", required=False)
limit = min(limit, 500) limit = min(limit, 500)
push_actions = yield self.store.get_push_actions_for_user( push_actions = yield self.store.get_push_actions_for_user(
user_id, from_token, limit user_id, from_token, limit, only_highlight=(only == "highlight")
) )
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(

View file

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username'] desired_username = body['username']
appservice = None appservice = None
if 'access_token' in request.args: if has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which # fork off as soon as possible for ASes and shared secret auth which
@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username) desired_username = body.get("user", desired_username)
access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring): if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0], body desired_username, access_token, body
) )
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return

View file

@ -16,10 +16,11 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionStore from synapse.rest.client.v1.transactions import HttpTransactionStore
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,10 +40,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__() super(SendToDeviceRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()
self.device_message_handler = hs.get_device_message_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):
@ -57,28 +56,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
# TODO: Prod the notifier to wake up sync streams. sender_user_id = requester.user.to_string()
# TODO: Implement replication for the messages.
# TODO: Send the messages to remote servers if needed.
local_messages = {} yield self.device_message_handler.send_device_message(
for user_id, by_device in content["messages"].items(): sender_user_id, message_type, content["messages"]
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": requester.user.to_string(),
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
) )
response = (200, {}) response = (200, {})

View file

@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet):
defer.returnValue((200, protocols)) defer.returnValue((200, protocols))
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
releases=())
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
protocols = yield self.appservice_handler.get_3pe_protocols(
only_protocol=protocol,
)
if protocol in protocols:
defer.returnValue((200, protocols[protocol]))
else:
defer.returnValue((404, {"error": "Unknown protocol"}))
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
@ -57,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
fields = request.args fields = request.args
del fields["access_token"] fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields ThirdPartyEntityKind.USER, protocol, fields
@ -81,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
fields = request.args fields = request.args
del fields["access_token"] fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields ThirdPartyEntityKind.LOCATION, protocol, fields
@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server) ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server)

View file

@ -45,7 +45,14 @@ class DownloadResource(Resource):
@request_handler() @request_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
request.setHeader("Content-Security-Policy", "sandbox") request.setHeader(
"Content-Security-Policy",
"default-src 'none';"
" script-src 'none';"
" plugin-types application/pdf;"
" style-src 'unsafe-inline';"
" object-src 'self';"
)
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name) yield self._respond_local_file(request, media_id, name)

View file

@ -35,10 +35,11 @@ from synapse.federation import initialize_http_replication
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.events import EventHandler, EventStreamHandler
@ -100,6 +101,7 @@ class HomeServer(object):
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
'device_message_handler',
'notifier', 'notifier',
'distributor', 'distributor',
'client_resource', 'client_resource',
@ -205,6 +207,9 @@ class HomeServer(object):
def build_device_handler(self): def build_device_handler(self):
return DeviceHandler(self) return DeviceHandler(self)
def build_device_message_handler(self):
return DeviceMessageHandler(self)
def build_e2e_keys_handler(self): def build_e2e_keys_handler(self):
return E2eKeysHandler(self) return E2eKeysHandler(self)

View file

@ -26,6 +26,7 @@ from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from collections import namedtuple from collections import namedtuple
from frozendict import frozendict
import logging import logging
import hashlib import hashlib
@ -55,12 +56,15 @@ def _gen_state_id():
class _StateCacheEntry(object): class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id"] __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group): def __init__(self, state, state_group, prev_group=None, delta_ids=None):
self.state = state self.state = frozendict(state)
self.state_group = state_group self.state_group = state_group
self.prev_group = prev_group
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
# The `state_id` is a unique ID we generate that can be used as ID for # The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the # this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state # state group, but on worker instances we can't generate a new state
@ -153,7 +157,8 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_user_in_room(self, room_id): def get_current_user_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state( joined_users = yield self.store.get_joined_users_from_state(
@ -234,21 +239,29 @@ class StateHandler(object):
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
context.state_group = self.store.get_next_state_group() context.state_group = self.store.get_next_state_group()
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.prev_state_ids: if key in context.prev_state_ids:
replaces = context.prev_state_ids[key] replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id context.current_state_ids[key] = event.event_id
context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids
if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids)
context.delta_ids[key] = event.event_id
else: else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
context.current_state_ids = context.prev_state_ids context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@ -283,6 +296,8 @@ class StateHandler(object):
defer.returnValue(_StateCacheEntry( defer.returnValue(_StateCacheEntry(
state=state_list, state=state_list,
state_group=name, state_group=name,
prev_group=name,
delta_ids={},
)) ))
with (yield self.resolve_linearizer.queue(group_names)): with (yield self.resolve_linearizer.queue(group_names)):
@ -340,9 +355,24 @@ class StateHandler(object):
if hasattr(self.store, "get_next_state_group"): if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group() state_group = self.store.get_next_state_group()
prev_group = None
delta_ids = None
for old_group, old_ids in state_groups_ids.items():
if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
n_delta_ids = {
k: v
for k, v in new_state.items()
if old_ids.get(k) != v
}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
cache = _StateCacheEntry( cache = _StateCacheEntry(
state=new_state, state=new_state,
state_group=state_group, state_group=state_group,
prev_group=prev_group,
delta_ids=delta_ids,
) )
if self._state_cache is not None: if self._state_cache is not None:

View file

@ -111,7 +111,10 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", "stream_id" db_conn, "presence_stream", "stream_id"
) )
self._device_inbox_id_gen = StreamIdGenerator( self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id" db_conn, "device_max_stream_id", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
) )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
@ -182,6 +185,30 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=push_rules_prefill, prefilled_cache=push_rules_prefill,
) )
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
db_conn, "device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",
@ -195,6 +222,8 @@ class DataStore(RoomMemberStore, RoomStore,
self._find_stream_orderings_for_times, 60 * 60 * 1000 self._find_stream_orderings_for_times, 60 * 60 * 1000
) )
self._stream_order_on_start = self.get_room_max_stream_ordering()
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
def take_presence_startup_info(self): def take_presence_startup_info(self):

View file

@ -133,9 +133,11 @@ class BackgroundUpdateStore(SQLBaseStore):
updates = yield self._simple_select_list( updates = yield self._simple_select_list(
"background_updates", "background_updates",
keyvalues=None, keyvalues=None,
retcols=("update_name",), retcols=("update_name", "depends_on"),
) )
in_flight = set(update["update_name"] for update in updates)
for update in updates: for update in updates:
if update["depends_on"] not in in_flight:
self._background_update_queue.append(update['update_name']) self._background_update_queue.append(update['update_name'])
if not self._background_update_queue: if not self._background_update_queue:
@ -217,7 +219,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_handlers[update_name] = update_handler self._background_update_handlers[update_name] = update_handler
def register_background_index_update(self, update_name, index_name, def register_background_index_update(self, update_name, index_name,
table, columns): table, columns, where_clause=None):
"""Helper for store classes to do a background index addition """Helper for store classes to do a background index addition
To use: To use:
@ -241,13 +243,19 @@ class BackgroundUpdateStore(SQLBaseStore):
conc = True conc = True
else: else:
conc = False conc = False
# We don't use partial indices on SQLite as it wasn't introduced
# until 3.8, and wheezy has 3.7
where_clause = None
sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \ sql = (
% { "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)"
" %(where_clause)s"
) % {
"conc": "CONCURRENTLY" if conc else "", "conc": "CONCURRENTLY" if conc else "",
"name": index_name, "name": index_name,
"table": table, "table": table,
"columns": ", ".join(columns), "columns": ", ".join(columns),
"where_clause": "WHERE " + where_clause if where_clause else ""
} }
def create_index_concurrently(conn): def create_index_concurrently(conn):

View file

@ -27,37 +27,159 @@ logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore): class DeviceInboxStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_messages_to_device_inbox(self, messages_by_user_then_device): def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
""" remote_messages_by_destination):
"""Used to send messages from this server.
Args: Args:
messages_by_user_and_device(dict): sender_user_id(str): The ID of the user sending these messages.
local_messages_by_user_and_device(dict):
Dictionary of user_id to device_id to message. Dictionary of user_id to device_id to message.
remote_messages_by_destination(dict):
Dictionary of destination server_name to the EDU JSON to send.
Returns: Returns:
A deferred stream_id that resolves when the messages have been A deferred stream_id that resolves when the messages have been
inserted. inserted.
""" """
def select_devices_txn(txn, user_id, devices): def add_messages_txn(txn, now_ms, stream_id):
if not devices: # Add the local messages directly to the local inbox.
return [] self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
)
# Add the remote messages to the federation outbox.
# We'll send them to a remote server when we next send a
# federation transaction to that destination.
sql = ( sql = (
"SELECT user_id, device_id FROM devices" "INSERT INTO device_federation_outbox"
" (destination, stream_id, queued_ts, messages_json)"
" VALUES (?,?,?,?)"
)
rows = []
for destination, edu in remote_messages_by_destination.items():
edu_json = ujson.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
"add_messages_to_device_inbox",
add_messages_txn,
now_ms,
stream_id,
)
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
for destination in remote_messages_by_destination.keys():
self._device_federation_outbox_stream_cache.entity_has_changed(
destination, stream_id
)
defer.returnValue(self._device_inbox_id_gen.get_current_token())
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
self, origin, message_id, local_messages_by_user_then_device
):
def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
already_inserted = self._simple_select_one_txn(
txn, table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
retcols=("message_id",),
allow_none=True,
)
if already_inserted is not None:
return
# Add an entry for this message_id so that we know we've processed
# it.
self._simple_insert_txn(
txn, table="device_federation_inbox",
values={
"origin": origin,
"message_id": message_id,
"received_ts": now_ms,
},
)
# Add the messages to the approriate local device inboxes so that
# they'll be sent to the devices when they next sync.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
stream_id,
)
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
defer.returnValue(stream_id)
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
messages_by_user_then_device):
sql = (
"UPDATE device_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(sql, (stream_id, stream_id))
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
devices = messages_by_device.keys()
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
"SELECT device_id FROM devices"
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"])
for row in txn.fetchall():
# Add the message for all devices for this user on this
# server.
device = row[0]
messages_json_for_user[device] = message_json
else:
if not devices:
continue
sql = (
"SELECT device_id FROM devices"
" WHERE user_id = ? AND device_id IN (" " WHERE user_id = ? AND device_id IN ("
+ ",".join("?" * len(devices)) + ",".join("?" * len(devices))
+ ")" + ")"
) )
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
args = [user_id] + devices txn.execute(sql, [user_id] + devices)
txn.execute(sql, args) for row in txn.fetchall():
return [tuple(row) for row in txn.fetchall()] # Only insert into the local inbox if the device exists on
# this server
device = row[0]
message_json = ujson.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json
def add_messages_to_device_inbox_txn(txn, stream_id): if messages_json_for_user:
local_users_and_devices = set() local_by_user_then_device[user_id] = messages_json_for_user
for user_id, messages_by_device in messages_by_user_then_device.items():
local_users_and_devices.update( if not local_by_user_then_device:
select_devices_txn(txn, user_id, messages_by_device.keys()) return
)
sql = ( sql = (
"INSERT INTO device_inbox" "INSERT INTO device_inbox"
@ -65,25 +187,12 @@ class DeviceInboxStore(SQLBaseStore):
" VALUES (?,?,?,?)" " VALUES (?,?,?,?)"
) )
rows = [] rows = []
for user_id, messages_by_device in messages_by_user_then_device.items(): for user_id, messages_by_device in local_by_user_then_device.items():
for device_id, message in messages_by_device.items(): for device_id, message_json in messages_by_device.items():
message_json = ujson.dumps(message)
# Only insert into the local inbox if the device exists on
# this server
if (user_id, device_id) in local_users_and_devices:
rows.append((user_id, device_id, stream_id, message_json)) rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows) txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_messages_to_device_inbox",
add_messages_to_device_inbox_txn,
stream_id
)
defer.returnValue(self._device_inbox_id_gen.get_current_token())
def get_new_messages_for_device( def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100 self, user_id, device_id, last_stream_id, current_stream_id, limit=100
): ):
@ -97,6 +206,12 @@ class DeviceInboxStore(SQLBaseStore):
Deferred ([dict], int): List of messages for the device and where Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to. in the stream the messages got to.
""" """
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn): def get_new_messages_for_device_txn(txn):
sql = ( sql = (
"SELECT stream_id, message_json FROM device_inbox" "SELECT stream_id, message_json FROM device_inbox"
@ -182,3 +297,71 @@ class DeviceInboxStore(SQLBaseStore):
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int): The last position of the device message stream
that the server sent up to.
current_stream_id(int): The current position of the device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn.fetchall():
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)

View file

@ -54,8 +54,12 @@ class DeviceStore(SQLBaseStore):
or_ignore=ignore_if_known, or_ignore=ignore_if_known,
) )
except Exception as e: except Exception as e:
logger.error("store_device with device_id=%s failed: %s", logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
device_id, e) " display_name=%s(%r) failed: %s",
type(device_id).__name__, device_id,
type(user_id).__name__, user_id,
type(initial_device_display_name).__name__,
initial_device_display_name, e)
raise StoreError(500, "Problem storing device.") raise StoreError(500, "Problem storing device.")
def get_device(self, user_id, device_id): def get_device(self, user_id, device_id):

View file

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -36,6 +37,13 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively. and backfilling from another server respectively.
""" """
def __init__(self, hs):
super(EventFederationStore, self).__init__(hs)
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
def get_auth_chain(self, event_ids): def get_auth_chain(self, event_ids):
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
@ -270,6 +278,37 @@ class EventFederationStore(SQLBaseStore):
] ]
) )
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
max_stream_ord = max(
ev.internal_metadata.stream_ordering for ev in events
)
new_extrem = {}
for room_id in events_by_room:
event_ids = self._simple_select_onecol_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
new_extrem[room_id] = event_ids
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_ord,
}
for room_id, extrem_evs in new_extrem.items()
for event_id in extrem_evs
]
)
query = ( query = (
"INSERT INTO event_backward_extremities (event_id, room_id)" "INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS (" " SELECT ?, ? WHERE NOT EXISTS ("
@ -305,6 +344,76 @@ class EventFederationStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate, (room_id,) self.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
# try and pin to a stream_ordering from before a restart
last_change = max(self._stream_order_on_start, last_change)
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
return self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
Throws a StoreError if we have since purged the index for
stream_orderings from that point.
"""
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old")
sql = ("""
SELECT event_id FROM stream_ordering_to_exterm
INNER JOIN (
SELECT room_id, MAX(stream_ordering) AS stream_ordering
FROM stream_ordering_to_exterm
WHERE stream_ordering <= ? GROUP BY room_id
) AS rms USING (room_id, stream_ordering)
WHERE room_id = ?
""")
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
rows = txn.fetchall()
return [event_id for event_id, in rows]
return self.runInteraction(
"get_forward_extremeties_for_room",
get_forward_extremeties_for_room_txn
)
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = ("""
DELETE FROM stream_ordering_to_exterm
WHERE
(
SELECT max(stream_ordering) AS stream_ordering
FROM stream_ordering_to_exterm
WHERE room_id = stream_ordering_to_exterm.room_id
) > ?
AND stream_ordering < ?
""")
txn.execute(
sql,
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
)
return self.runInteraction(
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn
)
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and """Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit` including) the events in event_list. Return a list of max size `limit`

View file

@ -26,10 +26,19 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore): class EventPushActionsStore(SQLBaseStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, hs): def __init__(self, hs):
self.stream_ordering_month_ago = None self.stream_ordering_month_ago = None
super(EventPushActionsStore, self).__init__(hs) super(EventPushActionsStore, self).__init__(hs)
self.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight",
table="event_push_actions",
columns=["user_id", "stream_ordering"],
)
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
""" """
Args: Args:
@ -338,20 +347,29 @@ class EventPushActionsStore(SQLBaseStore):
defer.returnValue(notifs[:limit]) defer.returnValue(notifs[:limit])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50): def get_push_actions_for_user(self, user_id, before=None, limit=50,
only_highlight=False):
def f(txn): def f(txn):
before_clause = "" before_clause = ""
if before: if before:
before_clause = "AND stream_ordering < ?" before_clause = "AND epa.stream_ordering < ?"
args = [user_id, before, limit] args = [user_id, before, limit]
else: else:
args = [user_id, limit] args = [user_id, limit]
if only_highlight:
if len(before_clause) > 0:
before_clause += " "
before_clause += "AND epa.highlight = 1"
# NB. This assumes event_ids are globally unique since
# it makes the query easier to index
sql = ( sql = (
"SELECT epa.event_id, epa.room_id," "SELECT epa.event_id, epa.room_id,"
" epa.stream_ordering, epa.topological_ordering," " epa.stream_ordering, epa.topological_ordering,"
" epa.actions, epa.profile_tag, e.received_ts" " epa.actions, epa.profile_tag, e.received_ts"
" FROM event_push_actions epa, events e" " FROM event_push_actions epa, events e"
" WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id" " WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s" " AND epa.user_id = ? %s"
" ORDER BY epa.stream_ordering DESC" " ORDER BY epa.stream_ordering DESC"
" LIMIT ?" " LIMIT ?"

View file

@ -189,6 +189,14 @@ class EventsStore(SQLBaseStore):
self._background_reindex_fields_sender, self._background_reindex_fields_sender,
) )
self.register_background_index_update(
"event_contains_url_index",
index_name="event_contains_url_index",
table="events",
columns=["room_id", "topological_ordering", "stream_ordering"],
where_clause="contains_url = true AND outlier = false",
)
self._event_persist_queue = _EventPeristenceQueue() self._event_persist_queue = _EventPeristenceQueue()
def persist_events(self, events_and_contexts, backfilled=False): def persist_events(self, events_and_contexts, backfilled=False):
@ -497,7 +505,11 @@ class EventsStore(SQLBaseStore):
# insert into the state_group, state_groups_state and # insert into the state_group, state_groups_state and
# event_to_state_groups tables. # event_to_state_groups tables.
try:
self._store_mult_state_groups_txn(txn, ((event, context),)) self._store_mult_state_groups_txn(txn, ((event, context),))
except Exception:
logger.exception("")
raise
metadata_json = encode_json( metadata_json = encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
@ -1543,6 +1555,9 @@ class EventsStore(SQLBaseStore):
) )
event_rows = txn.fetchall() event_rows = txn.fetchall()
for event_id, state_key in event_rows:
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
# We calculate the new entries for the backward extremeties by finding # We calculate the new entries for the backward extremeties by finding
# all events that point to events that are to be purged # all events that point to events that are to be purged
txn.execute( txn.execute(
@ -1582,7 +1597,66 @@ class EventsStore(SQLBaseStore):
" GROUP BY state_group HAVING MAX(topological_ordering) < ?", " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
(room_id, topological_ordering, topological_ordering) (room_id, topological_ordering, topological_ordering)
) )
state_rows = txn.fetchall() state_rows = txn.fetchall()
state_groups_to_delete = [sg for sg, in state_rows]
# Now we get all the state groups that rely on these state groups
new_state_edges = []
chunks = [
state_groups_to_delete[i:i + 100]
for i in xrange(0, len(state_groups_to_delete), 100)
]
for chunk in chunks:
rows = self._simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
iterable=chunk,
retcols=["state_group"],
keyvalues={},
)
new_state_edges.extend(row["state_group"] for row in rows)
# Now we turn the state groups that reference to-be-deleted state groups
# to non delta versions.
for new_state_edge in new_state_edges:
curr_state = self._get_state_groups_from_groups_txn(
txn, [new_state_edge], types=None
)
curr_state = curr_state[new_state_edge]
self._simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": new_state_edge,
}
)
self._simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": new_state_edge,
}
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": new_state_edge,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in curr_state.items()
],
)
txn.executemany( txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?", "DELETE FROM state_groups_state WHERE state_group = ?",
state_rows state_rows

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 34 SCHEMA_VERSION = 35
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -242,7 +242,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module = imp.load_source( module = imp.load_source(
module_name, absolute_path, python_file module_name, absolute_path, python_file
) )
logger.debug("Running script %s", relative_path) logger.info("Running script %s", relative_path)
module.run_create(cur, database_engine) module.run_create(cur, database_engine)
if not is_empty: if not is_empty:
module.run_upgrade(cur, database_engine, config=config) module.run_upgrade(cur, database_engine, config=config)
@ -253,7 +253,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
pass pass
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path) logger.info("Applying schema %s", relative_path)
executescript(cur, absolute_path) executescript(cur, absolute_path)
else: else:
# Not a valid delta file. # Not a valid delta file.

View file

@ -48,14 +48,30 @@ class RoomStore(SQLBaseStore):
StoreError if the room could not be stored. StoreError if the room could not be stored.
""" """
try: try:
yield self._simple_insert( def store_room_txn(txn, next_id):
self._simple_insert_txn(
txn,
"rooms", "rooms",
{ {
"room_id": room_id, "room_id": room_id,
"creator": room_creator_user_id, "creator": room_creator_user_id,
"is_public": is_public, "is_public": is_public,
}, },
desc="store_room", )
if is_public:
self._simple_insert_txn(
txn,
table="public_room_list_stream",
values={
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
}
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"store_room_txn",
store_room_txn, next_id,
) )
except Exception as e: except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
@ -77,12 +93,44 @@ class RoomStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public): def set_room_is_public(self, room_id, is_public):
return self._simple_update_one( def set_room_is_public_txn(txn, next_id):
self._simple_update_one_txn(
txn,
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
updatevalues={"is_public": is_public}, updatevalues={"is_public": is_public},
desc="set_room_is_public", )
entries = self._simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={"room_id": room_id},
retcols=("stream_id", "visibility"),
)
entries.sort(key=lambda r: r["stream_id"])
add_to_stream = True
if entries:
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
self._simple_insert_txn(
txn,
table="public_room_list_stream",
values={
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
}
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public",
set_room_is_public_txn, next_id,
) )
def get_public_room_ids(self): def get_public_room_ids(self):
@ -207,3 +255,71 @@ class RoomStore(SQLBaseStore):
}, },
desc="add_event_report" desc="add_event_report"
) )
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
def get_public_room_ids_at_stream_id(self, stream_id):
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn, stream_id
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id):
return {
rm
for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items()
if vis
}
def get_published_at_stream_id_txn(self, txn, stream_id):
sql = ("""
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream
WHERE stream_id <= ?
GROUP BY room_id
) grouped USING (room_id, stream_id)
""")
txn.execute(sql, (stream_id,))
return dict(txn.fetchall())
def get_public_room_changes(self, prev_stream_id, new_stream_id):
def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id)
now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id)
now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis
)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
newly_visible = now_rooms_visible - then_rooms
newly_unpublished = now_rooms_not_visible & then_rooms
return newly_visible, newly_unpublished
return self.runInteraction(
"get_public_room_changes", get_public_room_changes_txn
)
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = ("""
SELECT stream_id, room_id, visibility FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
""")
txn.execute(sql, (prev_id, current_id, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)

View file

@ -13,6 +13,10 @@
* limitations under the License. * limitations under the License.
*/ */
/** Using CREATE INDEX directly is deprecated in favour of using background
* update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
* and synapse/storage/registration.py for an example using
* "access_tokens_device_index" **/
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( CREATE INDEX receipts_linearized_room_stream ON receipts_linearized(
room_id, stream_id room_id, stream_id
); );

View file

@ -13,4 +13,8 @@
* limitations under the License. * limitations under the License.
*/ */
/** Using CREATE INDEX directly is deprecated in favour of using background
* update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
* and synapse/storage/registration.py for an example using
* "access_tokens_device_index" **/
CREATE INDEX events_room_stream on events(room_id, stream_ordering); CREATE INDEX events_room_stream on events(room_id, stream_ordering);

View file

@ -13,4 +13,8 @@
* limitations under the License. * limitations under the License.
*/ */
/** Using CREATE INDEX directly is deprecated in favour of using background
* update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
* and synapse/storage/registration.py for an example using
* "access_tokens_device_index" **/
CREATE INDEX public_room_index on rooms(is_public); CREATE INDEX public_room_index on rooms(is_public);

View file

@ -13,6 +13,10 @@
* limitations under the License. * limitations under the License.
*/ */
/** Using CREATE INDEX directly is deprecated in favour of using background
* update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
* and synapse/storage/registration.py for an example using
* "access_tokens_device_index" **/
CREATE INDEX receipts_linearized_user ON receipts_linearized( CREATE INDEX receipts_linearized_user ON receipts_linearized(
user_id user_id
); );

View file

@ -26,6 +26,10 @@ UPDATE event_push_actions SET stream_ordering = (
UPDATE event_push_actions SET notif = 1, highlight = 0; UPDATE event_push_actions SET notif = 1, highlight = 0;
/** Using CREATE INDEX directly is deprecated in favour of using background
* update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
* and synapse/storage/registration.py for an example using
* "access_tokens_device_index" **/
CREATE INDEX event_push_actions_rm_tokens on event_push_actions( CREATE INDEX event_push_actions_rm_tokens on event_push_actions(
user_id, room_id, topological_ordering, stream_ordering user_id, room_id, topological_ordering, stream_ordering
); );

View file

@ -13,6 +13,10 @@
* limitations under the License. * limitations under the License.
*/ */
/** Using CREATE INDEX directly is deprecated in favour of using background
* update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
* and synapse/storage/registration.py for an example using
* "access_tokens_device_index" **/
CREATE INDEX event_push_actions_stream_ordering on event_push_actions( CREATE INDEX event_push_actions_stream_ordering on event_push_actions(
stream_ordering, user_id stream_ordering, user_id
); );

View file

@ -0,0 +1,20 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication');

View file

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT into background_updates (update_name, progress_json)
VALUES ('event_contains_url_index', '{}');

View file

@ -0,0 +1,39 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DROP TABLE IF EXISTS device_federation_outbox;
CREATE TABLE device_federation_outbox (
destination TEXT NOT NULL,
stream_id BIGINT NOT NULL,
queued_ts BIGINT NOT NULL,
messages_json TEXT NOT NULL
);
DROP INDEX IF EXISTS device_federation_outbox_destination_id;
CREATE INDEX device_federation_outbox_destination_id
ON device_federation_outbox(destination, stream_id);
DROP TABLE IF EXISTS device_federation_inbox;
CREATE TABLE device_federation_inbox (
origin TEXT NOT NULL,
message_id TEXT NOT NULL,
received_ts BIGINT NOT NULL
);
DROP INDEX IF EXISTS device_federation_inbox_sender_id;
CREATE INDEX device_federation_inbox_sender_id
ON device_federation_inbox(origin, message_id);

View file

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE device_max_stream_id (
stream_id BIGINT NOT NULL
);
INSERT INTO device_max_stream_id (stream_id)
SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox;

View file

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT into background_updates (update_name, progress_json)
VALUES ('epa_highlight_index', '{}');

View file

@ -0,0 +1,33 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE public_room_list_stream (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
visibility BOOLEAN NOT NULL
);
INSERT INTO public_room_list_stream (stream_id, room_id, visibility)
SELECT 1, room_id, is_public FROM rooms
WHERE is_public = CAST(1 AS BOOLEAN);
CREATE INDEX public_room_list_stream_idx on public_room_list_stream(
stream_id
);
CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream(
room_id, stream_id
);

View file

@ -0,0 +1,22 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE state_group_edges(
state_group BIGINT NOT NULL,
prev_state_group BIGINT NOT NULL
);
CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group);

View file

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT into background_updates (update_name, progress_json)
VALUES ('state_group_state_deduplication', '{}');

View file

@ -0,0 +1,37 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE stream_ordering_to_exterm (
stream_ordering BIGINT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL
);
INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id)
SELECT stream_ordering, room_id, event_id FROM event_forward_extremities
INNER JOIN (
SELECT room_id, max(stream_ordering) as stream_ordering FROM events
INNER JOIN event_forward_extremities USING (room_id, event_id)
GROUP BY room_id
) AS rms USING (room_id);
CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm(
stream_ordering
);
CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm(
room_id, stream_ordering
);

View file

@ -16,6 +16,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
@ -24,6 +25,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
class StateStore(SQLBaseStore): class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event. """ Keeps track of the state at a given event.
@ -43,6 +47,20 @@ class StateStore(SQLBaseStore):
* `state_groups_state`: Maps state group to state events. * `state_groups_state`: Maps state group to state events.
""" """
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
def __init__(self, hs):
super(StateStore, self).__init__(hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
if not event_ids: if not event_ids:
@ -103,11 +121,8 @@ class StateStore(SQLBaseStore):
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group): if self._have_persisted_state_group_txn(txn, context.state_group):
logger.info("Already persisted state_group: %r", context.state_group)
continue continue
state_event_ids = dict(context.current_state_ids)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
@ -118,6 +133,22 @@ class StateStore(SQLBaseStore):
}, },
) )
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="state_groups_state", table="state_groups_state",
@ -129,7 +160,22 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in state_event_ids.items() for key, state_id in context.delta_ids.items()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.items()
], ],
) )
@ -145,6 +191,47 @@ class StateStore(SQLBaseStore):
], ],
) )
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None: if event_type and state_key is not None:
@ -206,7 +293,78 @@ class StateStore(SQLBaseStore):
def _get_state_groups_from_groups(self, groups, types): def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id) """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
""" """
def f(txn, groups): results = {}
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, chunk, types,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
results = {group: {} for group in groups}
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
%s
""")
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
if types:
clause_to_args = [
(
"AND type = ? AND state_key = ?",
(etype, state_key)
)
for etype, state_key in types
]
else:
# If types is None we fetch all the state, and so just use an
# empty where clause with no extra args.
clause_to_args = [("", [])]
for where_clause, where_args in clause_to_args:
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
rows = self.cursor_to_dict(txn)
for row in rows:
key = (row["type"], row["state_key"])
results[group][key] = row["event_id"]
else:
if types is not None: if types is not None:
where_clause = "AND (%s)" % ( where_clause = "AND (%s)" % (
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)), " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
@ -214,40 +372,45 @@ class StateStore(SQLBaseStore):
else: else:
where_clause = "" where_clause = ""
sql = ( # We don't use WITH RECURSIVE on sqlite3 as there are distributions
"SELECT state_group, event_id, type, state_key" # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
" FROM state_groups_state WHERE" for group in groups:
" state_group IN (%s) %s" % ( group_tree = [group]
",".join("?" for _ in groups), next_group = group
where_clause,
)
)
args = list(groups) while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
group_tree.append(next_group)
sql = ("""
SELECT type, state_key, event_id FROM state_groups_state
INNER JOIN (
SELECT type, state_key, max(state_group) as state_group
FROM state_groups_state
WHERE state_group IN (%s) %s
GROUP BY type, state_key
) USING (type, state_key, state_group);
""") % (",".join("?" for _ in group_tree), where_clause,)
args = list(group_tree)
if types is not None: if types is not None:
args.extend([i for typ in types for i in typ]) args.extend([i for typ in types for i in typ])
txn.execute(sql, args) txn.execute(sql, args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
results = {group: {} for group in groups}
for row in rows: for row in rows:
key = (row["type"], row["state_key"]) key = (row["type"], row["state_key"])
results[row["state_group"]][key] = row["event_id"] results[group][key] = row["event_id"]
return results return results
results = {}
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
f, chunk
)
results.update(res)
defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_events(self, event_ids, types): def get_state_for_events(self, event_ids, types):
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
@ -504,32 +667,184 @@ class StateStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def get_all_new_state_groups(self, last_id, current_id, limit):
def get_all_new_state_groups_txn(txn):
sql = (
"SELECT id, room_id, event_id FROM state_groups"
" WHERE ? < id AND id <= ? ORDER BY id LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
groups = txn.fetchall()
if not groups:
return ([], [])
lower_bound = groups[0][0]
upper_bound = groups[-1][0]
sql = (
"SELECT state_group, type, state_key, event_id"
" FROM state_groups_state"
" WHERE ? <= state_group AND state_group <= ?"
)
txn.execute(sql, (lower_bound, upper_bound))
state_group_state = txn.fetchall()
return (groups, state_group_state)
return self.runInteraction(
"get_all_new_state_groups", get_all_new_state_groups_txn
)
def get_next_state_group(self): def get_next_state_group(self):
return self._state_groups_id_gen.get_next() return self._state_groups_id_gen.get_next()
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
last_state_group = progress.get("last_state_group", 0)
rows_inserted = progress.get("rows_inserted", 0)
max_group = progress.get("max_group", None)
BATCH_SIZE_SCALE_FACTOR = 100
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
rows = yield self._execute(
"_background_deduplicate_state", None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
def reindex_txn(txn):
new_last_state_group = last_state_group
for count in xrange(batch_size):
txn.execute(
"SELECT id, room_id FROM state_groups"
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC"
" LIMIT 1",
(new_last_state_group, max_group,)
)
row = txn.fetchone()
if row:
state_group, room_id = row
if not row or not state_group:
return True, count
txn.execute(
"SELECT state_group FROM state_group_edges"
" WHERE state_group = ?",
(state_group,)
)
# If we reach a point where we've already started inserting
# edges we should stop.
if txn.fetchall():
return True, count
txn.execute(
"SELECT coalesce(max(id), 0) FROM state_groups"
" WHERE id < ? AND room_id = ?",
(state_group, room_id,)
)
prev_group, = txn.fetchone()
new_last_state_group = state_group
if prev_group:
potential_hops = self._count_state_group_hops_txn(
txn, prev_group
)
if potential_hops >= MAX_STATE_DELTA_HOPS:
# We want to ensure chains are at most this long,#
# otherwise read performance degrades.
continue
prev_state = self._get_state_groups_from_groups_txn(
txn, [prev_group], types=None
)
prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn(
txn, [state_group], types=None
)
curr_state = curr_state[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set
# of keys
delta_state = {
key: value for key, value in curr_state.items()
if prev_state.get(key, None) != value
}
self._simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": state_group,
}
)
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": state_group,
"prev_state_group": prev_group,
}
)
self._simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": state_group,
}
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": state_group,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in delta_state.items()
],
)
progress = {
"last_state_group": state_group,
"rows_inserted": rows_inserted + batch_size,
"max_group": max_group,
}
self._background_update_progress_txn(
txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
)
return False, batch_size
finished, result = yield self.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
)
if finished:
yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
@defer.inlineCallbacks
def _background_index_state(self, progress, batch_size):
def reindex_txn(conn):
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
try:
txn = conn.cursor()
txn.execute(
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
finally:
conn.set_session(autocommit=False)
else:
txn = conn.cursor()
txn.execute(
"CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
defer.returnValue(1)

View file

@ -531,6 +531,9 @@ class StreamStore(SQLBaseStore):
) )
defer.returnValue("t%d-%d" % (topo, token)) defer.returnValue("t%d-%d" % (topo, token))
def get_room_max_stream_ordering(self):
return self._stream_id_gen.get_current_token()
def get_stream_token_for_event(self, event_id): def get_stream_token_for_event(self, event_id):
"""The stream token for an event """The stream token for an event
Args: Args:

View file

@ -121,3 +121,9 @@ class StreamChangeCache(object):
k, r = self._cache.popitem() k, r = self._cache.popitem()
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
self._entity_to_key.pop(r, None) self._entity_to_key.pop(r, None)
def get_max_pos_of_last_change(self, entity):
"""Returns an upper bound of the stream id of the last change to an
entity.
"""
return self._entity_to_key.get(entity, self._earliest_known_stream_pos)

View file

@ -121,6 +121,14 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = (
lambda *args, **kargs: ([], 0)
)
self.datastore.delete_device_msgs_for_remote = (
lambda *args, **kargs: None
)
# Some local users to test with # Some local users to test with
self.u_apple = UserID.from_string("@apple:test") self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test") self.u_banana = UserID.from_string("@banana:test")