diff --git a/CHANGES.rst b/CHANGES.rst index 2ec10516fd..f1d2c7a765 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,17 @@ +Changes in synapse v0.10.0-r2 (2015-09-16) +========================================== + +* Fix bug where we always fetched remote server signing keys instead of using + ones in our cache. +* Fix adding threepids to an existing account. +* Fix bug with invinting over federation where remote server was already in + the room. (PR #281, SYN-392) + +Changes in synapse v0.10.0-r1 (2015-09-08) +========================================== + +* Fix bug with python packaging + Changes in synapse v0.10.0 (2015-09-03) ======================================= diff --git a/demo/start.sh b/demo/start.sh index 572dbfab0b..a90561488d 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -25,6 +25,7 @@ for port in 8080 8081 8082; do --generate-config \ -H "localhost:$https_port" \ --config-path "$DIR/etc/$port.config" \ + --report-stats no # Check script parameters if [ $# -eq 1 ]; then diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py new file mode 100755 index 0000000000..f0d0cd8a3f --- /dev/null +++ b/scripts-dev/definitions.py @@ -0,0 +1,142 @@ +#! /usr/bin/python + +import ast +import yaml + +class DefinitionVisitor(ast.NodeVisitor): + def __init__(self): + super(DefinitionVisitor, self).__init__() + self.functions = {} + self.classes = {} + self.names = {} + self.attrs = set() + self.definitions = { + 'def': self.functions, + 'class': self.classes, + 'names': self.names, + 'attrs': self.attrs, + } + + def visit_Name(self, node): + self.names.setdefault(type(node.ctx).__name__, set()).add(node.id) + + def visit_Attribute(self, node): + self.attrs.add(node.attr) + for child in ast.iter_child_nodes(node): + self.visit(child) + + def visit_ClassDef(self, node): + visitor = DefinitionVisitor() + self.classes[node.name] = visitor.definitions + for child in ast.iter_child_nodes(node): + visitor.visit(child) + + def visit_FunctionDef(self, node): + visitor = DefinitionVisitor() + self.functions[node.name] = visitor.definitions + for child in ast.iter_child_nodes(node): + visitor.visit(child) + + +def non_empty(defs): + functions = {name: non_empty(f) for name, f in defs['def'].items()} + classes = {name: non_empty(f) for name, f in defs['class'].items()} + result = {} + if functions: result['def'] = functions + if classes: result['class'] = classes + names = defs['names'] + uses = [] + for name in names.get('Load', ()): + if name not in names.get('Param', ()) and name not in names.get('Store', ()): + uses.append(name) + uses.extend(defs['attrs']) + if uses: result['uses'] = uses + result['names'] = names + result['attrs'] = defs['attrs'] + return result + + +def definitions_in_code(input_code): + input_ast = ast.parse(input_code) + visitor = DefinitionVisitor() + visitor.visit(input_ast) + definitions = non_empty(visitor.definitions) + return definitions + + +def definitions_in_file(filepath): + with open(filepath) as f: + return definitions_in_code(f.read()) + + +def defined_names(prefix, defs, names): + for name, funcs in defs.get('def', {}).items(): + names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + defined_names(prefix + name + ".", funcs, names) + + for name, funcs in defs.get('class', {}).items(): + names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + defined_names(prefix + name + ".", funcs, names) + + +def used_names(prefix, defs, names): + for name, funcs in defs.get('def', {}).items(): + used_names(prefix + name + ".", funcs, names) + + for name, funcs in defs.get('class', {}).items(): + used_names(prefix + name + ".", funcs, names) + + for used in defs.get('uses', ()): + if used in names: + names[used].setdefault('used', []).append(prefix.rstrip('.')) + + +if __name__ == '__main__': + import sys, os, argparse, re + + parser = argparse.ArgumentParser(description='Find definitions.') + parser.add_argument( + "--unused", action="store_true", help="Only list unused definitions" + ) + parser.add_argument( + "--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" + ) + parser.add_argument( + "--pattern", action="append", metavar="REGEXP", + help="Search for a pattern" + ) + parser.add_argument( + "directories", nargs='+', metavar="DIR", + help="Directories to search for definitions" + ) + args = parser.parse_args() + + definitions = {} + for directory in args.directories: + for root, dirs, files in os.walk(directory): + for filename in files: + if filename.endswith(".py"): + filepath = os.path.join(root, filename) + definitions[filepath] = definitions_in_file(filepath) + + names = {} + for filepath, defs in definitions.items(): + defined_names(filepath + ":", defs, names) + + for filepath, defs in definitions.items(): + used_names(filepath + ":", defs, names) + + patterns = [re.compile(pattern) for pattern in args.pattern or ()] + ignore = [re.compile(pattern) for pattern in args.ignore or ()] + + result = {} + for name, definition in names.items(): + if patterns and not any(pattern.match(name) for pattern in patterns): + continue + if ignore and any(pattern.match(name) for pattern in ignore): + continue + if args.unused and definition.get('used'): + continue + result[name] = definition + + yaml.dump(result, sys.stdout, default_flow_style=False) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 6aba72e459..62515997b1 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -95,8 +95,6 @@ class Store(object): _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] - _execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"] - def runInteraction(self, desc, func, *args, **kwargs): def r(conn): try: diff --git a/synapse/__init__.py b/synapse/__init__.py index d85bb3dce0..d62294e6bb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.10.0" +__version__ = "0.10.0-r2" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a2fa171e3a..e3b8c3099a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -23,6 +23,7 @@ from synapse.util.logutils import log_function from synapse.types import RoomID, UserID, EventID import logging +import pymacaroons logger = logging.getLogger(__name__) @@ -40,6 +41,12 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 + self._KNOWN_CAVEAT_PREFIXES = set([ + "gen = ", + "type = ", + "time < ", + "user_id = ", + ]) def check(self, event, auth_events): """ Checks if this event is correctly authed. @@ -121,6 +128,20 @@ class Auth(object): @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): + """Check if the user is currently joined in the room + Args: + room_id(str): The room to check. + user_id(str): The user to check. + current_state(dict): Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + Raises: + AuthError if the user is not in the room. + Returns: + A deferred membership event for the user if the user is in + the room. + """ if current_state: member = current_state.get( (EventTypes.Member, user_id), @@ -136,6 +157,43 @@ class Auth(object): self._check_joined_room(member, user_id, room_id) defer.returnValue(member) + @defer.inlineCallbacks + def check_user_was_in_room(self, room_id, user_id, current_state=None): + """Check if the user was in the room at some point. + Args: + room_id(str): The room to check. + user_id(str): The user to check. + current_state(dict): Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + Raises: + AuthError if the user was never in the room. + Returns: + A deferred membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. + """ + if current_state: + member = current_state.get( + (EventTypes.Member, user_id), + None + ) + else: + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership not in (Membership.JOIN, Membership.LEAVE): + raise AuthError(403, "User %s not in room %s" % ( + user_id, room_id + )) + + defer.returnValue(member) + @defer.inlineCallbacks def check_host_in_room(self, room_id, host): curr_state = yield self.state.get_current_state(room_id) @@ -390,7 +448,7 @@ class Auth(object): except KeyError: pass # normal users won't have the user_id query parameter set. - user_info = yield self.get_user_by_access_token(access_token) + user_info = yield self._get_user_by_access_token(access_token) user = user_info["user"] token_id = user_info["token_id"] @@ -417,7 +475,7 @@ class Auth(object): ) @defer.inlineCallbacks - def get_user_by_access_token(self, token): + def _get_user_by_access_token(self, token): """ Get a registered user's ID. Args: @@ -427,6 +485,86 @@ class Auth(object): Raises: AuthError if no user by that token exists or the token is invalid. """ + try: + ret = yield self._get_user_from_macaroon(token) + except AuthError: + # TODO(daniel): Remove this fallback when all existing access tokens + # have been re-issued as macaroons. + ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) + + @defer.inlineCallbacks + def _get_user_from_macaroon(self, macaroon_str): + try: + macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) + self._validate_macaroon(macaroon) + + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) + # This codepath exists so that we can actually return a + # token ID, because we use token IDs in place of device + # identifiers throughout the codebase. + # TODO(daniel): Remove this fallback when device IDs are + # properly implemented. + ret = yield self._look_up_user_by_access_token(macaroon_str) + if ret["user"] != user: + logger.error( + "Macaroon user (%s) != DB user (%s)", + user, + ret["user"] + ) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "User mismatch in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + defer.returnValue(ret) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN + ) + + def _validate_macaroon(self, macaroon): + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = access") + v.satisfy_general(lambda c: c.startswith("user_id = ")) + v.satisfy_general(self._verify_expiry) + v.verify(macaroon, self.hs.config.macaroon_secret_key) + + v = pymacaroons.Verifier() + v.satisfy_general(self._verify_recognizes_caveats) + v.verify(macaroon, self.hs.config.macaroon_secret_key) + + def _verify_expiry(self, caveat): + prefix = "time < " + if not caveat.startswith(prefix): + return False + # TODO(daniel): Enable expiry check when clients actually know how to + # refresh tokens. (And remember to enable the tests) + return True + expiry = int(caveat[len(prefix):]) + now = self.hs.get_clock().time_msec() + return now < expiry + + def _verify_recognizes_caveats(self, caveat): + first_space = caveat.find(" ") + if first_space < 0: + return False + second_space = caveat.find(" ", first_space + 1) + if second_space < 0: + return False + return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES + + @defer.inlineCallbacks + def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) if not ret: raise AuthError( @@ -437,7 +575,6 @@ class Auth(object): "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } - defer.returnValue(user_info) @defer.inlineCallbacks diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 1423986c1e..3385664394 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -27,16 +27,6 @@ class Membership(object): LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) -class Feedback(object): - - """Represents the types of feedback a user can send in response to a - message.""" - - DELIVERED = u"delivered" - READ = u"read" - LIST = (DELIVERED, READ) - - class PresenceState(object): """Represents the presence state of a user.""" OFFLINE = u"offline" @@ -73,7 +63,6 @@ class EventTypes(object): PowerLevels = "m.room.power_levels" Aliases = "m.room.aliases" Redaction = "m.room.redaction" - Feedback = "m.room.message.feedback" RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c3b4d971a8..ee3045268f 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -77,11 +77,6 @@ class SynapseError(CodeMessageException): ) -class RoomError(SynapseError): - """An error raised when a room event fails.""" - pass - - class RegistrationError(SynapseError): """An error raised when a registration event fails.""" pass diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index c23f853230..190b03e2f7 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -16,10 +16,23 @@ import sys sys.dont_write_bytecode = True -from synapse.python_dependencies import check_requirements, DEPENDENCY_LINKS +from synapse.python_dependencies import ( + check_requirements, DEPENDENCY_LINKS, MissingRequirementError +) if __name__ == '__main__': - check_requirements() + try: + check_requirements() + except MissingRequirementError as e: + message = "\n".join([ + "Missing Requirement: %s" % (e.message,), + "To install run:", + " pip install --upgrade --force \"%s\"" % (e.dependency,), + "", + ]) + sys.stderr.writelines(message) + sys.exit(1) + from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import ( @@ -29,7 +42,7 @@ from synapse.storage import ( from synapse.server import HomeServer -from twisted.internet import reactor +from twisted.internet import reactor, task, defer from twisted.application import service from twisted.enterprise import adbapi from twisted.web.resource import Resource, EncodingResourceWrapper @@ -72,12 +85,6 @@ import time logger = logging.getLogger("synapse.app.homeserver") -class GzipFile(File): - def getChild(self, path, request): - child = File.getChild(self, path, request) - return EncodingResourceWrapper(child, [GzipEncoderFactory()]) - - def gz_wrap(r): return EncodingResourceWrapper(r, [GzipEncoderFactory()]) @@ -121,6 +128,7 @@ class SynapseHomeServer(HomeServer): # (It can stay enabled for the API resources: they call # write() with the whole body and then finish() straight # after and so do not trigger the bug. + # GzipFile was removed in commit 184ba09 # return GzipFile(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable? @@ -221,7 +229,7 @@ class SynapseHomeServer(HomeServer): listener_config, root_resource, ), - self.tls_context_factory, + self.tls_server_context_factory, interface=bind_address ) else: @@ -365,7 +373,6 @@ def setup(config_options): Args: config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`. - should_run (bool): Whether to start the reactor. Returns: HomeServer @@ -388,7 +395,7 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - tls_context_factory = context_factory.ServerContextFactory(config) + tls_server_context_factory = context_factory.ServerContextFactory(config) database_engine = create_engine(config.database_config["name"]) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection @@ -396,7 +403,7 @@ def setup(config_options): hs = SynapseHomeServer( config.server_name, db_config=config.database_config, - tls_context_factory=tls_context_factory, + tls_server_context_factory=tls_server_context_factory, config=config, content_addr=config.content_addr, version_string=version_string, @@ -665,6 +672,42 @@ def run(hs): ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) + start_time = hs.get_clock().time() + + @defer.inlineCallbacks + def phone_stats_home(): + now = int(hs.get_clock().time()) + uptime = int(now - start_time) + if uptime < 0: + uptime = 0 + + stats = {} + stats["homeserver"] = hs.config.server_name + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + stats["total_users"] = yield hs.get_datastore().count_all_users() + + all_rooms = yield hs.get_datastore().get_rooms(False) + stats["total_room_count"] = len(all_rooms) + + stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() + daily_messages = yield hs.get_datastore().count_daily_messages() + if daily_messages is not None: + stats["daily_messages"] = daily_messages + + logger.info("Reporting stats to matrix.org: %s" % (stats,)) + try: + yield hs.get_simple_http_client().put_json( + "https://matrix.org/report-usage-stats/push", + stats + ) + except Exception as e: + logger.warn("Error reporting stats: %s", e) + + if hs.config.report_stats: + phone_home_task = task.LoopingCall(phone_stats_home) + phone_home_task.start(60 * 60 * 24, now=False) + def in_thread(): with LoggingContext("run"): change_resource_limit(hs.config.soft_file_limit) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 1f7d543c31..5d82beed0e 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -16,57 +16,67 @@ import sys import os +import os.path import subprocess import signal import yaml SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] -CONFIGFILE = "homeserver.yaml" - GREEN = "\x1b[1;32m" +RED = "\x1b[1;31m" NORMAL = "\x1b[m" -if not os.path.exists(CONFIGFILE): - sys.stderr.write( - "No config file found\n" - "To generate a config file, run '%s -c %s --generate-config" - " --server-name='\n" % ( - " ".join(SYNAPSE), CONFIGFILE - ) - ) - sys.exit(1) -CONFIG = yaml.load(open(CONFIGFILE)) -PIDFILE = CONFIG["pid_file"] - - -def start(): +def start(configfile): print "Starting ...", args = SYNAPSE - args.extend(["--daemonize", "-c", CONFIGFILE]) - subprocess.check_call(args) - print GREEN + "started" + NORMAL + args.extend(["--daemonize", "-c", configfile]) + + try: + subprocess.check_call(args) + print GREEN + "started" + NORMAL + except subprocess.CalledProcessError as e: + print ( + RED + + "error starting (exit code: %d); see above for logs" % e.returncode + + NORMAL + ) -def stop(): - if os.path.exists(PIDFILE): - pid = int(open(PIDFILE).read()) +def stop(pidfile): + if os.path.exists(pidfile): + pid = int(open(pidfile).read()) os.kill(pid, signal.SIGTERM) print GREEN + "stopped" + NORMAL def main(): + configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml" + + if not os.path.exists(configfile): + sys.stderr.write( + "No config file found\n" + "To generate a config file, run '%s -c %s --generate-config" + " --server-name='\n" % ( + " ".join(SYNAPSE), configfile + ) + ) + sys.exit(1) + + config = yaml.load(open(configfile)) + pidfile = config["pid_file"] + action = sys.argv[1] if sys.argv[1:] else "usage" if action == "start": - start() + start(configfile) elif action == "stop": - stop() + stop(pidfile) elif action == "restart": - stop() - start() + stop(pidfile) + start(configfile) else: - sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],)) + sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],)) sys.exit(1) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 8a75c48733..ceef309afc 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -26,6 +26,16 @@ class ConfigError(Exception): class Config(object): + stats_reporting_begging_spiel = ( + "We would really appreciate it if you could help our project out by" + " reporting anonymized usage statistics from your homeserver. Only very" + " basic aggregate data (e.g. number of users) will be reported, but it" + " helps us to track the growth of the Matrix community, and helps us to" + " make Matrix a success, as well as to convince other networks that they" + " should peer with us." + "\nThank you." + ) + @staticmethod def parse_size(value): if isinstance(value, int) or isinstance(value, long): @@ -111,11 +121,14 @@ class Config(object): results.append(getattr(cls, name)(self, *args, **kargs)) return results - def generate_config(self, config_dir_path, server_name): + def generate_config(self, config_dir_path, server_name, report_stats=None): default_config = "# vim:ft=yaml\n" default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( - "default_config", config_dir_path, server_name + "default_config", + config_dir_path=config_dir_path, + server_name=server_name, + report_stats=report_stats, )) config = yaml.load(default_config) @@ -139,6 +152,12 @@ class Config(object): action="store_true", help="Generate a config file for the server name" ) + config_parser.add_argument( + "--report-stats", + action="store", + help="Stuff", + choices=["yes", "no"] + ) config_parser.add_argument( "--generate-keys", action="store_true", @@ -189,6 +208,11 @@ class Config(object): config_files.append(config_path) if config_args.generate_config: + if config_args.report_stats is None: + config_parser.error( + "Please specify either --report-stats=yes or --report-stats=no\n\n" + + cls.stats_reporting_begging_spiel + ) if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" @@ -211,7 +235,9 @@ class Config(object): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: config_bytes, config = obj.generate_config( - config_dir_path, server_name + config_dir_path=config_dir_path, + server_name=server_name, + report_stats=(config_args.report_stats == "yes"), ) obj.invoke_all("generate_files", config) config_file.write(config_bytes) @@ -261,9 +287,20 @@ class Config(object): specified_config.update(yaml_config) server_name = specified_config["server_name"] - _, config = obj.generate_config(config_dir_path, server_name) + _, config = obj.generate_config( + config_dir_path=config_dir_path, + server_name=server_name + ) config.pop("log_config") config.update(specified_config) + if "report_stats" not in config: + sys.stderr.write( + "Please opt in or out of reporting anonymized homeserver usage " + "statistics, by setting the report_stats key in your config file " + " ( " + config_path + " ) " + + "to either True or False.\n\n" + + Config.stats_reporting_begging_spiel + "\n") + sys.exit(1) if generate_keys: obj.invoke_all("generate_files", config) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 38f41933b7..b8d301995e 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -20,7 +20,7 @@ class AppServiceConfig(Config): def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) - def default_config(cls, config_dir_path, server_name): + def default_config(cls, **kwargs): return """\ # A list of application service config file to use app_service_config_files: [] diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 15a132b4e3..dd92fcd0dc 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -24,7 +24,7 @@ class CaptchaConfig(Config): self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Captcha ## diff --git a/synapse/config/database.py b/synapse/config/database.py index f0611e8884..baeda8f300 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -45,7 +45,7 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def default_config(self, config, config_dir_path): + def default_config(self, **kwargs): database_path = self.abspath("homeserver.db") return """\ # Database configuration diff --git a/synapse/config/key.py b/synapse/config/key.py index 23ac8a3fca..2c187065e5 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -40,7 +40,7 @@ class KeyConfig(Config): config["perspectives"] ) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) return """\ ## Signing Keys ## diff --git a/synapse/config/logger.py b/synapse/config/logger.py index fa542623b7..bd0c17c861 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -21,6 +21,7 @@ import logging.config import yaml from string import Template import os +import signal DEFAULT_LOG_CONFIG = Template(""" @@ -69,7 +70,7 @@ class LoggingConfig(Config): self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): log_file = self.abspath("homeserver.log") log_config = self.abspath( os.path.join(config_dir_path, server_name + ".log.config") @@ -142,6 +143,19 @@ class LoggingConfig(Config): handler = logging.handlers.RotatingFileHandler( self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 ) + + def sighup(signum, stack): + logger.info("Closing log file due to SIGHUP") + handler.doRollover() + logger.info("Opened new log file due to SIGHUP") + + # TODO(paul): obviously this is a terrible mechanism for + # stealing SIGHUP, because it means no other part of synapse + # can use it instead. If we want to catch SIGHUP anywhere + # else as well, I'd suggest we find a nicer way to broadcast + # it around. + if getattr(signal, "SIGHUP"): + signal.signal(signal.SIGHUP, sighup) else: handler = logging.StreamHandler() handler.setFormatter(formatter) diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index ae5a691527..825fec9a38 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -19,13 +19,15 @@ from ._base import Config class MetricsConfig(Config): def read_config(self, config): self.enable_metrics = config["enable_metrics"] + self.report_stats = config.get("report_stats", None) self.metrics_port = config.get("metrics_port") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") - def default_config(self, config_dir_path, server_name): - return """\ + def default_config(self, report_stats=None, **kwargs): + suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n" + return ("""\ ## Metrics ### # Enable collection and rendering of performance metrics enable_metrics: False - """ + """ + suffix) % locals() diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 76d9970e5b..611b598ec7 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -27,7 +27,7 @@ class RatelimitConfig(Config): self.federation_rc_reject_limit = config["federation_rc_reject_limit"] self.federation_rc_concurrent = config["federation_rc_concurrent"] - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Ratelimiting ## diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 62de4b399f..fa98eced34 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -34,7 +34,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") - def default_config(self, config_dir, server_name): + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) macaroon_secret_key = random_string_with_symbols(50) return """\ diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 64644b9a7a..2fcf872449 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -60,7 +60,7 @@ class ContentRepositoryConfig(Config): config["thumbnail_sizes"] ) - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): media_store = self.default_path("media_store") uploads_path = self.default_path("uploads") return """ diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 1532036876..4c6133cf22 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -41,7 +41,7 @@ class SAML2Config(Config): self.saml2_config_path = None self.saml2_idp_redirect_url = None - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable SAML2 for registration and login. Uses pysaml2 # config_path: Path to the sp_conf.py configuration file diff --git a/synapse/config/server.py b/synapse/config/server.py index a03e55c223..4d12d49857 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -117,7 +117,7 @@ class ServerConfig(Config): self.content_addr = content_addr - def default_config(self, config_dir_path, server_name): + def default_config(self, server_name, **kwargs): if ":" in server_name: bind_port = int(server_name.split(":")[1]) unsecure_port = bind_port - 400 diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 4751d39bc9..0ac2698293 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -42,7 +42,15 @@ class TlsConfig(Config): config.get("tls_dh_params_path"), "tls_dh_params" ) - def default_config(self, config_dir_path, server_name): + # This config option applies to non-federation HTTP clients + # (e.g. for talking to recaptcha, identity servers, and such) + # It should never be used in production, and is intended for + # use only when running tests. + self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get( + "use_insecure_ssl_client_just_for_testing_do_not_use" + ) + + def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) tls_certificate_path = base_key_name + ".tls.crt" diff --git a/synapse/config/voip.py b/synapse/config/voip.py index a1707223d3..a093354ccd 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -22,7 +22,7 @@ class VoipConfig(Config): self.turn_shared_secret = config["turn_shared_secret"] self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Turn ## diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index e251ab6af3..8b6a59866f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -228,10 +228,9 @@ class Keyring(object): def do_iterations(): merged_results = {} - missing_keys = { - group.server_name: set(group.key_ids) - for group in group_id_to_group.values() - } + missing_keys = {} + for group in group_id_to_group.values(): + missing_keys.setdefault(group.server_name, set()).union(group.key_ids) for fn in key_fetch_fns: results = yield fn(missing_keys.items()) @@ -470,7 +469,7 @@ class Keyring(object): continue (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_context_factory, + server_name, self.hs.tls_server_context_factory, path=(b"/_matrix/key/v2/server/%s" % ( urllib.quote(requested_key_id), )).encode("ascii"), @@ -604,7 +603,7 @@ class Keyring(object): # Try to fetch the key from the remote server. (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_context_factory + server_name, self.hs.tls_server_context_factory ) # Check the response. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 59f687e0f1..793b3fcd8b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -19,7 +19,6 @@ from ._base import BaseHandler from synapse.api.constants import LoginType from synapse.types import UserID from synapse.api.errors import LoginError, Codes -from synapse.http.client import SimpleHttpClient from synapse.util.async import run_on_reactor from twisted.web.client import PartialDownloadError @@ -187,7 +186,7 @@ class AuthHandler(BaseHandler): # TODO: get this from the homeserver rather than creating a new one for # each request try: - client = SimpleHttpClient(self.hs) + client = self.hs.get_simple_http_client() resp_body = yield client.post_urlencoded_get_json( self.hs.config.recaptcha_siteverify_api, args={ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4ff20599d6..3aff80bf59 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,60 +125,72 @@ class FederationHandler(BaseHandler): ) if not is_in_room and not event.internal_metadata.is_outlier(): logger.debug("Got event for room we're not in.") - current_state = state - event_ids = set() - if state: - event_ids |= {e.event_id for e in state} - if auth_chain: - event_ids |= {e.event_id for e in auth_chain} + try: + event_stream_id, max_stream_id = yield self._persist_auth_tree( + auth_chain, state, event + ) + except AuthError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=event.event_id, + ) - seen_ids = set( - (yield self.store.have_events(event_ids)).keys() - ) + else: + event_ids = set() + if state: + event_ids |= {e.event_id for e in state} + if auth_chain: + event_ids |= {e.event_id for e in auth_chain} - if state and auth_chain is not None: - # If we have any state or auth_chain given to us by the replication - # layer, then we should handle them (if we haven't before.) - - event_infos = [] - - for e in itertools.chain(auth_chain, state): - if e.event_id in seen_ids: - continue - e.internal_metadata.outlier = True - auth_ids = [e_id for e_id, _ in e.auth_events] - auth = { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - event_infos.append({ - "event": e, - "auth_events": auth, - }) - seen_ids.add(e.event_id) - - yield self._handle_new_events( - origin, - event_infos, - outliers=True + seen_ids = set( + (yield self.store.have_events(event_ids)).keys() ) - try: - _, event_stream_id, max_stream_id = yield self._handle_new_event( - origin, - event, - state=state, - backfilled=backfilled, - current_state=current_state, - ) - except AuthError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) + if state and auth_chain is not None: + # If we have any state or auth_chain given to us by the replication + # layer, then we should handle them (if we haven't before.) + + event_infos = [] + + for e in itertools.chain(auth_chain, state): + if e.event_id in seen_ids: + continue + e.internal_metadata.outlier = True + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + event_infos.append({ + "event": e, + "auth_events": auth, + }) + seen_ids.add(e.event_id) + + yield self._handle_new_events( + origin, + event_infos, + outliers=True + ) + + try: + _, event_stream_id, max_stream_id = yield self._handle_new_event( + origin, + event, + state=state, + backfilled=backfilled, + current_state=current_state, + ) + except AuthError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=event.event_id, + ) # if we're receiving valid events from an origin, # it's probably a good idea to mark it as not in retry-state @@ -649,35 +661,8 @@ class FederationHandler(BaseHandler): # FIXME pass - ev_infos = [] - for e in itertools.chain(state, auth_chain): - if e.event_id == event.event_id: - continue - - e.internal_metadata.outlier = True - auth_ids = [e_id for e_id, _ in e.auth_events] - ev_infos.append({ - "event": e, - "auth_events": { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - }) - - yield self._handle_new_events(origin, ev_infos, outliers=True) - - auth_ids = [e_id for e_id, _ in event.auth_events] - auth_events = { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - - _, event_stream_id, max_stream_id = yield self._handle_new_event( - origin, - new_event, - state=state, - current_state=state, - auth_events=auth_events, + event_stream_id, max_stream_id = yield self._persist_auth_tree( + auth_chain, state, event ) with PreserveLoggingContext(): @@ -1026,6 +1011,76 @@ class FederationHandler(BaseHandler): is_new_state=(not outliers and not backfilled), ) + @defer.inlineCallbacks + def _persist_auth_tree(self, auth_events, state, event): + """Checks the auth chain is valid (and passes auth checks) for the + state and event. Then persists the auth chain and state atomically. + Persists the event seperately. + + Returns: + 2-tuple of (event_stream_id, max_stream_id) from the persist_event + call for `event` + """ + events_to_context = {} + for e in itertools.chain(auth_events, state): + ctx = yield self.state_handler.compute_event_context( + e, outlier=True, + ) + events_to_context[e.event_id] = ctx + e.internal_metadata.outlier = True + + event_map = { + e.event_id: e + for e in auth_events + } + + create_event = None + for e in auth_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + for e in itertools.chain(auth_events, state, [event]): + auth_for_e = { + (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] + for e_id, _ in e.auth_events + } + if create_event: + auth_for_e[(EventTypes.Create, "")] = create_event + + try: + self.auth.check(e, auth_events=auth_for_e) + except AuthError as err: + logger.warn( + "Rejecting %s because %s", + e.event_id, err.msg + ) + + if e == event: + raise + events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR + + yield self.store.persist_events( + [ + (e, events_to_context[e.event_id]) + for e in itertools.chain(auth_events, state) + ], + is_new_state=False, + ) + + new_event_context = yield self.state_handler.compute_event_context( + event, old_state=state, outlier=False, + ) + + event_stream_id, max_stream_id = yield self.store.persist_event( + event, new_event_context, + backfilled=False, + is_new_state=True, + current_state=state, + ) + + defer.returnValue((event_stream_id, max_stream_id)) + @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, backfilled=False, current_state=None, auth_events=None): @@ -1456,52 +1511,3 @@ class FederationHandler(BaseHandler): }, "missing": [e.event_id for e in missing_locals], }) - - @defer.inlineCallbacks - def _handle_auth_events(self, origin, auth_events): - auth_ids_to_deferred = {} - - def process_auth_ev(ev): - auth_ids = [e_id for e_id, _ in ev.auth_events] - - prev_ds = [ - auth_ids_to_deferred[i] - for i in auth_ids - if i in auth_ids_to_deferred - ] - - d = defer.Deferred() - - auth_ids_to_deferred[ev.event_id] = d - - @defer.inlineCallbacks - def f(*_): - ev.internal_metadata.outlier = True - - try: - auth = { - (e.type, e.state_key): e for e in auth_events - if e.event_id in auth_ids - } - - yield self._handle_new_event( - origin, ev, auth_events=auth - ) - except: - logger.exception( - "Failed to handle auth event %s", - ev.event_id, - ) - - d.callback(None) - - if prev_ds: - dx = defer.DeferredList(prev_ds) - dx.addBoth(f) - else: - f() - - for e in auth_events: - process_auth_ev(e) - - yield defer.DeferredList(auth_ids_to_deferred.values()) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 23b779ad7c..bda8eb5f3f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,13 +16,13 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import RoomError, SynapseError +from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError from synapse.util.logcontext import PreserveLoggingContext -from synapse.types import UserID, RoomStreamToken +from synapse.types import UserID, RoomStreamToken, StreamToken from ._base import BaseHandler @@ -71,7 +71,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, - feedback=False, as_client_event=True): + as_client_event=True): """Get messages in a room. Args: @@ -79,26 +79,52 @@ class MessageHandler(BaseHandler): room_id (str): The room they want messages from. pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. - feedback (bool): True to get compressed feedback with the messages as_client_event (bool): True to get events in client-server format. Returns: dict: Pagination API results """ - yield self.auth.check_joined_room(room_id, user_id) + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) data_source = self.hs.get_event_sources().sources["room"] - if not pagin_config.from_token: + if pagin_config.from_token: + room_token = pagin_config.from_token.room_key + else: pagin_config.from_token = ( yield self.hs.get_event_sources().get_current_token( direction='b' ) ) + room_token = pagin_config.from_token.room_key - room_token = RoomStreamToken.parse(pagin_config.from_token.room_key) + room_token = RoomStreamToken.parse(room_token) if room_token.topological is None: raise SynapseError(400, "Invalid token") + pagin_config.from_token = pagin_config.from_token.copy_and_replace( + "room_key", str(room_token) + ) + + source_config = pagin_config.get_source_config("room") + + if member_event.membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room + leave_token = yield self.store.get_topological_token_for_event( + member_event.event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < room_token.topological: + source_config.from_key = str(leave_token) + + if source_config.direction == "f": + if source_config.to_key is None: + source_config.to_key = str(leave_token) + else: + to_token = RoomStreamToken.parse(source_config.to_key) + if leave_token.topological < to_token.topological: + source_config.to_key = str(leave_token) + yield self.hs.get_handlers().federation_handler.maybe_backfill( room_id, room_token.topological ) @@ -106,7 +132,7 @@ class MessageHandler(BaseHandler): user = UserID.from_string(user_id) events, next_key = yield data_source.get_pagination_rows( - user, pagin_config.get_source_config("room"), room_id + user, source_config, room_id ) next_token = pagin_config.from_token.copy_and_replace( @@ -255,29 +281,26 @@ class MessageHandler(BaseHandler): Raises: SynapseError if something went wrong. """ - have_joined = yield self.auth.check_joined_room(room_id, user_id) - if not have_joined: - raise RoomError(403, "User not in room.") + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + + if member_event.membership == Membership.JOIN: + data = yield self.state_handler.get_current_state( + room_id, event_type, state_key + ) + elif member_event.membership == Membership.LEAVE: + key = (event_type, state_key) + room_state = yield self.store.get_state_for_events( + room_id, [member_event.event_id], [key] + ) + data = room_state[member_event.event_id].get(key) - data = yield self.state_handler.get_current_state( - room_id, event_type, state_key - ) defer.returnValue(data) - @defer.inlineCallbacks - def get_feedback(self, event_id): - # yield self.auth.check_joined_room(room_id, user_id) - - # Pull out the feedback from the db - fb = yield self.store.get_feedback(event_id) - - if fb: - defer.returnValue(fb) - defer.returnValue(None) - @defer.inlineCallbacks def get_state_events(self, user_id, room_id): - """Retrieve all state events for a given room. + """Retrieve all state events for a given room. If the user is + joined to the room then return the current state. If the user has + left the room return the state events from when they left. Args: user_id(str): The user requesting state events. @@ -285,18 +308,23 @@ class MessageHandler(BaseHandler): Returns: A list of dicts representing state events. [{}, {}, {}] """ - yield self.auth.check_joined_room(room_id, user_id) + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + + if member_event.membership == Membership.JOIN: + room_state = yield self.state_handler.get_current_state(room_id) + elif member_event.membership == Membership.LEAVE: + room_state = yield self.store.get_state_for_events( + room_id, [member_event.event_id], None + ) + room_state = room_state[member_event.event_id] - # TODO: This is duplicating logic from snapshot_all_rooms - current_state = yield self.state_handler.get_current_state(room_id) now = self.clock.time_msec() defer.returnValue( - [serialize_event(c, now) for c in current_state.values()] + [serialize_event(c, now) for c in room_state.values()] ) @defer.inlineCallbacks - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - feedback=False, as_client_event=True): + def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -306,7 +334,6 @@ class MessageHandler(BaseHandler): user_id (str): The ID of the user making the request. pagin_config (synapse.api.streams.PaginationConfig): The pagination config used to determine how many messages *PER ROOM* to return. - feedback (bool): True to get feedback along with these messages. as_client_event (bool): True to get events in client-server format. Returns: A list of dicts with "room_id" and "membership" keys for all rooms @@ -316,7 +343,9 @@ class MessageHandler(BaseHandler): """ room_list = yield self.store.get_rooms_for_user_where_membership_is( user_id=user_id, - membership_list=[Membership.INVITE, Membership.JOIN] + membership_list=[ + Membership.INVITE, Membership.JOIN, Membership.LEAVE + ] ) user = UserID.from_string(user_id) @@ -358,19 +387,32 @@ class MessageHandler(BaseHandler): rooms_ret.append(d) - if event.membership != Membership.JOIN: + if event.membership not in (Membership.JOIN, Membership.LEAVE): return + try: + if event.membership == Membership.JOIN: + room_end_token = now_token.room_key + deferred_room_state = self.state_handler.get_current_state( + event.room_id + ) + elif event.membership == Membership.LEAVE: + room_end_token = "s%d" % (event.stream_ordering,) + deferred_room_state = self.store.get_state_for_events( + event.room_id, [event.event_id], None + ) + deferred_room_state.addCallback( + lambda states: states[event.event_id] + ) + (messages, token), current_state = yield defer.gatherResults( [ self.store.get_recent_events_for_room( event.room_id, limit=limit, - end_token=now_token.room_key, - ), - self.state_handler.get_current_state( - event.room_id + end_token=room_end_token, ), + deferred_room_state, ] ).addErrback(unwrapFirstError) @@ -417,15 +459,85 @@ class MessageHandler(BaseHandler): defer.returnValue(ret) @defer.inlineCallbacks - def room_initial_sync(self, user_id, room_id, pagin_config=None, - feedback=False): - current_state = yield self.state.get_current_state( - room_id=room_id, + def room_initial_sync(self, user_id, room_id, pagin_config=None): + """Capture the a snapshot of a room. If user is currently a member of + the room this will be what is currently in the room. If the user left + the room this will be what was in the room when they left. + + Args: + user_id(str): The user to get a snapshot for. + room_id(str): The room to get a snapshot of. + pagin_config(synapse.streams.config.PaginationConfig): + The pagination config used to determine how many messages to + return. + Raises: + AuthError if the user wasn't in the room. + Returns: + A JSON serialisable dict with the snapshot of the room. + """ + + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + + if member_event.membership == Membership.JOIN: + result = yield self._room_initial_sync_joined( + user_id, room_id, pagin_config, member_event + ) + elif member_event.membership == Membership.LEAVE: + result = yield self._room_initial_sync_parted( + user_id, room_id, pagin_config, member_event + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _room_initial_sync_parted(self, user_id, room_id, pagin_config, + member_event): + room_state = yield self.store.get_state_for_events( + member_event.room_id, [member_event.event_id], None ) - yield self.auth.check_joined_room( - room_id, user_id, - current_state=current_state + room_state = room_state[member_event.event_id] + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + stream_token = yield self.store.get_stream_token_for_event( + member_event.event_id + ) + + messages, token = yield self.store.get_recent_events_for_room( + room_id, + limit=limit, + end_token=stream_token + ) + + messages = yield self._filter_events_for_client( + user_id, room_id, messages + ) + + start_token = StreamToken(token[0], 0, 0, 0) + end_token = StreamToken(token[1], 0, 0, 0) + + time_now = self.clock.time_msec() + + defer.returnValue({ + "membership": member_event.membership, + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": [serialize_event(s, time_now) for s in room_state.values()], + "presence": [], + "receipts": [], + }) + + @defer.inlineCallbacks + def _room_initial_sync_joined(self, user_id, room_id, pagin_config, + member_event): + current_state = yield self.state.get_current_state( + room_id=room_id, ) # TODO(paul): I wish I was called with user objects not user_id @@ -439,8 +551,6 @@ class MessageHandler(BaseHandler): for x in current_state.values() ] - member_event = current_state.get((EventTypes.Member, user_id,)) - now_token = yield self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 4f8ad824b5..ac636255c2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,7 +25,6 @@ from synapse.api.constants import ( from synapse.api.errors import StoreError, SynapseError from synapse.util import stringutils, unwrapFirstError from synapse.util.async import run_on_reactor -from synapse.events.utils import serialize_event from collections import OrderedDict import logging @@ -39,7 +38,7 @@ class RoomCreationHandler(BaseHandler): PRESETS_DICT = { RoomCreationPreset.PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, - "history_visibility": "invited", + "history_visibility": "shared", "original_invitees_have_ops": False, }, RoomCreationPreset.PUBLIC_CHAT: { @@ -159,6 +158,7 @@ class RoomCreationHandler(BaseHandler): invite_list=invite_list, initial_state=initial_state, creation_content=creation_content, + room_alias=room_alias, ) msg_handler = self.hs.get_handlers().message_handler @@ -206,7 +206,8 @@ class RoomCreationHandler(BaseHandler): defer.returnValue(result) def _create_events_for_new_room(self, creator, room_id, preset_config, - invite_list, initial_state, creation_content): + invite_list, initial_state, creation_content, + room_alias): config = RoomCreationHandler.PRESETS_DICT[preset_config] creator_id = creator.to_string() @@ -276,6 +277,14 @@ class RoomCreationHandler(BaseHandler): returned_events.append(power_levels_event) + if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: + room_alias_event = create( + etype=EventTypes.CanonicalAlias, + content={"alias": room_alias.to_string()}, + ) + + returned_events.append(room_alias_event) + if (EventTypes.JoinRules, '') not in initial_state: join_rules_event = create( etype=EventTypes.JoinRules, @@ -346,41 +355,6 @@ class RoomMemberHandler(BaseHandler): if remotedomains is not None: remotedomains.add(member.domain) - @defer.inlineCallbacks - def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None, - limit=0, start_tok=None, - end_tok=None): - """Retrieve a list of room members in the room. - - Args: - room_id (str): The room to get the member list for. - user_id (str): The ID of the user making the request. - limit (int): The max number of members to return. - start_tok (str): Optional. The start token if known. - end_tok (str): Optional. The end token if known. - Returns: - dict: A Pagination streamable dict. - Raises: - SynapseError if something goes wrong. - """ - yield self.auth.check_joined_room(room_id, user_id) - - member_list = yield self.store.get_room_members(room_id=room_id) - time_now = self.clock.time_msec() - event_list = [ - serialize_event(entry, time_now) - for entry in member_list - ] - chunk_data = { - "start": "START", # FIXME (erikj): START is no longer valid - "end": "END", - "chunk": event_list - } - # TODO honor Pagination stream params - # TODO snapshot this list to return on subsequent requests when - # paginating - defer.returnValue(chunk_data) - @defer.inlineCallbacks def change_membership(self, event, context, do_auth=True): """ Change the membership status of a user in a room. @@ -532,32 +506,6 @@ class RoomMemberHandler(BaseHandler): "user_joined_room", user=user, room_id=room_id ) - @defer.inlineCallbacks - def _should_invite_join(self, room_id, prev_state, do_auth): - logger.debug("_should_invite_join: room_id: %s", room_id) - - # XXX: We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - - # Only do an invite join dance if a) we were invited, - # b) the person inviting was from a differnt HS and c) we are - # not currently in the room - room_host = None - if prev_state and prev_state.membership == Membership.INVITE: - room = yield self.store.get_room(room_id) - inviter = UserID.from_string( - prev_state.sender - ) - - is_remote_invite_join = not self.hs.is_mine(inviter) and not room - room_host = inviter.domain - else: - is_remote_invite_join = False - - defer.returnValue((is_remote_invite_join, room_host)) - @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): """Returns a list of roomids that the user has any of the given @@ -650,7 +598,6 @@ class RoomEventSource(object): to_key=config.to_key, direction=config.direction, limit=config.limit, - with_feedback=True ) defer.returnValue((events, next_key)) diff --git a/synapse/http/client.py b/synapse/http/client.py index 4b8fd3d3a3..0933388c04 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from OpenSSL import SSL +from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import CodeMessageException from synapse.util.logcontext import preserve_context_over_fn @@ -19,7 +21,7 @@ import synapse.metrics from canonicaljson import encode_canonical_json -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor, ssl from twisted.web.client import ( Agent, readBody, FileBodyProducer, PartialDownloadError, HTTPConnectionPool, @@ -59,7 +61,12 @@ class SimpleHttpClient(object): # 'like a browser' pool = HTTPConnectionPool(reactor) pool.maxPersistentPerHost = 10 - self.agent = Agent(reactor, pool=pool) + self.agent = Agent( + reactor, + pool=pool, + connectTimeout=15, + contextFactory=hs.get_http_client_context_factory() + ) self.version_string = hs.version_string def request(self, method, uri, *args, **kwargs): @@ -252,3 +259,18 @@ def _print_ex(e): _print_ex(ex) else: logger.exception(e) + + +class InsecureInterceptableContextFactory(ssl.ContextFactory): + """ + Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain. + + Do not use this since it allows an attacker to intercept your communications. + """ + + def __init__(self): + self._context = SSL.Context(SSL.SSLv23_METHOD) + self._context.set_verify(VERIFY_NONE, lambda *_: None) + + def getContext(self, hostname, port): + return self._context diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 1c9e552788..b50a0c445c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -57,14 +57,14 @@ incoming_responses_counter = metrics.register_counter( class MatrixFederationEndpointFactory(object): def __init__(self, hs): - self.tls_context_factory = hs.tls_context_factory + self.tls_server_context_factory = hs.tls_server_context_factory def endpointForURI(self, uri): destination = uri.netloc return matrix_federation_endpoint( reactor, destination, timeout=10, - ssl_context_factory=self.tls_context_factory + ssl_context_factory=self.tls_server_context_factory ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 795ef27182..e95316720e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -18,18 +18,18 @@ from distutils.version import LooseVersion logger = logging.getLogger(__name__) REQUIREMENTS = { + "frozendict>=0.4": ["frozendict"], "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], - "Twisted>=15.1.0": ["twisted>=15.1.0"], + "pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], + "Twisted>=15.1.0": ["twisted>=15.1.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], - "pynacl>=0.3.0": ["nacl>=0.3.0"], "daemonize": ["daemonize"], "py-bcrypt": ["bcrypt"], - "frozendict>=0.4": ["frozendict"], "pillow": ["PIL"], "pydenticon": ["pydenticon"], "ujson": ["ujson"], @@ -60,7 +60,10 @@ DEPENDENCY_LINKS = { class MissingRequirementError(Exception): - pass + def __init__(self, message, module_name, dependency): + super(MissingRequirementError, self).__init__(message) + self.module_name = module_name + self.dependency = dependency def check_requirements(config=None): @@ -88,7 +91,7 @@ def check_requirements(config=None): ) raise MissingRequirementError( "Can't import %r which is part of %r" - % (module_name, dependency) + % (module_name, dependency), module_name, dependency ) version = getattr(module, "__version__", None) file_path = getattr(module, "__file__", None) @@ -101,23 +104,25 @@ def check_requirements(config=None): if version is None: raise MissingRequirementError( "Version of %r isn't set as __version__ of module %r" - % (dependency, module_name) + % (dependency, module_name), module_name, dependency ) if LooseVersion(version) < LooseVersion(required_version): raise MissingRequirementError( "Version of %r in %r is too old. %r < %r" - % (dependency, file_path, version, required_version) + % (dependency, file_path, version, required_version), + module_name, dependency ) elif version_test == "==": if version is None: raise MissingRequirementError( "Version of %r isn't set as __version__ of module %r" - % (dependency, module_name) + % (dependency, module_name), module_name, dependency ) if LooseVersion(version) != LooseVersion(required_version): raise MissingRequirementError( "Unexpected version of %r in %r. %r != %r" - % (dependency, file_path, version, required_version) + % (dependency, file_path, version, required_version), + module_name, dependency ) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 4ea4da653c..bac68cc29f 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -26,14 +26,12 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): user, _ = yield self.auth.get_user_by_req(request) - with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler content = yield handler.snapshot_all_rooms( user_id=user.to_string(), pagin_config=pagination_config, - feedback=with_feedback, as_client_event=as_client_event ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index c9c27dd5a0..23871f161e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -290,12 +290,18 @@ class RoomMemberListRestServlet(ClientV1RestServlet): def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) user, _ = yield self.auth.get_user_by_req(request) - handler = self.handlers.room_member_handler - members = yield handler.get_room_members_as_pagination_chunk( + handler = self.handlers.message_handler + events = yield handler.get_state_events( room_id=room_id, - user_id=user.to_string()) + user_id=user.to_string(), + ) - for event in members["chunk"]: + chunk = [] + + for event in events: + if event["type"] != EventTypes.Member: + continue + chunk.append(event) # FIXME: should probably be state_key here, not user_id target_user = UserID.from_string(event["user_id"]) # Presence is an optional cache; don't fail if we can't fetch it @@ -308,7 +314,9 @@ class RoomMemberListRestServlet(ClientV1RestServlet): except: pass - defer.returnValue((200, members)) + defer.returnValue((200, { + "chunk": chunk + })) # TODO: Needs unit testing @@ -321,14 +329,12 @@ class RoomMessageListRestServlet(ClientV1RestServlet): pagination_config = PaginationConfig.from_request( request, default_limit=10, ) - with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, user_id=user.to_string(), pagin_config=pagination_config, - feedback=with_feedback, as_client_event=as_client_event ) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 52e99f54d5..b107b7ce17 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from ._base import client_v2_pattern @@ -41,6 +42,9 @@ class ReceiptRestServlet(RestServlet): def on_POST(self, request, room_id, receipt_type, event_id): user, _ = yield self.auth.get_user_by_req(request) + if receipt_type != "m.read": + raise SynapseError(400, "Receipt type must be 'm.read'") + yield self.receipts_handler.received_client_receipt( room_id, receipt_type, diff --git a/synapse/server.py b/synapse/server.py index 4d1fb1cbf6..8424798b1b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -19,7 +19,9 @@ # partial one for unit test mocking. # Imports required for the default HomeServer() implementation +from twisted.web.client import BrowserLikePolicyForHTTPS from synapse.federation import initialize_http_replication +from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers @@ -87,6 +89,8 @@ class BaseHomeServer(object): 'pusherpool', 'event_builder_factory', 'filtering', + 'http_client_context_factory', + 'simple_http_client', ] def __init__(self, hostname, **kwargs): @@ -174,6 +178,17 @@ class HomeServer(BaseHomeServer): def build_auth(self): return Auth(self) + def build_http_client_context_factory(self): + config = self.get_config() + return ( + InsecureInterceptableContextFactory() + if config.use_insecure_ssl_client_just_for_testing_do_not_use + else BrowserLikePolicyForHTTPS() + ) + + def build_simple_http_client(self): + return SimpleHttpClient(self) + def build_v1auth(self): orf = Auth(self) # Matrix spec makes no reference to what HTTP status code is returned, diff --git a/synapse/state.py b/synapse/state.py index 1fe4d066bd..bb225c39cf 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.util.logutils import log_function -from synapse.util.async import run_on_reactor from synapse.util.caches.expiringcache import ExpiringCache from synapse.api.constants import EventTypes from synapse.api.errors import AuthError @@ -32,10 +31,6 @@ import hashlib logger = logging.getLogger(__name__) -def _get_state_key_from_event(event): - return event.state_key - - KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) @@ -119,8 +114,6 @@ class StateHandler(object): Returns: an EventContext """ - yield run_on_reactor() - context = EventContext() if outlier: diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 77cb1dbd81..340e59afcb 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 23 +SCHEMA_VERSION = 24 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -126,6 +126,27 @@ class DataStore(RoomMemberStore, RoomStore, lock=False, ) + @defer.inlineCallbacks + def count_daily_users(self): + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + def _count_users(txn): + txn.execute( + "SELECT COUNT(DISTINCT user_id) AS users" + " FROM user_ips" + " WHERE last_seen > ?", + # This is close enough to a day for our purposes. + (int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),) + ) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) + def get_user_ip_and_agents(self, user): return self._simple_select_list( table="user_ips", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 495ef087c9..693784ad38 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -25,8 +25,6 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple - import sys import time import threading @@ -376,9 +374,6 @@ class SQLBaseStore(object): return self.runInteraction(desc, interaction) - def _execute_and_decode(self, desc, query, *args): - return self._execute(desc, self.cursor_to_dict, query, *args) - # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. @@ -691,37 +686,6 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, - retcols=None, allow_none=False, - desc="_simple_selectupdate_one"): - """ Combined SELECT then UPDATE.""" - def func(txn): - ret = None - if retcols: - ret = self._simple_select_one_txn( - txn, - table=table, - keyvalues=keyvalues, - retcols=retcols, - allow_none=allow_none, - ) - - if updatevalues: - self._simple_update_one_txn( - txn, - table=table, - keyvalues=keyvalues, - updatevalues=updatevalues, - ) - - # if txn.rowcount == 0: - # raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched") - - return ret - return self.runInteraction(desc, func) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -743,16 +707,6 @@ class SQLBaseStore(object): raise StoreError(500, "more than one row matched") return self.runInteraction(desc, func) - def _simple_delete(self, table, keyvalues, desc="_simple_delete"): - """Executes a DELETE query on the named table. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - - return self.runInteraction(desc, self._simple_delete_txn) - def _simple_delete_txn(self, txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, @@ -761,24 +715,6 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) - def _simple_max_id(self, table): - """Executes a SELECT query on the named table, expecting to return the - max value for the column "id". - - Args: - table : string giving the table name - """ - sql = "SELECT MAX(id) AS id FROM %s" % table - - def func(txn): - txn.execute(sql) - max_id = self.cursor_to_dict(txn)[0]["id"] - if max_id is None: - return 0 - return max_id - - return self.runInteraction("_simple_max_id", func) - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id @@ -791,129 +727,3 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass - - -class Table(object): - """ A base class used to store information about a particular table. - """ - - table_name = None - """ str: The name of the table """ - - fields = None - """ list: The field names """ - - EntryType = None - """ Type: A tuple type used to decode the results """ - - _select_where_clause = "SELECT %s FROM %s WHERE %s" - _select_clause = "SELECT %s FROM %s" - _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)" - - @classmethod - def select_statement(cls, where_clause=None): - """ - Args: - where_clause (str): The WHERE clause to use. - - Returns: - str: An SQL statement to select rows from the table with the given - WHERE clause. - """ - if where_clause: - return cls._select_where_clause % ( - ", ".join(cls.fields), - cls.table_name, - where_clause - ) - else: - return cls._select_clause % ( - ", ".join(cls.fields), - cls.table_name, - ) - - @classmethod - def insert_statement(cls): - return cls._insert_clause % ( - cls.table_name, - ", ".join(cls.fields), - ", ".join(["?"] * len(cls.fields)), - ) - - @classmethod - def decode_single_result(cls, results): - """ Given an iterable of tuples, return a single instance of - `EntryType` or None if the iterable is empty - Args: - results (list): The results list to convert to `EntryType` - Returns: - EntryType: An instance of `EntryType` - """ - results = list(results) - if results: - return cls.EntryType(*results[0]) - else: - return None - - @classmethod - def decode_results(cls, results): - """ Given an iterable of tuples, return a list of `EntryType` - Args: - results (list): The results list to convert to `EntryType` - - Returns: - list: A list of `EntryType` - """ - return [cls.EntryType(*row) for row in results] - - @classmethod - def get_fields_string(cls, prefix=None): - if prefix: - to_join = ("%s.%s" % (prefix, f) for f in cls.fields) - else: - to_join = cls.fields - - return ", ".join(to_join) - - -class JoinHelper(object): - """ Used to help do joins on tables by looking at the tables' fields and - creating a list of unique fields to use with SELECTs and a namedtuple - to dump the results into. - - Attributes: - tables (list): List of `Table` classes - EntryType (type) - """ - - def __init__(self, *tables): - self.tables = tables - - res = [] - for table in self.tables: - res += [f for f in table.fields if f not in res] - - self.EntryType = namedtuple("JoinHelperEntry", res) - - def get_fields(self, **prefixes): - """Get a string representing a list of fields for use in SELECT - statements with the given prefixes applied to each. - - For example:: - - JoinHelper(PdusTable, StateTable).get_fields( - PdusTable="pdus", - StateTable="state" - ) - """ - res = [] - for field in self.EntryType._fields: - for table in self.tables: - if field in table.fields: - res.append("%s.%s" % (prefixes[table.__name__], field)) - break - - return ", ".join(res) - - def decode_results(self, rows): - return [self.EntryType(*row) for row in rows] diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 989ad340b0..6d4421dd8f 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -154,98 +154,6 @@ class EventFederationStore(SQLBaseStore): return results - def _get_latest_state_in_room(self, txn, room_id, type, state_key): - event_ids = self._simple_select_onecol_txn( - txn, - table="state_forward_extremities", - keyvalues={ - "room_id": room_id, - "type": type, - "state_key": state_key, - }, - retcol="event_id", - ) - - results = [] - for event_id in event_ids: - hashes = self._get_event_reference_hashes_txn(txn, event_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((event_id, prev_hashes)) - - return results - - def _get_prev_events(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=0, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_state(self, txn, event_id): - results = self._get_prev_events_and_state( - txn, - event_id, - is_state=True, - ) - - return [(e_id, h, ) for e_id, h, _ in results] - - def _get_prev_events_and_state(self, txn, event_id, is_state=None): - keyvalues = { - "event_id": event_id, - } - - if is_state is not None: - keyvalues["is_state"] = bool(is_state) - - res = self._simple_select_list_txn( - txn, - table="event_edges", - keyvalues=keyvalues, - retcols=["prev_event_id", "is_state"], - ) - - hashes = self._get_prev_event_hashes_txn(txn, event_id) - - results = [] - for d in res: - edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"]) - edge_hash.update(hashes.get(d["prev_event_id"], {})) - prev_hashes = { - k: encode_base64(v) - for k, v in edge_hash.items() - if k == "sha256" - } - results.append((d["prev_event_id"], prev_hashes, d["is_state"])) - - return results - - def _get_auth_events(self, txn, event_id): - auth_ids = self._simple_select_onecol_txn( - txn, - table="event_auth", - keyvalues={ - "event_id": event_id, - }, - retcol="auth_id", - ) - - results = [] - for auth_id in auth_ids: - hashes = self._get_event_reference_hashes_txn(txn, auth_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() - if k == "sha256" - } - results.append((auth_id, prev_hashes)) - - return results - def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ @@ -303,6 +211,15 @@ class EventFederationStore(SQLBaseStore): ], ) + self._update_extremeties(txn, events) + + def _update_extremeties(self, txn, events): + """Updates the event_*_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are are now being persisted as non-outliers. + """ events_by_room = {} for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index fba837f461..416ef6af93 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from _base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer, reactor @@ -28,6 +27,7 @@ from canonicaljson import encode_canonical_json from contextlib import contextmanager import logging +import math import ujson as json logger = logging.getLogger(__name__) @@ -281,6 +281,8 @@ class EventsStore(SQLBaseStore): (False, event.event_id,) ) + self._update_extremeties(txn, [event]) + events_and_contexts = filter( lambda ec: ec[0] not in to_remove, events_and_contexts @@ -888,18 +890,69 @@ class EventsStore(SQLBaseStore): return ev - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - def _parse_events_txn(self, txn, rows): event_ids = [r["event_id"] for r in rows] return self._get_events_txn(txn, event_ids) - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None + @defer.inlineCallbacks + def count_daily_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + def _count_messages(txn): + now = self.hs.get_clock().time() + + txn.execute( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + last_reported = self.cursor_to_dict(txn) + + txn.execute( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + now_reporting = self.cursor_to_dict(txn) + if not now_reporting: + return None + now_reporting = now_reporting[0]["stream_ordering"] + + txn.execute("DELETE FROM stats_reporting") + txn.execute( + "INSERT INTO stats_reporting" + " (reported_stream_token, reported_time)" + " VALUES (?, ?)", + (now_reporting, now,) + ) + + if not last_reported: + return None + + # Close enough to correct for our purposes. + yesterday = (now - 24 * 60 * 60) + if math.fabs(yesterday - last_reported[0]["reported_time"]) > 60 * 60: + return None + + txn.execute( + "SELECT COUNT(*) as messages" + " FROM events NATURAL JOIN event_json" + " WHERE json like '%m.room.message%'" + " AND stream_ordering > ?" + " AND stream_ordering <= ?", + ( + last_reported[0]["reported_stream_token"], + now_reporting, + ) + ) + rows = self.cursor_to_dict(txn) + if not rows: + return None + return rows[0]["messages"] + + ret = yield self.runInteraction("count_messages", _count_messages) + defer.returnValue(ret) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 00b748f131..345c4e1104 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, Table +from ._base import SQLBaseStore from twisted.internet import defer from synapse.api.errors import StoreError @@ -149,5 +149,5 @@ class PusherStore(SQLBaseStore): ) -class PushersTable(Table): +class PushersTable(object): table_name = "pushers" diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index c9ceb132ae..b454dd5b3a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -289,3 +289,16 @@ class RegistrationStore(SQLBaseStore): if ret: defer.returnValue(ret['user_id']) defer.returnValue(None) + + @defer.inlineCallbacks + def count_all_users(self): + """Counts all users registered on the homeserver.""" + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8eee2dfbcc..8c40d9a8a6 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) RoomsForUser = namedtuple( "RoomsForUser", - ("room_id", "sender", "membership") + ("room_id", "sender", "membership", "event_id", "stream_ordering") ) @@ -141,11 +141,13 @@ class RoomMemberStore(SQLBaseStore): args.extend(membership_list) sql = ( - "SELECT m.room_id, m.sender, m.membership" - " FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id " - " AND m.room_id = c.room_id " + "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" + " FROM current_state_events as c" + " INNER JOIN room_memberships as m" + " ON m.event_id = c.event_id" + " INNER JOIN events as e" + " ON e.event_id = c.event_id" + " AND m.room_id = c.room_id" " AND m.user_id = c.state_key" " WHERE %s" ) % (where_clause,) @@ -176,12 +178,6 @@ class RoomMemberStore(SQLBaseStore): return joined_domains - def _get_members_query(self, where_clause, where_values): - return self.runInteraction( - "get_members_query", self._get_members_events_txn, - where_clause, where_values - ).addCallbacks(self._get_events) - def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): rows = self._get_members_rows_txn( txn, diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/schema/delta/24/stats_reporting.sql new file mode 100644 index 0000000000..e9165d2917 --- /dev/null +++ b/synapse/storage/schema/delta/24/stats_reporting.sql @@ -0,0 +1,22 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Should only ever contain one row +CREATE TABLE IF NOT EXISTS stats_reporting( + -- The stream ordering token which was most recently reported as stats + reported_stream_token INTEGER, + -- The time (seconds since epoch) stats were most recently reported + reported_time BIGINT +); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index ab57b92174..b070be504d 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -24,41 +24,6 @@ from synapse.crypto.event_signing import compute_event_reference_hash class SignatureStore(SQLBaseStore): """Persistence for event signatures and hashes""" - def _get_event_content_hashes_txn(self, txn, event_id): - """Get all the hashes for a given Event. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of algorithm -> hash. - """ - query = ( - "SELECT algorithm, hash" - " FROM event_content_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - return dict(txn.fetchall()) - - def _store_event_content_hash_txn(self, txn, event_id, algorithm, - hash_bytes): - """Store a hash for a Event - Args: - txn (cursor): - event_id (str): Id for the Event. - algorithm (str): Hashing algorithm. - hash_bytes (bytes): Hash function output bytes. - """ - self._simple_insert_txn( - txn, - "event_content_hashes", - { - "event_id": event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) - def get_event_reference_hashes(self, event_ids): def f(txn): return [ @@ -123,80 +88,3 @@ class SignatureStore(SQLBaseStore): table="event_reference_hashes", values=vals, ) - - def _get_event_signatures_txn(self, txn, event_id): - """Get all the signatures for a given PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - A dict of sig name -> dict(key_id -> signature_bytes) - """ - query = ( - "SELECT signature_name, key_id, signature" - " FROM event_signatures" - " WHERE event_id = ? " - ) - txn.execute(query, (event_id, )) - rows = txn.fetchall() - - res = {} - - for name, key, sig in rows: - res.setdefault(name, {})[key] = sig - - return res - - def _store_event_signature_txn(self, txn, event_id, signature_name, key_id, - signature_bytes): - """Store a signature from the origin server for a PDU. - Args: - txn (cursor): - event_id (str): Id for the Event. - origin (str): origin of the Event. - key_id (str): Id for the signing key. - signature (bytes): The signature. - """ - self._simple_insert_txn( - txn, - "event_signatures", - { - "event_id": event_id, - "signature_name": signature_name, - "key_id": key_id, - "signature": buffer(signature_bytes), - }, - ) - - def _get_prev_event_hashes_txn(self, txn, event_id): - """Get all the hashes for previous PDUs of a PDU - Args: - txn (cursor): - event_id (str): Id for the Event. - Returns: - dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. - """ - query = ( - "SELECT prev_event_id, algorithm, hash" - " FROM event_edge_hashes" - " WHERE event_id = ?" - ) - txn.execute(query, (event_id, )) - results = {} - for prev_event_id, algorithm, hash_bytes in txn.fetchall(): - hashes = results.setdefault(prev_event_id, {}) - hashes[algorithm] = hash_bytes - return results - - def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, - algorithm, hash_bytes): - self._simple_insert_txn( - txn, - "event_edge_hashes", - { - "event_id": event_id, - "prev_event_id": prev_event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, - ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 9630efcfcc..e935b9443b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,8 +20,6 @@ from synapse.util.caches.descriptors import ( from twisted.internet import defer -from synapse.util.stringutils import random_string - import logging logger = logging.getLogger(__name__) @@ -428,7 +426,3 @@ class StateStore(SQLBaseStore): } defer.returnValue(results) - - -def _make_group_id(clock): - return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d7fe423f5a..3cab06fdef 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -159,9 +159,7 @@ class StreamStore(SQLBaseStore): @log_function def get_room_events_stream(self, user_id, from_key, to_key, room_id, - limit=0, with_feedback=False): - # TODO (erikj): Handle compressed feedback - + limit=0): current_room_membership_sql = ( "SELECT m.room_id FROM room_memberships as m " " INNER JOIN current_state_events as c" @@ -227,10 +225,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, - direction='b', limit=-1, - with_feedback=False): - # TODO (erikj): Handle compressed feedback - + direction='b', limit=-1): # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. @@ -302,7 +297,6 @@ class StreamStore(SQLBaseStore): @cachedInlineCallbacks(num_args=4) def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): - # TODO (erikj): Handle compressed feedback end_token = RoomStreamToken.parse_stream_token(end_token) @@ -379,6 +373,38 @@ class StreamStore(SQLBaseStore): ) defer.returnValue("t%d-%d" % (topo, token)) + def get_stream_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "s%d" stream token. + """ + return self._simple_select_one_onecol( + table="events", + keyvalues={"event_id": event_id}, + retcol="stream_ordering", + ).addCallback(lambda row: "s%d" % (row,)) + + def get_topological_token_for_event(self, event_id): + """The stream token for an event + Args: + event_id(str): The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A deferred "t%d-%d" topological token. + """ + return self._simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "topological_ordering"), + ).addCallback(lambda row: "t%d-%d" % ( + row["topological_ordering"], row["stream_ordering"],) + ) + def _get_max_topological_txn(self, txn): txn.execute( "SELECT MAX(topological_ordering) FROM events" diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 2ec7c5403b..167bfe0de3 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -34,6 +34,11 @@ class SourcePaginationConfig(object): self.direction = 'f' if direction == 'f' else 'b' self.limit = int(limit) if limit is not None else None + def __repr__(self): + return ( + "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" + ) % (self.from_key, self.to_key, self.direction, self.limit) + class PaginationConfig(object): @@ -94,10 +99,10 @@ class PaginationConfig(object): logger.exception("Failed to create pagination config") raise SynapseError(400, "Invalid request.") - def __str__(self): + def __repr__(self): return ( - "" + "PaginationConfig(from_tok=%r, to_tok=%r," + " direction=%r, limit=%r)" ) % (self.from_token, self.to_token, self.direction, self.limit) def get_source_config(self, source_name): diff --git a/synapse/streams/events.py b/synapse/streams/events.py index aaa3609aa5..699083ae12 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -23,22 +23,6 @@ from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource -class NullSource(object): - """This event source never yields any events and its token remains at - zero. It may be useful for unit-testing.""" - def __init__(self, hs): - pass - - def get_new_events_for_user(self, user, from_key, limit): - return defer.succeed(([], from_key)) - - def get_current_key(self, direction='f'): - return defer.succeed(0) - - def get_pagination_rows(self, user, pagination_config, key): - return defer.succeed(([], pagination_config.from_key)) - - class EventSources(object): SOURCE_TYPES = { "room": RoomEventSource, @@ -70,15 +54,3 @@ class EventSources(object): ), ) defer.returnValue(token) - - -class StreamSource(object): - def get_new_events_for_user(self, user, from_key, limit): - """from_key is the key within this event source.""" - raise NotImplementedError("get_new_events_for_user") - - def get_current_key(self): - raise NotImplementedError("get_current_key") - - def get_pagination_rows(self, user, pagination_config, key): - raise NotImplementedError("get_rows") diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 07ff25cef3..1d123ccefc 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -29,34 +29,6 @@ def unwrapFirstError(failure): return failure.value.subFailure -def unwrap_deferred(d): - """Given a deferred that we know has completed, return its value or raise - the failure as an exception - """ - if not d.called: - raise RuntimeError("deferred has not finished") - - res = [] - - def f(r): - res.append(r) - return r - d.addCallback(f) - - if res: - return res[0] - - def f(r): - res.append(r) - return r - d.addErrback(f) - - if res: - res[0].raiseException() - else: - raise RuntimeError("deferred did not call callbacks") - - class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 22fc804331..c96273480d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -19,17 +19,21 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.errors import AuthError +from synapse.types import UserID +from tests.utils import setup_test_homeserver + +import pymacaroons class AuthTestCase(unittest.TestCase): + @defer.inlineCallbacks def setUp(self): self.state_handler = Mock() self.store = Mock() - self.hs = Mock() + self.hs = yield setup_test_homeserver(handlers=None) self.hs.get_datastore = Mock(return_value=self.store) - self.hs.get_state_handler = Mock(return_value=self.state_handler) self.auth = Auth(self.hs) self.test_user = "@foo:bar" @@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = Mock(return_value=[""]) d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) + + @defer.inlineCallbacks + def test_get_user_from_macaroon(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize()) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_user_db_mismatch(self): + self.store.get_user_by_access_token = Mock( + return_value={"name": "@percy:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("User mismatch", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_missing_caveat(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("No user caveat", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_wrong_key(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key + "wrong") + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_unknown_caveat(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("cunning > fox") + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_expired(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("time < 1") # ms + + self.hs.clock.now = 5000 # seconds + + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # TODO(daniel): Turn on the check that we validate expiration, when we + # validate expiration (and remove the above line, which will start + # throwing). + # with self.assertRaises(AuthError) as cm: + # yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # self.assertEqual(401, cm.exception.code) + # self.assertIn("Invalid macaroon", cm.exception.msg) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 91547bdd06..29d9bbaad4 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -41,6 +41,22 @@ myid = "@apple:test" PATH_PREFIX = "/_matrix/client/api/v1" +class NullSource(object): + """This event source never yields any events and its token remains at + zero. It may be useful for unit-testing.""" + def __init__(self, hs): + pass + + def get_new_events_for_user(self, user, from_key, limit): + return defer.succeed(([], from_key)) + + def get_current_key(self, direction='f'): + return defer.succeed(0) + + def get_pagination_rows(self, user, pagination_config, key): + return defer.succeed(([], pagination_config.from_key)) + + class JustPresenceHandlers(object): def __init__(self, hs): self.presence_handler = PresenceHandler(hs) @@ -76,7 +92,7 @@ class PresenceStateTestCase(unittest.TestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ @@ -169,7 +185,7 @@ class PresenceListTestCase(unittest.TestCase): ] ) - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token presence.register_servlets(hs, self.mock_resource) @@ -243,7 +259,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): # HIDEOUS HACKERY # TODO(paul): This should be injected in via the HomeServer DI system from synapse.streams.events import ( - PresenceEventSource, NullSource, EventSources + PresenceEventSource, EventSources ) old_SOURCE_TYPES = EventSources.SOURCE_TYPES diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 34ab47d02e..a2123be81b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -239,7 +239,7 @@ class RoomPermissionsTestCase(RestTestCase): "PUT", topic_path, topic_content) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(200, code, msg=str(response)) # get topic in PUBLIC room, not joined, expect 403 (code, response) = yield self.mock_resource.trigger_get( @@ -301,11 +301,11 @@ class RoomPermissionsTestCase(RestTestCase): room=room, expect_code=200) # get membership of self, get membership of other, private room + left - # expect all 403s + # expect all 200s yield self.leave(room=room, user=self.user_id) yield self._test_get_membership( members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + room=room, expect_code=200) @defer.inlineCallbacks def test_membership_public_room_perms(self): @@ -326,11 +326,11 @@ class RoomPermissionsTestCase(RestTestCase): room=room, expect_code=200) # get membership of self, get membership of other, public room + left - # expect 403. + # expect 200. yield self.leave(room=room, user=self.user_id) yield self._test_get_membership( members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + room=room, expect_code=200) @defer.inlineCallbacks def test_invited_permissions(self): @@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -492,9 +492,9 @@ class RoomsMemberListTestCase(RestTestCase): self.assertEquals(200, code, msg=str(response)) yield self.leave(room=room_id, user=self.user_id) - # can no longer see list, you've left. + # can see old list once left (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(200, code, msg=str(response)) class RoomsCreateTestCase(RestTestCase): @@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 1c4519406d..6395ce79db 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c472d53043..85096a0326 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase): self.mock_resource = None self.auth_user_id = None - def mock_get_user_by_access_token(self, token=None): - return self.auth_user_id - @defer.inlineCallbacks def create_room_as(self, room_creator, is_public=True, tok=None): temp_id = self.auth_user_id diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index ef972a53aa..f45570a1c0 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase): "user": UserID.from_string(self.USER_ID), "token_id": 1, } - hs.get_auth().get_user_by_access_token = _get_user_by_access_token + hs.get_auth()._get_user_by_access_token = _get_user_by_access_token for r in self.TO_REGISTER: r.register_servlets(hs, self.mock_resource) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py new file mode 100644 index 0000000000..42bd8928bd --- /dev/null +++ b/tests/storage/event_injector.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.types import UserID, RoomID + +from tests.utils import setup_test_homeserver + +from mock import Mock + + +class EventInjector: + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.message_handler = hs.get_handlers().message_handler + self.event_builder_factory = hs.get_event_builder_factory() + + @defer.inlineCallbacks + def create_room(self, room): + builder = self.event_builder_factory.new({ + "type": EventTypes.Create, + "room_id": room.to_string(), + "content": {}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + @defer.inlineCallbacks + def inject_room_member(self, room, user, membership): + builder = self.event_builder_factory.new({ + "type": EventTypes.Member, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"membership": membership}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_message(self, room, user, body): + builder = self.event_builder_factory.new({ + "type": EventTypes.Message, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"body": body, "msgtype": u"message"}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 8573f18b55..1ddca1da4c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -185,26 +185,6 @@ class SQLBaseStoreTestCase(unittest.TestCase): [3, 4, 1, 2] ) - @defer.inlineCallbacks - def test_update_one_with_return(self): - self.mock_txn.rowcount = 1 - self.mock_txn.fetchone.return_value = ("Old Value",) - - ret = yield self.datastore._simple_selectupdate_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columname": "New Value"}, - retcols=["columname"] - ) - - self.assertEquals({"columname": "Old Value"}, ret) - self.mock_txn.execute.assert_has_calls([ - call('SELECT columname FROM tablename WHERE keycol = ?', - ['TheKey']), - call("UPDATE tablename SET columname = ? WHERE keycol = ?", - ["New Value", "TheKey"]) - ]) - @defer.inlineCallbacks def test_delete_one(self): self.mock_txn.rowcount = 1 diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py new file mode 100644 index 0000000000..313013009e --- /dev/null +++ b/tests/storage/test_events.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import uuid +from mock.mock import Mock +from synapse.types import RoomID, UserID + +from tests import unittest +from twisted.internet import defer +from tests.storage.event_injector import EventInjector + +from tests.utils import setup_test_homeserver + + +class EventsStoreTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + resource_for_federation=Mock(), + http_client=None, + ) + self.store = self.hs.get_datastore() + self.db_pool = self.hs.get_db_pool() + self.message_handler = self.hs.get_handlers().message_handler + self.event_injector = EventInjector(self.hs) + + @defer.inlineCallbacks + def test_count_daily_messages(self): + self.db_pool.runQuery("DELETE FROM stats_reporting") + + self.hs.clock.now = 100 + + # Never reported before, and nothing which could be reported + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting") + self.assertEqual([(0,)], count) + + # Create something to report + room = RoomID.from_string("!abc123:test") + user = UserID.from_string("@raccoonlover:test") + yield self.event_injector.create_room(room) + + self.base_event = yield self._get_last_stream_token() + + yield self.event_injector.inject_message(room, user, "Raccoons are really cute") + + # Never reported before, something could be reported, but isn't because + # it isn't old enough. + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(1, self.hs.clock.now) + + # Already reported yesterday, two new events from today. + yield self.event_injector.inject_message(room, user, "Yeah they are!") + yield self.event_injector.inject_message(room, user, "Incredibly!") + self.hs.clock.now += 60 * 60 * 24 + count = yield self.store.count_daily_messages() + self.assertEqual(2, count) # 2 since yesterday + self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + + # Last reported too recently. + yield self.event_injector.inject_message(room, user, "Who could disagree?") + self.hs.clock.now += 60 * 60 * 22 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(4, self.hs.clock.now) + + # Last reported too long ago + yield self.event_injector.inject_message(room, user, "No one.") + self.hs.clock.now += 60 * 60 * 26 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(5, self.hs.clock.now) + + # And now let's actually report something + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + # A little over 24 hours is fine :) + self.hs.clock.now += (60 * 60 * 24) + 50 + count = yield self.store.count_daily_messages() + self.assertEqual(3, count) + self._assert_stats_reporting(8, self.hs.clock.now) + + @defer.inlineCallbacks + def _get_last_stream_token(self): + rows = yield self.db_pool.runQuery( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + if not rows: + defer.returnValue(0) + else: + defer.returnValue(rows[0][0]) + + @defer.inlineCallbacks + def _assert_stats_reporting(self, messages, time): + rows = yield self.db_pool.runQuery( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + self.assertEqual([(self.base_event + messages, time,)], rows) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ab7625a3ca..caffce64e3 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() - self.event_factory = hs.get_event_factory(); + self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 0c9b89d765..a658a789aa 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, RoomID +from tests.storage.event_injector import EventInjector from tests.utils import setup_test_homeserver @@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase): self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() + self.event_injector = EventInjector(hs) self.handlers = hs.get_handlers() self.message_handler = self.handlers.message_handler @@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") self.room2 = RoomID.from_string("!xyx987:test") - self.depth = 1 - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - defer.returnValue(event) - - @defer.inlineCallbacks - def inject_message(self, room, user, body): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - @defer.inlineCallbacks def test_event_stream_get_other(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_get_own(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_join_leave(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Then bob leaves again. - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.LEAVE ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_prev_content(self): - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) - event1 = yield self.inject_room_member( + event1 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) start = yield self.store.get_room_events_max_id() - event2 = yield self.inject_room_member( + event2 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN, )