Merge branch 'develop' into HEAD

This commit is contained in:
Erik Johnston 2016-04-11 10:40:49 +01:00
commit 58de897c77
84 changed files with 3118 additions and 1876 deletions

View file

@ -57,3 +57,6 @@ Florent Violleau <floviolleau at gmail dot com>
Niklas Riekenbrauck <nikriek at gmail dot.com> Niklas Riekenbrauck <nikriek at gmail dot.com>
* Add JWT support for registration and login * Add JWT support for registration and login
Christoph Witzany <christoph at web.crofting.com>
* Add LDAP support for authentication

View file

@ -118,7 +118,6 @@ Installing prerequisites on CentOS 7::
python-virtualenv libffi-devel openssl-devel python-virtualenv libffi-devel openssl-devel
sudo yum groupinstall "Development Tools" sudo yum groupinstall "Development Tools"
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
xcode-select --install xcode-select --install
@ -150,12 +149,7 @@ In case of problems, please see the _Troubleshooting section below.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/. above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
Another alternative is to install via apt from http://matrix.org/packages/debian/. Also, Martin Giess has created an auto-deployment process with vagrant/ansible,
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with
https://github.com/matrix-org/matrix-js-sdk/).
Finally, Martin Giess has created an auto-deployment process with vagrant/ansible,
tested with VirtualBox/AWS/DigitalOcean - see https://github.com/EMnify/matrix-synapse-auto-deploy tested with VirtualBox/AWS/DigitalOcean - see https://github.com/EMnify/matrix-synapse-auto-deploy
for details. for details.
@ -229,6 +223,19 @@ For information on how to install and use PostgreSQL, please see
Platform Specific Instructions Platform Specific Instructions
============================== ==============================
Debian
------
Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with one of our SDKs :)
Fedora
------
Oleg Girko provides Fedora RPMs at
https://obs.infoserver.lv/project/monitor/matrix-synapse
ArchLinux ArchLinux
--------- ---------
@ -270,11 +277,17 @@ During setup of Synapse you need to call python2.7 directly again::
FreeBSD FreeBSD
------- -------
Synapse can be installed via FreeBSD Ports or Packages: Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean`` - Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
- Packages: ``pkg install py27-matrix-synapse`` - Packages: ``pkg install py27-matrix-synapse``
NixOS
-----
Robin Lambertz has packaged Synapse for NixOS at:
https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix
Windows Install Windows Install
--------------- ---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages: Synapse can be installed on Cygwin. It requires the following Cygwin packages:

View file

@ -19,6 +19,7 @@ from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse import argparse
import curses import curses
@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = {
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"],
} }
@ -292,7 +294,7 @@ class Porter(object):
} }
) )
database_engine.prepare_database(db_conn) prepare_database(db_conn, database_engine, config=None)
db_conn.commit() db_conn.commit()
@ -309,8 +311,8 @@ class Porter(object):
**self.postgres_config["args"] **self.postgres_config["args"]
) )
sqlite_engine = create_engine(FakeConfig(sqlite_config)) sqlite_engine = create_engine(sqlite_config)
postgres_engine = create_engine(FakeConfig(postgres_config)) postgres_engine = create_engine(postgres_config)
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine)
@ -792,8 +794,3 @@ if __name__ == "__main__":
if end_error_exec_info: if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback) traceback.print_exception(exc_type, exc_value, exc_traceback)
class FakeConfig:
def __init__(self, database_config):
self.database_config = database_config

View file

@ -17,3 +17,6 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
[pep8]
max-line-length = 90

View file

@ -33,7 +33,7 @@ from synapse.python_dependencies import (
from synapse.rest import ClientRestResource from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.server import HomeServer from synapse.server import HomeServer
@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self): def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should # Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine. # not be passed to the database engine.
db_params = { db_params = {
@ -254,7 +254,8 @@ class SynapseHomeServer(HomeServer):
} }
db_conn = self.database_engine.module.connect(**db_params) db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn) if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn return db_conn
@ -386,7 +387,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config) database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
@ -402,8 +403,10 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name']) logger.info("Preparing database: %s...", config.database_config['name'])
try: try:
db_conn = hs.get_db_conn() db_conn = hs.get_db_conn(run_new_connection=False)
database_engine.prepare_database(db_conn) prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
hs.run_startup_checks(db_conn, database_engine) hs.run_startup_checks(db_conn, database_engine)
db_conn.commit() db_conn.commit()

View file

@ -100,11 +100,6 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("push_bulk to %s threw exception %s", uri, ex) logger.warning("push_bulk to %s threw exception %s", uri, ex)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks
def push(self, service, event, txn_id=None):
response = yield self.push_bulk(service, [event], txn_id)
defer.returnValue(response)
def _serialize(self, events): def _serialize(self, events):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
return [ return [

View file

@ -29,13 +29,15 @@ from .key import KeyConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
from .cas import CasConfig from .cas import CasConfig
from .password import PasswordConfig from .password import PasswordConfig
from .jwt import JWTConfig
from .ldap import LDAPConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,): JWTConfig, LDAPConfig, PasswordConfig,):
pass pass

37
synapse/config/jwt.py Normal file
View file

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class JWTConfig(Config):
def read_config(self, config):
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
else:
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
def default_config(self, **kwargs):
return """\
# jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
"""

52
synapse/config/ldap.py Normal file
View file

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class LDAPConfig(Config):
def read_config(self, config):
ldap_config = config.get("ldap_config", None)
if ldap_config:
self.ldap_enabled = ldap_config.get("enabled", False)
self.ldap_server = ldap_config["server"]
self.ldap_port = ldap_config["port"]
self.ldap_tls = ldap_config.get("tls", False)
self.ldap_search_base = ldap_config["search_base"]
self.ldap_search_property = ldap_config["search_property"]
self.ldap_email_property = ldap_config["email_property"]
self.ldap_full_name_property = ldap_config["full_name_property"]
else:
self.ldap_enabled = False
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = False
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
# server: "ldap://localhost"
# port: 389
# tls: false
# search_base: "ou=Users,dc=example,dc=com"
# search_property: "cn"
# email_property: "email"
# full_name_property: "givenName"
"""

View file

@ -31,7 +31,10 @@ class _EventInternalMetadata(object):
return dict(self.__dict__) return dict(self.__dict__)
def is_outlier(self): def is_outlier(self):
return hasattr(self, "outlier") and self.outlier return getattr(self, "outlier", False)
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def _event_dict_property(key): def _event_dict_property(key):

View file

@ -17,8 +17,9 @@ from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, RoomCreationHandler, RoomListHandler, RoomContextHandler,
) )
from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler

View file

@ -37,12 +37,22 @@ VISIBILITY_PRIORITY = (
) )
MEMBERSHIP_PRIORITY = (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
Membership.LEAVE,
Membership.BAN,
)
class BaseHandler(object): class BaseHandler(object):
""" """
Common base class for the event handlers. Common base class for the event handlers.
:type store: synapse.storage.events.StateStore Attributes:
:type state_handler: synapse.state.StateHandler store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler):
""" """
def __init__(self, hs): def __init__(self, hs):
@ -65,11 +75,13 @@ class BaseHandler(object):
""" Returns dict of user_id -> list of events that user is allowed to """ Returns dict of user_id -> list of events that user is allowed to
see. see.
:param (str, bool) user_tuples: (user id, is_peeking) for each Args:
user to be checked. is_peeking should be true if: user_tuples (str, bool): (user id, is_peeking) for each user to be
* the user is not currently a member of the room, and: checked. is_peeking should be true if:
* the user has not been a member of the room since the given * the user is not currently a member of the room, and:
events * the user has not been a member of the room since the
given events
events ([synapse.events.EventBase]): list of events to filter
""" """
forgotten = yield defer.gatherResults([ forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room( self.store.who_forgot_in_room(
@ -84,6 +96,12 @@ class BaseHandler(object):
) )
def allowed(event, user_id, is_peeking): def allowed(event, user_id, is_peeking):
"""
Args:
event (synapse.events.EventBase): event to check
user_id (str)
is_peeking (bool)
"""
state = event_id_to_state[event.event_id] state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event. # get the room_visibility at the time of the event.
@ -115,17 +133,30 @@ class BaseHandler(object):
if old_priority < new_priority: if old_priority < new_priority:
visibility = prev_visibility visibility = prev_visibility
# get the user's membership at the time of the event. (or rather, # likewise, if the event is the user's own membership event, use
# just *after* the event. Which means that people can see their # the 'most joined' membership
# own join events, but not (currently) their own leave events.) membership = None
membership_event = state.get((EventTypes.Member, user_id), None) if event.type == EventTypes.Member and event.state_key == user_id:
if membership_event: membership = event.content.get("membership", None)
if membership_event.event_id in event_id_forgotten: if membership not in MEMBERSHIP_PRIORITY:
membership = None membership = "leave"
else:
membership = membership_event.membership prev_content = event.unsigned.get("prev_content", {})
else: prev_membership = prev_content.get("membership", None)
membership = None if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority:
membership = prev_membership
# otherwise, get the user's membership at the time of the event.
if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id not in event_id_forgotten:
membership = membership_event.membership
# if the user was a member of the room at the time of the event, # if the user was a member of the room at the time of the event,
# they can see it. # they can see it.
@ -165,13 +196,16 @@ class BaseHandler(object):
""" """
Check which events a user is allowed to see Check which events a user is allowed to see
:param str user_id: user id to be checked Args:
:param [synapse.events.EventBase] events: list of events to be checked user_id(str): user id to be checked
:param bool is_peeking should be True if: events([synapse.events.EventBase]): list of events to be checked
is_peeking(bool): should be True if:
* the user is not currently a member of the room, and: * the user is not currently a member of the room, and:
* the user has not been a member of the room since the given * the user has not been a member of the room since the given
events events
:rtype [synapse.events.EventBase]
Returns:
[synapse.events.EventBase]
""" """
types = ( types = (
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@ -199,20 +233,25 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder, prev_event_ids=None):
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( if prev_event_ids:
builder.room_id, prev_events = yield self.store.add_event_hashes(prev_event_ids)
) prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else: else:
depth = 1 latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
prev_events = [ if latest_ret:
(event_id, prev_hashes) depth = max([d for _, _, d in latest_ret]) + 1
for event_id, prev_hashes, _ in latest_ret else:
] depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
@ -221,50 +260,6 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder) context = yield state_handler.compute_event_context(builder)
# If we've received an invite over federation, there are no latest
# events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite.
if (
not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member
):
prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id
)
# The prev_member_event may already be in context.current_state,
# despite us not being present in the room; in particular, if
# inviting user, and all other local users, have already left.
#
# In that case, we have all the information we need, and we don't
# want to drop "context" - not least because we may need to handle
# the invite locally, which will require us to have the whole
# context (not just prev_member_event) to auth it.
#
context_event_ids = (
e.event_id for e in context.current_state.values()
)
if (
prev_member_event and
prev_member_event.event_id not in context_event_ids
):
# The prev_member_event is missing from context, so it must
# have arrived over federation and is an outlier. We forcibly
# set our context to the invite we received over federation
builder.prev_events = (
prev_member_event.event_id,
prev_member_event.prev_events
)
context = yield state_handler.compute_event_context(
builder,
old_state=(prev_member_event,),
outlier=True
)
if builder.is_state(): if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes( builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events context.prev_state_events

View file

@ -49,6 +49,21 @@ class AuthHandler(BaseHandler):
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401 self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled
self.ldap_server = hs.config.ldap_server
self.ldap_port = hs.config.ldap_port
self.ldap_tls = hs.config.ldap_tls
self.ldap_search_base = hs.config.ldap_search_base
self.ldap_search_property = hs.config.ldap_search_property
self.ldap_email_property = hs.config.ldap_email_property
self.ldap_full_name_property = hs.config.ldap_full_name_property
if self.ldap_enabled is True:
import ldap
logger.info("Import ldap version: %s", ldap.__version__)
self.hs = hs # FIXME better possibility to access registrationHandler later?
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
""" """
@ -163,9 +178,13 @@ class AuthHandler(BaseHandler):
def get_session_id(self, clientdict): def get_session_id(self, clientdict):
""" """
Gets the session ID for a client given the client dictionary Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not Args:
send a session ID, returns None. clientdict: The dictionary sent by the client in the request
Returns:
str|None: The string session ID the client sent. If the client did
not send a session ID, returns None.
""" """
sid = None sid = None
if clientdict and 'auth' in clientdict: if clientdict and 'auth' in clientdict:
@ -179,9 +198,11 @@ class AuthHandler(BaseHandler):
Store a key-value pair into the sessions data associated with this Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by request. This data is stored server-side and cannot be modified by
the client. the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under Args:
:param value: (any) The data to store session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
value (any): The data to store
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value sess.setdefault('serverdict', {})[key] = value
@ -190,9 +211,11 @@ class AuthHandler(BaseHandler):
def get_session_data(self, session_id, key, default=None): def get_session_data(self, session_id, key, default=None):
""" """
Retrieve data stored with set_session_data Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under Args:
:param default: (any) Value to return if the key has not been set session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
default (any): Value to return if the key has not been set
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@ -207,8 +230,10 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) if not (yield self._check_password(user_id, password)):
self._check_password(user_id, password, password_hash) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -332,8 +357,10 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash) if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id) access_token = yield self.issue_access_token(user_id)
@ -399,11 +426,60 @@ class AuthHandler(BaseHandler):
else: else:
defer.returnValue(user_infos.popitem()) defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash): @defer.inlineCallbacks
"""Checks that user_id has passed password, raises LoginError if not.""" def _check_password(self, user_id, password):
if not self.validate_hash(password, stored_hash): defer.returnValue(
logger.warn("Failed password login for user %s", user_id) not (
raise LoginError(403, "", errcode=Codes.FORBIDDEN) (yield self._check_ldap_password(user_id, password))
or
(yield self._check_local_password(user_id, password))
))
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(not self.validate_hash(password, password_hash))
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
if self.ldap_enabled is not True:
logger.debug("LDAP not configured")
defer.returnValue(False)
import ldap
logger.info("Authenticating %s with LDAP" % user_id)
try:
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
logger.debug("Connecting LDAP server at %s" % ldap_url)
l = ldap.initialize(ldap_url)
if self.ldap_tls:
logger.debug("Initiating TLS")
self._connection.start_tls_s()
local_name = UserID.from_string(user_id).localpart
dn = "%s=%s, %s" % (
self.ldap_search_property,
local_name,
self.ldap_search_base)
logger.debug("DN for LDAP authentication: %s" % dn)
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
if not (yield self.does_user_exist(user_id)):
handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield handler.register(localpart=local_name)
)
defer.returnValue(True)
except ldap.LDAPError, e:
logger.warn("LDAP error: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id): def issue_access_token(self, user_id):

View file

@ -40,6 +40,7 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.util.distributor import user_joined_room
from twisted.internet import defer from twisted.internet import defer
@ -49,10 +50,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user, room_id)
class FederationHandler(BaseHandler): class FederationHandler(BaseHandler):
"""Handles events that originated from federation. """Handles events that originated from federation.
Responsible for: Responsible for:
@ -102,8 +99,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -174,11 +170,7 @@ class FederationHandler(BaseHandler):
}) })
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
yield self._handle_new_events( yield self._handle_new_events(origin, event_infos)
origin,
event_infos,
outliers=True
)
try: try:
context, event_stream_id, max_stream_id = yield self._handle_new_event( context, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -289,6 +281,9 @@ class FederationHandler(BaseHandler):
def backfill(self, dest, room_id, limit, extremities=[]): def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
""" """
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
if not extremities: if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id) extremities = yield self.store.get_oldest_events_in_room(room_id)
@ -455,7 +450,7 @@ class FederationHandler(BaseHandler):
likely_domains = [ likely_domains = [
domain for domain, depth in curr_domains domain for domain, depth in curr_domains
if domain is not self.server_name if domain != self.server_name
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@ -761,6 +756,7 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
@ -788,13 +784,19 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id): def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event( try:
target_hosts, origin, event = yield self._make_and_verify_event(
room_id, target_hosts,
user_id, room_id,
"leave" user_id,
) "leave"
signed_event = self._sign_event(event) )
signed_event = self._sign_event(event)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
# Try the host we successfully got a response to /make_join/ # Try the host we successfully got a response to /make_join/
# request first. # request first.
@ -804,10 +806,16 @@ class FederationHandler(BaseHandler):
except ValueError: except ValueError:
pass pass
yield self.replication_layer.send_leave( try:
target_hosts, yield self.replication_layer.send_leave(
signed_event target_hosts,
) signed_event
)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
@ -1069,9 +1077,6 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None): def _handle_new_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self._prep_event( context = yield self._prep_event(
origin, event, origin, event,
state=state, state=state,
@ -1087,14 +1092,12 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
is_new_state=not outlier,
) )
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False, def _handle_new_events(self, origin, event_infos, backfilled=False):
outliers=False):
contexts = yield defer.gatherResults( contexts = yield defer.gatherResults(
[ [
self._prep_event( self._prep_event(
@ -1113,7 +1116,6 @@ class FederationHandler(BaseHandler):
for ev_info, context in itertools.izip(event_infos, contexts) for ev_info, context in itertools.izip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1128,11 +1130,9 @@ class FederationHandler(BaseHandler):
""" """
events_to_context = {} events_to_context = {}
for e in itertools.chain(auth_events, state): for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
ctx = yield self.state_handler.compute_event_context(e)
events_to_context[e.event_id] = ctx
event_map = { event_map = {
e.event_id: e e.event_id: e
@ -1176,16 +1176,14 @@ class FederationHandler(BaseHandler):
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
], ],
is_new_state=False,
) )
new_event_context = yield self.state_handler.compute_event_context( new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False, event, old_state=state
) )
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
is_new_state=True,
current_state=state, current_state=state,
) )
@ -1193,10 +1191,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None): def _prep_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=outlier, event, old_state=state,
) )
if not auth_events: if not auth_events:
@ -1718,13 +1715,15 @@ class FederationHandler(BaseHandler):
def _check_signature(self, event, auth_events): def _check_signature(self, event, auth_events):
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
:param event (Event): The m.room.member event to check
:param auth_events (dict<(event type, state_key), event>)
:raises Args:
AuthError if signature didn't match any keys, or key has been event (Event): The m.room.member event to check
auth_events (dict<(event type, state_key), event>):
Raises:
AuthError: if signature didn't match any keys, or key has been
revoked, revoked,
SynapseError if a transient error meant a key couldn't be checked SynapseError: if a transient error meant a key couldn't be checked
for revocation. for revocation.
""" """
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
@ -1766,12 +1765,13 @@ class FederationHandler(BaseHandler):
""" """
Checks whether public_key has been revoked. Checks whether public_key has been revoked.
:param public_key (str): base-64 encoded public key. Args:
:param url (str): Key revocation URL. public_key (str): base-64 encoded public key.
url (str): Key revocation URL.
:raises Raises:
AuthError if they key has been revoked. AuthError: if they key has been revoked.
SynapseError if a transient error meant a key couldn't be checked SynapseError: if a transient error meant a key couldn't be checked
for revocation. for revocation.
""" """
try: try:

View file

@ -21,6 +21,7 @@ from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken from synapse.types import UserID, RoomStreamToken, StreamToken
@ -33,10 +34,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class MessageHandler(BaseHandler): class MessageHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -47,35 +44,6 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache() self.snapshot_cache = SnapshotCache()
@defer.inlineCallbacks
def get_message(self, msg_id=None, room_id=None, sender_id=None,
user_id=None):
""" Retrieve a message.
Args:
msg_id (str): The message ID to obtain.
room_id (str): The room where the message resides.
sender_id (str): The user ID of the user who sent the message.
user_id (str): The user ID of the user making this request.
Returns:
The message, or None if no message exists.
Raises:
SynapseError if something went wrong.
"""
yield self.auth.check_joined_room(room_id, user_id)
# Pull out the message from the db
# msg = yield self.store.get_message(
# room_id=room_id,
# msg_id=msg_id,
# user_id=sender_id
# )
# TODO (erikj): Once we work out the correct c-s api we need to think
# on how to do this.
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True): as_client_event=True):
@ -175,7 +143,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None): def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
""" """
Given a dict from a client, create a new event. Given a dict from a client, create a new event.
@ -186,6 +154,9 @@ class MessageHandler(BaseHandler):
Args: Args:
event_dict (dict): An entire event event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns: Returns:
Tuple of created event (FrozenEvent), Context Tuple of created event (FrozenEvent), Context
@ -198,12 +169,8 @@ class MessageHandler(BaseHandler):
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key) target = UserID.from_string(builder.state_key)
if membership == Membership.JOIN: if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield collect_presencelike_data(
self.distributor, target, builder.content
)
elif membership == Membership.INVITE:
profile = self.hs.get_handlers().profile_handler profile = self.hs.get_handlers().profile_handler
content = builder.content content = builder.content
@ -224,6 +191,7 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event( event, context = yield self._create_new_client_event(
builder=builder, builder=builder,
prev_event_ids=prev_event_ids,
) )
defer.returnValue((event, context)) defer.returnValue((event, context))
@ -556,14 +524,7 @@ class MessageHandler(BaseHandler):
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
# Only do N rooms at once yield concurrently_execute(handle_room, room_list, 10)
n = 5
d_list = [handle_room(e) for e in room_list]
for i in range(0, len(d_list), n):
yield defer.gatherResults(
d_list[i:i + n],
consumeErrors=True
).addErrback(unwrapFirstError)
account_data_events = [] account_data_events = []
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():

View file

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester from synapse.types import UserID, Requester
from synapse.util import unwrapFirstError
from ._base import BaseHandler from ._base import BaseHandler
@ -27,14 +26,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def changed_presencelike_data(distributor, user, state):
return distributor.fire("changed_presencelike_data", user, state)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class ProfileHandler(BaseHandler): class ProfileHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -46,17 +37,9 @@ class ProfileHandler(BaseHandler):
) )
distributor = hs.get_distributor() distributor = hs.get_distributor()
self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user) distributor.observe("registered_user", self.registered_user)
distributor.observe(
"collect_presencelike_data", self.collect_presencelike_data
)
def registered_user(self, user): def registered_user(self, user):
return self.store.create_profile(user.localpart) return self.store.create_profile(user.localpart)
@ -105,10 +88,6 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname target_user.localpart, new_displayname
) )
yield changed_presencelike_data(self.distributor, target_user, {
"displayname": new_displayname,
})
yield self._update_join_states(requester) yield self._update_join_states(requester)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -152,30 +131,8 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url target_user.localpart, new_avatar_url
) )
yield changed_presencelike_data(self.distributor, target_user, {
"avatar_url": new_avatar_url,
})
yield self._update_join_states(requester) yield self._update_join_states(requester)
@defer.inlineCallbacks
def collect_presencelike_data(self, user, state):
if not self.hs.is_mine(user):
defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
).addErrback(unwrapFirstError)
state["displayname"] = displayname
state["avatar_url"] = avatar_url
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_profile_query(self, args): def on_profile_query(self, args):
user = UserID.from_string(args["user_id"]) user = UserID.from_string(args["user_id"])

View file

@ -23,6 +23,7 @@ from synapse.api.errors import (
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.util.distributor import registered_user
import logging import logging
import urllib import urllib
@ -30,10 +31,6 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def registered_user(distributor, user):
return distributor.fire("registered_user", user)
class RegistrationHandler(BaseHandler): class RegistrationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):

View file

@ -18,19 +18,16 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset, EventTypes, JoinRules, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from collections import OrderedDict from collections import OrderedDict
from unpaddedbase64 import decode_base64
import logging import logging
import math import math
@ -41,20 +38,6 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://" id_server_scheme = "https://"
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
PRESETS_DICT = { PRESETS_DICT = {
@ -356,594 +339,23 @@ class RoomCreationHandler(BaseHandler):
) )
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
elif action == "forget":
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": effective_membership_state}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": effective_membership_state,
},
token_id=requester.access_token_id,
txn_id=txn_id,
)
old_state = context.current_state.get((EventTypes.Member, event.state_key))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(
requester,
event,
context,
ratelimit=ratelimit,
remote_room_hosts=remote_room_hosts,
)
if action == "forget":
yield self.forget(requester.user, room_id)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
action = "send"
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
do_remote_join_dance, remote_room_hosts = self._should_do_dance(
context,
(self.get_inviter(event.state_key, context.current_state)),
remote_room_hosts,
)
if do_remote_join_dance:
action = "remote_join"
elif event.membership == Membership.LEAVE:
is_host_in_room = self.is_host_in_room(context.current_state)
if not is_host_in_room:
# perhaps we've been invited
inviter = self.get_inviter(target_user.to_string(), context.current_state)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
action = "remote_reject"
federation_handler = self.hs.get_handlers().federation_handler
if action == "remote_join":
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield federation_handler.do_invite_join(
remote_room_hosts,
event.room_id,
event.user_id,
event.content,
)
elif action == "remote_reject":
yield federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
event.user_id
)
else:
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
def _should_do_dance(self, context, inviter, room_hosts=None):
# TODO: Shouldn't this be remote_room_host?
room_hosts = room_hosts or []
is_host_in_room = self.is_host_in_room(context.current_state)
if is_host_in_room:
return False, room_hosts
if inviter and not self.hs.is_mine(inviter):
room_hosts.append(inviter.domain)
return True, room_hosts
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
def get_inviter(self, user_id, current_state):
prev_state = current_state.get((EventTypes.Member, user_id))
if prev_state and prev_state.membership == Membership.INVITE:
return UserID.from_string(prev_state.user_id)
return None
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
(str) the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
:param id_server (str): hostname + optional port for the identity server.
:param medium (str): The literal string "email".
:param address (str): The third party address being invited.
:param room_id (str): The ID of the room to which the user is invited.
:param inviter_user_id (str): The user ID of the inviter.
:param room_alias (str): An alias for the room, for cosmetic
notifications.
:param room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
:param room_join_rules (str): The join rules of the email
(e.g. "public").
:param room_name (str): The m.room.name of the room.
:param inviter_display_name (str): The current display name of the
inviter.
:param inviter_avatar_url (str): The URL of the inviter's avatar.
:return: A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
def forget(self, user, room_id):
return self.store.forget(user.to_string(), room_id)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache()
def get_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
return result
@defer.inlineCallbacks @defer.inlineCallbacks
def get_public_room_list(self): def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids() room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_room(room_id): def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
@ -1004,18 +416,12 @@ class RoomListHandler(BaseHandler):
joined_users = yield self.store.get_users_in_room(room_id) joined_users = yield self.store.get_users_in_room(room_id)
result["num_joined_members"] = len(joined_users) result["num_joined_members"] = len(joined_users)
defer.returnValue(result) results.append(result)
result = [] yield concurrently_execute(handle_room, room_ids, 10)
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
chunk_result = yield defer.gatherResults([
handle_room(room_id)
for room_id in chunk
], consumeErrors=True).addErrback(unwrapFirstError)
result.extend(v for v in chunk_result if v)
# FIXME (erikj): START is no longer a valid value # FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": result}) defer.returnValue({"start": "START", "end": "END", "chunk": results})
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):

View file

@ -0,0 +1,722 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.member_linearizer = Linearizer()
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
prev_event_ids,
txn_id=None,
ratelimit=True,
):
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": membership}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": membership,
},
token_id=requester.access_token_id,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
)
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.hs.get_handlers().federation_handler.do_invite_join(
remote_room_hosts,
room_id,
user.to_string(),
content,
)
yield user_joined_room(self.distributor, user, room_id)
def reject_remote_invite(self, user_id, room_id, remote_room_hosts):
return self.hs.get_handlers().federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
user_id
)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
key = (target, room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
requester,
target,
room_id,
action,
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
if not remote_room_hosts:
remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state = yield self.state_handler.get_current_state(
room_id, latest_event_ids=latest_event_ids,
)
old_state = current_state.get((EventTypes.Member, target.to_string()))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
is_host_in_room = self.is_host_in_room(current_state)
if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content = {"membership": Membership.JOIN}
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
ret = yield self.remote_join(
remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
try:
ret = yield self.reject_remote_invite(
target.to_string(), room_id, remote_room_hosts
)
defer.returnValue(ret)
except SynapseError as e:
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
yield self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
def get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
Args:
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
def forget(self, user, room_id):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership != Membership.LEAVE:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
if membership:
yield self.store.forget(user_id, room_id)

View file

@ -17,9 +17,10 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from twisted.internet import defer from twisted.internet import defer
@ -35,6 +36,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [
"user", "user",
"filter_collection", "filter_collection",
"is_guest", "is_guest",
"request_key",
]) ])
@ -136,8 +138,8 @@ class SyncHandler(BaseHandler):
super(SyncHandler, self).__init__(hs) super(SyncHandler, self).__init__(hs)
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache()
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
"""Get the sync for a client if we have new data for it now. Otherwise """Get the sync for a client if we have new data for it now. Otherwise
@ -146,7 +148,19 @@ class SyncHandler(BaseHandler):
Returns: Returns:
A Deferred SyncResult. A Deferred SyncResult.
""" """
result = self.response_cache.get(sync_config.request_key)
if not result:
result = self.response_cache.set(
sync_config.request_key,
self._wait_for_sync_for_user(
sync_config, since_token, timeout, full_state
)
)
return result
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
full_state):
context = LoggingContext.current_context() context = LoggingContext.current_context()
if context: if context:
if since_token is None: if since_token is None:
@ -236,58 +250,50 @@ class SyncHandler(BaseHandler):
joined = [] joined = []
invited = [] invited = []
archived = [] archived = []
deferreds = []
room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)] user_id = sync_config.user.to_string()
for room_list_chunk in room_list_chunks:
for event in room_list_chunk:
if event.membership == Membership.JOIN:
room_sync_deferred = preserve_fn(
self.full_state_sync_for_joined_room
)(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(joined.append)
deferreds.append(room_sync_deferred)
elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if sync_config.user.to_string() == event.sender:
continue
leave_token = now_token.copy_and_replace( @defer.inlineCallbacks
"room_key", "s%d" % (event.stream_ordering,) def _generate_room_entry(event):
) if event.membership == Membership.JOIN:
room_sync_deferred = preserve_fn( room_result = yield self.full_state_sync_for_joined_room(
self.full_state_sync_for_archived_room room_id=event.room_id,
)( sync_config=sync_config,
sync_config=sync_config, now_token=now_token,
room_id=event.room_id, timeline_since_token=timeline_since_token,
leave_event_id=event.event_id, ephemeral_by_room=ephemeral_by_room,
leave_token=leave_token, tags_by_room=tags_by_room,
timeline_since_token=timeline_since_token, account_data_by_room=account_data_by_room,
tags_by_room=tags_by_room, )
account_data_by_room=account_data_by_room, joined.append(room_result)
) elif event.membership == Membership.INVITE:
room_sync_deferred.addCallback(archived.append) invite = yield self.store.get_event(event.event_id)
deferreds.append(room_sync_deferred) invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if user_id == event.sender:
return
yield defer.gatherResults( leave_token = now_token.copy_and_replace(
deferreds, consumeErrors=True "room_key", "s%d" % (event.stream_ordering,)
).addErrback(unwrapFirstError) )
room_result = yield self.full_state_sync_for_archived_room(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
archived.append(room_result)
yield concurrently_execute(_generate_room_entry, room_list, 10)
account_data_for_user = sync_config.filter_collection.filter_account_data( account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data) self.account_data_for_user(account_data)
@ -657,7 +663,8 @@ class SyncHandler(BaseHandler):
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None, recents=None, newly_joined_room=False): since_token=None, recents=None, newly_joined_room=False):
""" """
:returns a Deferred TimelineBatch Returns:
a Deferred TimelineBatch
""" """
with Measure(self.clock, "load_filtered_recents"): with Measure(self.clock, "load_filtered_recents"):
filtering_factor = 2 filtering_factor = 2
@ -824,8 +831,11 @@ class SyncHandler(BaseHandler):
""" """
Get the room state after the given event Get the room state after the given event
:param synapse.events.EventBase event: event of interest Args:
:return: A Deferred map from ((type, state_key)->Event) event(synapse.events.EventBase): event of interest
Returns:
A Deferred map from ((type, state_key)->Event)
""" """
state = yield self.store.get_state_for_event(event.event_id) state = yield self.store.get_state_for_event(event.event_id)
if event.is_state(): if event.is_state():
@ -836,9 +846,13 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position): def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position """ Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state Args:
:returns: A Deferred map from ((type, state_key)->Event) room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state
Returns:
A Deferred map from ((type, state_key)->Event)
""" """
last_events, token = yield self.store.get_recent_events_for_room( last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1, room_id, end_token=stream_position.room_key, limit=1,
@ -859,15 +873,18 @@ class SyncHandler(BaseHandler):
""" Works out the differnce in state between the start of the timeline """ Works out the differnce in state between the start of the timeline
and the previous sync. and the previous sync.
:param str room_id Args:
:param TimelineBatch batch: The timeline batch for the room that will room_id(str):
be sent to the user. batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
:param sync_config the room that will be sent to the user.
:param str since_token: Token of the end of the previous batch. May be None. sync_config(synapse.handlers.sync.SyncConfig):
:param str now_token: Token of the end of the current batch. since_token(str|None): Token of the end of the previous batch. May
:param bool full_state: Whether to force returning the full state. be None.
now_token(str): Token of the end of the current batch.
full_state(bool): Whether to force returning the full state.
:returns A new event dictionary Returns:
A deferred new event dictionary
""" """
# TODO(mjark) Check if the state events were received by the server # TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state # after the previous sync, since we need to include those state
@ -939,11 +956,13 @@ class SyncHandler(BaseHandler):
Check if the user has just joined the given room (so should Check if the user has just joined the given room (so should
be given the full state) be given the full state)
:param sync_config: Args:
:param dict[(str,str), synapse.events.FrozenEvent] state_delta: the sync_config(synapse.handlers.sync.SyncConfig):
difference in state since the last sync state_delta(dict[(str,str), synapse.events.FrozenEvent]): the
difference in state since the last sync
:returns A deferred Tuple (state_delta, limited) Returns:
A deferred Tuple (state_delta, limited)
""" """
join_event = state_delta.get(( join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None) EventTypes.Member, sync_config.user.to_string()), None)

View file

@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
import collections import collections
import logging import logging
import random import random
import time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ SERVER_CACHE = {}
_Server = collections.namedtuple( _Server = collections.namedtuple(
"_Server", "priority weight host port" "_Server", "priority weight host port expires"
) )
@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
host=domain, host=domain,
port=default_port, port=default_port,
priority=0, priority=0,
weight=0 weight=0,
expires=0,
) )
else: else:
self.default_server = None self.default_server = None
@ -153,7 +155,13 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = [] servers = []
try: try:
@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
continue continue
payload = answer.payload payload = answer.payload
host = str(payload.target) host = str(payload.target)
srv_ttl = answer.ttl
try: try:
answers, _, _ = yield dns_client.lookupAddress(host) answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError: except DNSNameError:
continue continue
ips = [ for answer in answers:
answer.payload.dottedQuad() if answer.type == dns.A and answer.payload:
for answer in answers ip = answer.payload.dottedQuad()
if answer.type == dns.A and answer.payload host_ttl = min(srv_ttl, answer.ttl)
]
for ip in ips: servers.append(_Server(
servers.append(_Server( host=ip,
host=ip, port=int(payload.port),
port=int(payload.port), priority=int(payload.priority),
priority=int(payload.priority), weight=int(payload.weight),
weight=int(payload.weight) expires=int(clock.time()) + host_ttl,
)) ))
servers.sort() servers.sort()
cache[service_name] = list(servers) cache[service_name] = list(servers)

View file

@ -26,14 +26,19 @@ logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False): def parse_integer(request, name, default=None, required=False):
"""Parse an integer parameter from the request string """Parse an integer parameter from the request string
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (int|None): value to use if the parameter is absent, defaults
parameter is absent, defaults to False. to None.
:return: An int value or the default. required (bool): whether to raise a 400 SynapseError if the
:raises parameter is absent, defaults to False.
SynapseError if the parameter is absent and required, or if the
Returns:
int|None: An int value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer. parameter is present and not an integer.
""" """
if name in request.args: if name in request.args:
@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False):
def parse_boolean(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False):
"""Parse a boolean parameter from the request query string """Parse a boolean parameter from the request query string
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (bool|None): value to use if the parameter is absent, defaults
parameter is absent, defaults to False. to None.
:return: A bool value or the default. required (bool): whether to raise a 400 SynapseError if the
:raises parameter is absent, defaults to False.
SynapseError if the parameter is absent and required, or if the
Returns:
bool|None: A bool value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false". parameter is present and not one of "true" or "false".
""" """
@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string. """Parse a string parameter from the request query string.
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (str|None): value to use if the parameter is absent, defaults
parameter is absent, defaults to False. to None.
:param allowed_values (list): List of allowed values for the string, required (bool): whether to raise a 400 SynapseError if the
or None if any value is allowed, defaults to None parameter is absent, defaults to False.
:return: A string value or the default. allowed_values (list[str]): List of allowed values for the string,
:raises or None if any value is allowed, defaults to None
Returns:
str|None: A string value or the default.
Raises:
SynapseError if the parameter is absent and required, or if the SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and parameter is present, must be one of a list of allowed values and
is not one of those allowed values. is not one of those allowed values.
@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False,
def parse_json_value_from_request(request): def parse_json_value_from_request(request):
"""Parse a JSON value from the body of a twisted HTTP request. """Parse a JSON value from the body of a twisted HTTP request.
:param request: the twisted HTTP request. Args:
:returns: The JSON value. request: the twisted HTTP request.
:raises
Returns:
The JSON value.
Raises:
SynapseError if the request body couldn't be decoded as JSON. SynapseError if the request body couldn't be decoded as JSON.
""" """
try: try:
@ -143,8 +162,10 @@ def parse_json_value_from_request(request):
def parse_json_object_from_request(request): def parse_json_object_from_request(request):
"""Parse a JSON object from the body of a twisted HTTP request. """Parse a JSON object from the body of a twisted HTTP request.
:param request: the twisted HTTP request. Args:
:raises request: the twisted HTTP request.
Raises:
SynapseError if the request body couldn't be decoded as JSON or SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object. if it wasn't a JSON object.
""" """

View file

@ -503,13 +503,14 @@ class Notifier(object):
def wait_for_replication(self, callback, timeout): def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen. """Wait for an event to happen.
:param callback: Args:
Gets called whenever an event happens. If this returns a truthy callback: Gets called whenever an event happens. If this returns a
value then ``wait_for_replication`` returns, otherwise it waits truthy value then ``wait_for_replication`` returns, otherwise
for another event. it waits for another event.
:param int timeout: timeout: How many milliseconds to wait for callback return a truthy
How many milliseconds to wait for callback return a truthy value. value.
:returns:
Returns:
A deferred that resolves with the value returned by the callback. A deferred that resolves with the value returned by the callback.
""" """
listener = _NotificationListener(None) listener = _NotificationListener(None)

View file

@ -19,9 +19,11 @@ import copy
def list_with_base_rules(rawrules): def list_with_base_rules(rawrules):
"""Combine the list of rules set by the user with the default push rules """Combine the list of rules set by the user with the default push rules
:param list rawrules: The rules the user has modified or set. Args:
:returns: A new list with the rules set by the user combined with the rawrules(list): The rules the user has modified or set.
defaults.
Returns:
A new list with the rules set by the user combined with the defaults.
""" """
ruleslist = [] ruleslist = []
@ -160,7 +162,27 @@ BASE_APPEND_OVRRIDE_RULES = [
'actions': [ 'actions': [
'dont_notify', 'dont_notify',
] ]
} },
# Will we sometimes want to know about people joining and leaving?
# Perhaps: if so, this could be expanded upon. Seems the most usual case
# is that we don't though. We add this override rule so that even if
# the room rule is set to notify, we don't get notifications about
# join/leave/avatar/displayname events.
# See also: https://matrix.org/jira/browse/SYN-607
{
'rule_id': 'global/override/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
}
],
'actions': [
'dont_notify'
]
},
] ]
@ -261,25 +283,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
# This is too simple: https://matrix.org/jira/browse/SYN-607
# Removing for now
# {
# 'rule_id': 'global/underride/.m.rule.member_event',
# 'conditions': [
# {
# 'kind': 'event_match',
# 'key': 'type',
# 'pattern': 'm.room.member',
# '_id': '_member',
# }
# ],
# 'actions': [
# 'notify', {
# 'set_tweak': 'highlight',
# 'value': False
# }
# ]
# },
{ {
'rule_id': 'global/underride/.m.rule.message', 'rule_id': 'global/underride/.m.rule.message',
'conditions': [ 'conditions': [

View file

@ -133,8 +133,9 @@ class PushRuleEvaluator:
enabled = self.enabled_map.get(r['rule_id'], None) enabled = self.enabled_map.get(r['rule_id'], None)
if enabled is not None and not enabled: if enabled is not None and not enabled:
continue continue
elif enabled is None and not r.get("enabled", True):
if not r.get("enabled", True): # if no override, check enabled on the rule itself
# (may have come from a base rule)
continue continue
conditions = r['conditions'] conditions = r['conditions']

View file

@ -36,6 +36,7 @@ REQUIREMENTS = {
"sortedcontainers": ["sortedcontainers"], "sortedcontainers": ["sortedcontainers"],
"pysaml2==4.0.3": ["saml2==4.0.3"], "pysaml2==4.0.3": ["saml2==4.0.3"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"pyjwt": ["jwt"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {

View file

@ -38,6 +38,7 @@ STREAM_NAMES = (
("backfill",), ("backfill",),
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",),
) )
@ -76,7 +77,7 @@ class ReplicationResource(Resource):
The response is a JSON object with keys for each stream with updates. Under The response is a JSON object with keys for each stream with updates. Under
each key is a JSON object with: each key is a JSON object with:
* "postion": The current position of the stream. * "position": The current position of the stream.
* "field_names": The names of the fields in each row. * "field_names": The names of the fields in each row.
* "rows": The updates as an array of arrays. * "rows": The updates as an array of arrays.
@ -123,6 +124,7 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token() backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -133,6 +135,7 @@ class ReplicationResource(Resource):
backfill_token, backfill_token,
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token,
)) ))
@request_handler @request_handler
@ -142,31 +145,43 @@ class ReplicationResource(Resource):
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
writer = _Writer(request)
@defer.inlineCallbacks request_streams = {
name: parse_integer(request, name)
for names in STREAM_NAMES for name in names
}
request_streams["streams"] = parse_string(request, "streams")
def replicate(): def replicate():
current_token = yield self.current_replication_token() return self.replicate(request_streams, limit)
logger.info("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit) result = yield self.notifier.wait_for_replication(replicate, timeout)
yield self.events(writer, current_token, limit)
yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
yield self.pushers(writer, current_token, limit)
self.streams(writer, current_token)
logger.info("Replicated %d rows", writer.total) request.write(json.dumps(result, ensure_ascii=False))
defer.returnValue(writer.total) finish_request(request)
yield self.notifier.wait_for_replication(replicate, timeout) @defer.inlineCallbacks
def replicate(self, request_streams, limit):
writer = _Writer()
current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token)
writer.finish() yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams)
# TODO: implement limit
yield self.presence(writer, current_token, request_streams)
yield self.typing(writer, current_token, request_streams)
yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
def streams(self, writer, current_token): logger.info("Replicated %d rows", writer.total)
request_token = parse_string(writer.request, "streams") defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams")
streams = [] streams = []
@ -191,32 +206,43 @@ class ReplicationResource(Resource):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def events(self, writer, current_token, limit): def events(self, writer, current_token, limit, request_streams):
request_events = parse_integer(writer.request, "events") request_events = request_streams.get("events")
request_backfill = parse_integer(writer.request, "backfill") request_backfill = request_streams.get("backfill")
if request_events is not None or request_backfill is not None: if request_events is not None or request_backfill is not None:
if request_events is None: if request_events is None:
request_events = current_token.events request_events = current_token.events
if request_backfill is None: if request_backfill is None:
request_backfill = current_token.backfill request_backfill = current_token.backfill
events_rows, backfill_rows = yield self.store.get_all_new_events( res = yield self.store.get_all_new_events(
request_backfill, request_events, request_backfill, request_events,
current_token.backfill, current_token.events, current_token.backfill, current_token.events,
limit limit
) )
writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group"
))
writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group"
))
writer.write_header_and_rows( writer.write_header_and_rows(
"events", events_rows, ("position", "internal", "json") "forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group")
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"backfill", backfill_rows, ("position", "internal", "json") "backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group")
)
writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def presence(self, writer, current_token): def presence(self, writer, current_token, request_streams):
current_position = current_token.presence current_position = current_token.presence
request_presence = parse_integer(writer.request, "presence") request_presence = request_streams.get("presence")
if request_presence is not None: if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates( presence_rows = yield self.presence_handler.get_all_presence_updates(
@ -229,10 +255,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token): def typing(self, writer, current_token, request_streams):
current_position = current_token.presence current_position = current_token.presence
request_typing = parse_integer(writer.request, "typing") request_typing = request_streams.get("typing")
if request_typing is not None: if request_typing is not None:
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
@ -243,10 +269,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def receipts(self, writer, current_token, limit): def receipts(self, writer, current_token, limit, request_streams):
current_position = current_token.receipts current_position = current_token.receipts
request_receipts = parse_integer(writer.request, "receipts") request_receipts = request_streams.get("receipts")
if request_receipts is not None: if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts( receipts_rows = yield self.store.get_all_updated_receipts(
@ -257,12 +283,12 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def account_data(self, writer, current_token, limit): def account_data(self, writer, current_token, limit, request_streams):
current_position = current_token.account_data current_position = current_token.account_data
user_account_data = parse_integer(writer.request, "user_account_data") user_account_data = request_streams.get("user_account_data")
room_account_data = parse_integer(writer.request, "room_account_data") room_account_data = request_streams.get("room_account_data")
tag_account_data = parse_integer(writer.request, "tag_account_data") tag_account_data = request_streams.get("tag_account_data")
if user_account_data is not None or room_account_data is not None: if user_account_data is not None or room_account_data is not None:
if user_account_data is None: if user_account_data is None:
@ -288,10 +314,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules(self, writer, current_token, limit): def push_rules(self, writer, current_token, limit, request_streams):
current_position = current_token.push_rules current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules") push_rules = request_streams.get("push_rules")
if push_rules is not None: if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates( rows = yield self.store.get_all_push_rule_updates(
@ -303,10 +329,11 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def pushers(self, writer, current_token, limit): def pushers(self, writer, current_token, limit, request_streams):
current_position = current_token.pushers current_position = current_token.pushers
pushers = parse_integer(writer.request, "pushers") pushers = request_streams.get("pushers")
if pushers is not None: if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers( updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit pushers, current_position, limit
@ -320,12 +347,30 @@ class ReplicationResource(Resource):
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
)) ))
@defer.inlineCallbacks
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
state = request_streams.get("state")
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
def __init__(self, request): def __init__(self):
self.streams = {} self.streams = {}
self.request = request
self.total = 0 self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None): def write_header_and_rows(self, name, rows, fields, position=None):
@ -344,13 +389,12 @@ class _Writer(object):
self.total += len(rows) self.total += len(rows)
def finish(self): def finish(self):
self.request.write(json.dumps(self.streams, ensure_ascii=False)) return self.streams
finish_request(self.request)
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers" "push_rules", "pushers", "state"
))): ))):
__slots__ = [] __slots__ = []

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage._base import SQLBaseStore
from twisted.internet import defer
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
def stream_positions(self):
return {}
def process_replication(self, result):
return defer.succeed(None)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker(object):
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self.advance(_load_current_id(db_conn, table, column))
def advance(self, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
return self._current

View file

@ -0,0 +1,204 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.state import StateStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_event = DataStore.get_event.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
get_membership_changes_for_user = (
DataStore.get_membership_changes_for_user.__func__
)
get_room_events_max_id = DataStore.get_room_events_max_id.__func__
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
_set_before_and_after = DataStore._set_before_and_after
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_parse_events_txn = DataStore._parse_events_txn.__func__
_get_events_txn = DataStore._get_events_txn.__func__
_enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__
_fetch_events_txn = DataStore._fetch_events_txn.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row = DataStore._get_event_from_row.__func__
_get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
state_resets = set(
r[0] for r in result.get("state_resets", {"rows": []})["rows"]
)
stream = result.get("events")
if stream:
self._stream_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
)
stream = result.get("backfill")
if stream:
self._backfill_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
)
stream = result.get("forward_ex_outliers")
if stream:
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled, state_resets):
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self._invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
)
def _invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id)
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
if not backfilled:
self._events_stream_cache.entity_has_changed(
event.room_id, event.internal_metadata.stream_ordering
)
# self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
# (event.room_id,)
# )
if event.type == EventTypes.Redaction:
self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
# self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering
# )
if not event.is_state():
return
if backfilled:
return
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))
if event.type in [EventTypes.Name, EventTypes.Aliases]:
self.get_room_name_and_aliases.invalidate(
(event.room_id,)
)
pass

View file

@ -33,6 +33,9 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import jwt
from jwt.exceptions import InvalidTokenError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,12 +46,16 @@ class LoginRestServlet(ClientV1RestServlet):
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token" TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
@ -57,6 +64,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
if self.saml2_enabled: if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE}) flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled: if self.cas_enabled:
@ -98,6 +107,10 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
elif self.jwt_enabled and (login_submission["type"] ==
LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead # TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] == elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE): LoginRestServlet.CAS_TYPE):
@ -209,6 +222,46 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
token = login_submission['token']
if token is None:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
except InvalidTokenError:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload['user']
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead # TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body) root = ET.fromstring(cas_response_body)

View file

@ -405,6 +405,42 @@ class RoomEventContext(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
class RoomForgetRestServlet(ClientV1RestServlet):
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, txn_id=None):
requester = yield self.auth.get_user_by_req(
request,
allow_guest=False,
)
yield self.handlers.room_member_handler.forget(
user=requester.user,
room_id=room_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(
request, room_id, txn_id
)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(ClientV1RestServlet):
@ -624,6 +660,7 @@ def register_servlets(hs, http_server):
RoomMemberListRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server)

View file

@ -115,6 +115,8 @@ class SyncRestServlet(RestServlet):
) )
) )
request_key = (user, timeout, since, filter_id, full_state)
if filter_id: if filter_id:
if filter_id.startswith('{'): if filter_id.startswith('{'):
try: try:
@ -134,6 +136,7 @@ class SyncRestServlet(RestServlet):
user=user, user=user,
filter_collection=filter, filter_collection=filter,
is_guest=requester.is_guest, is_guest=requester.is_guest,
request_key=request_key,
) )
if since is not None: if since is not None:
@ -196,15 +199,17 @@ class SyncRestServlet(RestServlet):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync Args:
results for rooms this user is joined to rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
:param int time_now: current time - used as a baseline for age results for rooms this user is joined to
calculations time_now(int): current time - used as a baseline for age
:param int token_id: ID of the user's auth token - used for namespacing calculations
of transaction IDs token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
:return: the joined rooms list, in our response format Returns:
:rtype: dict[str, dict[str, object]] dict[str, dict[str, object]]: the joined rooms list, in our
response format
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
@ -218,15 +223,17 @@ class SyncRestServlet(RestServlet):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of Args:
sync results for rooms this user is joined to rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
:param int time_now: current time - used as a baseline for age sync results for rooms this user is joined to
calculations time_now(int): current time - used as a baseline for age
:param int token_id: ID of the user's auth token - used for namespacing calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
:return: the invited rooms list, in our response format Returns:
:rtype: dict[str, dict[str, object]] dict[str, dict[str, object]]: the invited rooms list, in our
response format
""" """
invited = {} invited = {}
for room in rooms: for room in rooms:
@ -248,15 +255,17 @@ class SyncRestServlet(RestServlet):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of Args:
sync results for rooms this user is joined to rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
:param int time_now: current time - used as a baseline for age sync results for rooms this user is joined to
calculations time_now(int): current time - used as a baseline for age
:param int token_id: ID of the user's auth token - used for namespacing calculations
of transaction IDs token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format Returns:
:rtype: dict[str, dict[str, object]] dict[str, dict[str, object]]: The invited rooms list, in our
response format
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
@ -269,17 +278,18 @@ class SyncRestServlet(RestServlet):
@staticmethod @staticmethod
def encode_room(room, time_now, token_id, joined=True): def encode_room(room, time_now, token_id, joined=True):
""" """
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a Args:
single room room (JoinedSyncResult|ArchivedSyncResult): sync result for a
:param int time_now: current time - used as a baseline for age single room
calculations time_now (int): current time - used as a baseline for age
:param int token_id: ID of the user's auth token - used for namespacing calculations
of transaction IDs token_id (int): ID of the user's auth token - used for namespacing
:param joined: True if the user is joined to this room - will mean of transaction IDs
we handle ephemeral events joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events
:return: the room, encoded in our response format Returns:
:rtype: dict[str, object] dict[str, object]: the room, encoded in our response format
""" """
def serialize(event): def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.

View file

@ -75,7 +75,8 @@ class StateHandler(object):
self._state_cache.start() self._state_cache.start()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key="",
latest_event_ids=None):
""" Retrieves the current state for the room. This is done by """ Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts. event graph and then resolving any of the state conflicts.
@ -86,11 +87,13 @@ class StateHandler(object):
If `event_type` is specified, then the method returns only the one If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`. event (or None) with that `event_type` and `state_key`.
:returns map from (type, state_key) to event Returns:
map from (type, state_key) to event
""" """
event_ids = yield self.store.get_latest_event_ids_in_room(room_id) if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
res = yield self.resolve_state_groups(room_id, event_ids) res = yield self.resolve_state_groups(room_id, latest_event_ids)
state = res[1] state = res[1]
if event_type: if event_type:
@ -100,7 +103,7 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None, outlier=False): def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The """ Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph `current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event` just before the event - i.e. it never includes `event`
@ -115,7 +118,7 @@ class StateHandler(object):
""" """
context = EventContext() context = EventContext()
if outlier: if event.internal_metadata.is_outlier():
# If this is an outlier, then we know it shouldn't have any current # If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
@ -176,10 +179,11 @@ class StateHandler(object):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
:returns a Deferred tuple of (`state_group`, `state`, `prev_state`). Returns:
`state_group` is the name of a state group if one and only one is a Deferred tuple of (`state_group`, `state`, `prev_state`).
involved. `state` is a map from (type, state_key) to event, and `state_group` is the name of a state group if one and only one is
`prev_state` is a list of event ids. involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
@ -251,9 +255,10 @@ class StateHandler(object):
def _resolve_events(self, state_sets, event_type=None, state_key=""): def _resolve_events(self, state_sets, event_type=None, state_key=""):
""" """
:returns a tuple (new_state, prev_states). new_state is a map Returns
from (type, state_key) to event. prev_states is a list of event_ids. (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
:rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str]) (new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
""" """
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
state = {} state = {}

View file

@ -88,22 +88,17 @@ class DataStore(RoomMemberStore, RoomStore,
self.hs = hs self.hs = hs
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
cur = db_conn.cursor()
try:
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
self.min_stream_token = min(self.min_stream_token, -1)
finally:
cur.close()
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
) )
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering" db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1
) )
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
@ -116,7 +111,7 @@ class DataStore(RoomMemberStore, RoomStore,
) )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@ -129,7 +124,7 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
events_max = self._stream_id_gen.get_max_token() events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",
entity_column="room_id", entity_column="room_id",
@ -145,7 +140,7 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max,
) )
account_max = self._account_data_id_gen.get_max_token() account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max, "AccountDataAndTagsChangeCache", account_max,
) )
@ -156,7 +151,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", db_conn, "presence_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(), max_value=self._presence_id_gen.get_current_token(),
) )
self.presence_stream_cache = StreamChangeCache( self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val, "PresenceStreamChangeCache", min_presence_val,
@ -167,7 +162,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "push_rules_stream", db_conn, "push_rules_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0], max_value=self._push_rules_stream_id_gen.get_current_token()[0],
) )
self.push_rules_stream_cache = StreamChangeCache( self.push_rules_stream_cache = StreamChangeCache(
@ -182,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
self.__presence_on_startup = None self.__presence_on_startup = None
return active_on_startup return active_on_startup
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
def _get_active_presence(self, db_conn): def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register """Fetch non-offline presence from the database so that we can register
the appropriate time outs. the appropriate time outs.

View file

@ -810,11 +810,39 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values()) return txn.execute(sql, keyvalues.values())
def get_next_stream_id(self): def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
with self._next_stream_id_lock: max_value):
i = self._next_stream_id # Fetch a mapping of room_id -> max stream position for "recent" rooms.
self._next_stream_id += 1 # It doesn't really matter how many we get, the StreamChangeCache will
return i # do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):

View file

@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id "add_room_account_data", add_account_data_txn, next_id
) )
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id "add_user_account_data", add_account_data_txn, next_id
) )
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id): def _update_max_stream_id(self, txn, next_id):

View file

@ -26,13 +26,13 @@ SUPPORTED_MODULE = {
} }
def create_engine(config): def create_engine(database_config):
name = config.database_config["name"] name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None) engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class: if engine_class:
module = importlib.import_module(name) module = importlib.import_module(name)
return engine_class(module, config=config) return engine_class(module)
raise RuntimeError( raise RuntimeError(
"Unsupported database engine '%s'" % (name,) "Unsupported database engine '%s'" % (name,)

View file

@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.prepare_database import prepare_database
from ._base import IncorrectDatabaseSetup from ._base import IncorrectDatabaseSetup
class PostgresEngine(object): class PostgresEngine(object):
single_threaded = False single_threaded = False
def __init__(self, database_module, config): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)
self.config = config
def check_database(self, txn): def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING") txn.execute("SHOW SERVER_ENCODING")
@ -44,9 +41,6 @@ class PostgresEngine(object):
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
) )
def prepare_database(self, db_conn):
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):
return error.pgcode in ["40001", "40P01"] return error.pgcode in ["40001", "40P01"]

View file

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.prepare_database import ( from synapse.storage.prepare_database import prepare_database
prepare_database, prepare_sqlite3_database
)
import struct import struct
@ -23,9 +21,8 @@ import struct
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True single_threaded = True
def __init__(self, database_module, config): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.config = config
def check_database(self, txn): def check_database(self, txn):
pass pass
@ -34,13 +31,9 @@ class Sqlite3Engine(object):
return sql return sql
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
self.prepare_database(db_conn) prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank) db_conn.create_function("rank", 1, _rank)
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error): def is_deadlock(self, error):
return False return False

View file

@ -163,6 +163,22 @@ class EventFederationStore(SQLBaseStore):
room_id, room_id,
) )
@defer.inlineCallbacks
def get_max_depth_of_events(self, event_ids):
sql = (
"SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
) % (",".join(["?"] * len(event_ids)),)
rows = yield self._execute(
"get_max_depth_of_events", None,
sql, *event_ids
)
if rows:
defer.returnValue(rows[0][0])
else:
defer.returnValue(1)
def _get_min_depth_interaction(self, txn, room_id): def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn( min_depth = self._simple_select_one_onecol_txn(
txn, txn,

View file

@ -26,8 +26,9 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore): class EventPushActionsStore(SQLBaseStore):
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
""" """
:param event: the event set actions for Args:
:param tuples: list of tuples of (user_id, actions) event: the event set actions for
tuples: list of tuples of (user_id, actions)
""" """
values = [] values = []
for uid, actions in tuples: for uid, actions in tuples:

View file

@ -24,7 +24,7 @@ from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from contextlib import contextmanager from collections import namedtuple
import logging import logging
import math import math
@ -60,64 +60,71 @@ class EventsStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False, def persist_events(self, events_and_contexts, backfilled=False):
is_new_state=True):
if not events_and_contexts: if not events_and_contexts:
return return
if backfilled: if backfilled:
start = self.min_stream_token - 1 stream_ordering_manager = self._backfill_id_gen.get_next_mult(
self.min_stream_token -= len(events_and_contexts) + 1 len(events_and_contexts)
stream_orderings = range(start, self.min_stream_token, -1) )
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else: else:
stream_ordering_manager = self._stream_id_gen.get_next_mult( stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts) len(events_and_contexts)
) )
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
for (event, _), stream in zip(events_and_contexts, stream_orderings): with state_group_id_manager as state_group_ids:
event.internal_metadata.stream_ordering = stream for (event, context), stream, state_group_id in zip(
events_and_contexts, stream_orderings, state_group_ids
):
event.internal_metadata.stream_ordering = stream
# Assign a state group_id in case a new id is needed for
# this context. In theory we only need to assign this
# for contexts that have current_state and aren't outliers
# but that make the code more complicated. Assigning an ID
# per event only causes the state_group_ids to grow as fast
# as the stream_ordering so in practise shouldn't be a problem.
context.new_state_group_id = state_group_id
chunks = [ chunks = [
events_and_contexts[x:x + 100] events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100) for x in xrange(0, len(events_and_contexts), 100)
] ]
for chunk in chunks: for chunk in chunks:
# We can't easily parallelize these since different chunks # We can't easily parallelize these since different chunks
# might contain the same event. :( # might contain the same event. :(
yield self.runInteraction( yield self.runInteraction(
"persist_events", "persist_events",
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
is_new_state=is_new_state, )
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, def persist_event(self, event, context, current_state=None):
is_new_state=True, current_state=None):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering with self._state_groups_id_gen.get_next() as state_group_id:
yield self.runInteraction( event.internal_metadata.stream_ordering = stream_ordering
"persist_event", context.new_state_group_id = state_group_id
self._persist_event_txn, yield self.runInteraction(
event=event, "persist_event",
context=context, self._persist_event_txn,
is_new_state=is_new_state, event=event,
current_state=current_state, context=context,
) current_state=current_state,
)
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_max_token() max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((stream_ordering, max_persisted_id)) defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -177,8 +184,7 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@log_function @log_function
def _persist_event_txn(self, txn, event, context, def _persist_event_txn(self, txn, event, context, current_state):
is_new_state=True, current_state=None):
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
@ -186,7 +192,16 @@ class EventsStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id) txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = event.internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
@ -210,12 +225,10 @@ class EventsStore(SQLBaseStore):
txn, txn,
[(event, context)], [(event, context)],
backfilled=False, backfilled=False,
is_new_state=is_new_state,
) )
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled):
is_new_state=True):
depth_updates = {} depth_updates = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
@ -282,9 +295,7 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id] outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted: if not event.internal_metadata.is_outlier() and outlier_persisted:
self._store_state_groups_txn( self._store_mult_state_groups_txn(txn, ((event, context),))
txn, event, context,
)
metadata_json = encode_json( metadata_json = encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
@ -299,6 +310,18 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,) (metadata_json, event.event_id,)
) )
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
"event_stream_ordering": stream_order,
"event_id": event.event_id,
"state_group": state_group_id,
}
)
sql = ( sql = (
"UPDATE events SET outlier = ?" "UPDATE events SET outlier = ?"
" WHERE event_id = ?" " WHERE event_id = ?"
@ -310,19 +333,14 @@ class EventsStore(SQLBaseStore):
self._update_extremeties(txn, [event]) self._update_extremeties(txn, [event])
events_and_contexts = filter( events_and_contexts = [
lambda ec: ec[0] not in to_remove, ec for ec in events_and_contexts if ec[0] not in to_remove
events_and_contexts ]
)
if not events_and_contexts: if not events_and_contexts:
return return
self._store_mult_state_groups_txn(txn, [ self._store_mult_state_groups_txn(txn, events_and_contexts)
(event, context)
for event, context in events_and_contexts
if not event.internal_metadata.is_outlier()
])
self._handle_mult_prev_events( self._handle_mult_prev_events(
txn, txn,
@ -349,7 +367,8 @@ class EventsStore(SQLBaseStore):
event event
for event, _ in events_and_contexts for event, _ in events_and_contexts
if event.type == EventTypes.Member if event.type == EventTypes.Member
] ],
backfilled=backfilled,
) )
def event_dict(event): def event_dict(event):
@ -421,10 +440,9 @@ class EventsStore(SQLBaseStore):
txn, [event for event, _ in events_and_contexts] txn, [event for event, _ in events_and_contexts]
) )
state_events_and_contexts = filter( state_events_and_contexts = [
lambda i: i[0].is_state(), ec for ec in events_and_contexts if ec[0].is_state()
events_and_contexts, ]
)
state_values = [] state_values = []
for event, context in state_events_and_contexts: for event, context in state_events_and_contexts:
@ -462,32 +480,44 @@ class EventsStore(SQLBaseStore):
], ],
) )
if is_new_state: if backfilled:
for event, _ in state_events_and_contexts: # Backfilled events come before the current state so we don't need
if not context.rejected: # to update the current state table
txn.call_after( return
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
if event.type in [EventTypes.Name, EventTypes.Aliases]: for event, _ in state_events_and_contexts:
txn.call_after( if event.internal_metadata.is_outlier():
self.get_room_name_and_aliases.invalidate, # Outlier events shouldn't clobber the current state.
(event.room_id,) continue
)
self._simple_upsert_txn( if context.rejected:
txn, # If the event failed it's auth checks then it shouldn't
"current_state_events", # clobbler the current state.
keyvalues={ continue
"room_id": event.room_id,
"type": event.type, txn.call_after(
"state_key": event.state_key, self._get_current_state_for_key.invalidate,
}, (event.room_id, event.type, event.state_key,)
values={ )
"event_id": event.event_id,
} if event.type in [EventTypes.Name, EventTypes.Aliases]:
) txn.call_after(
self.get_room_name_and_aliases.invalidate,
(event.room_id,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return return
@ -1076,10 +1106,7 @@ class EventsStore(SQLBaseStore):
def get_current_backfill_token(self): def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached""" """The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
# TODO: Fix race with the persit_event txn by using one of the
# stream id managers
return -self.min_stream_token
def get_all_new_events(self, last_backfill_id, last_forward_id, def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit): current_backfill_id, current_forward_id, limit):
@ -1087,10 +1114,12 @@ class EventsStore(SQLBaseStore):
new events or as backfilled events""" new events or as backfilled events"""
def get_all_new_events_txn(txn): def get_all_new_events_txn(txn):
sql = ( sql = (
"SELECT e.stream_ordering, ej.internal_metadata, ej.json" "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
" FROM events as e" " FROM events as e"
" JOIN event_json as ej" " JOIN event_json as ej"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
" LEFT JOIN event_to_state_groups as eg"
" ON e.event_id = eg.event_id"
" WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC" " ORDER BY e.stream_ordering ASC"
" LIMIT ?" " LIMIT ?"
@ -1098,14 +1127,43 @@ class EventsStore(SQLBaseStore):
if last_forward_id != current_forward_id: if last_forward_id != current_forward_id:
txn.execute(sql, (last_forward_id, current_forward_id, limit)) txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall() new_forward_events = txn.fetchall()
if len(new_forward_events) == limit:
upper_bound = new_forward_events[-1][0]
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering FROM current_state_resets"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_forward_id, upper_bound))
state_resets = txn.fetchall()
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_forward_id, upper_bound))
forward_ex_outliers = txn.fetchall()
else: else:
new_forward_events = [] new_forward_events = []
state_resets = []
forward_ex_outliers = []
sql = ( sql = (
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json" "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
" eg.state_group"
" FROM events as e" " FROM events as e"
" JOIN event_json as ej" " JOIN event_json as ej"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
" LEFT JOIN event_to_state_groups as eg"
" ON e.event_id = eg.event_id"
" WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
" ORDER BY e.stream_ordering DESC" " ORDER BY e.stream_ordering DESC"
" LIMIT ?" " LIMIT ?"
@ -1113,8 +1171,35 @@ class EventsStore(SQLBaseStore):
if last_backfill_id != current_backfill_id: if last_backfill_id != current_backfill_id:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall() new_backfill_events = txn.fetchall()
if len(new_backfill_events) == limit:
upper_bound = new_backfill_events[-1][0]
else:
upper_bound = current_backfill_id
sql = (
"SELECT -event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_backfill_id, -upper_bound))
backward_ex_outliers = txn.fetchall()
else: else:
new_backfill_events = [] new_backfill_events = []
backward_ex_outliers = []
return (new_forward_events, new_backfill_events) return AllNewEventsResult(
new_forward_events, new_backfill_events,
forward_ex_outliers, backward_ex_outliers,
state_resets,
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn) return self.runInteraction("get_all_new_events", get_all_new_events_txn)
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
"forward_ex_outliers", "backward_ex_outliers",
"state_resets"
])

View file

@ -25,23 +25,11 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 30 SCHEMA_VERSION = 31
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
def read_schema(path):
""" Read the named database schema.
Args:
path: Path of the database schema.
Returns:
A string containing the database schema.
"""
with open(path) as schema_file:
return schema_file.read()
class PrepareDatabaseException(Exception): class PrepareDatabaseException(Exception):
pass pass
@ -53,6 +41,9 @@ class UpgradeDatabaseException(PrepareDatabaseException):
def prepare_database(db_conn, database_engine, config): def prepare_database(db_conn, database_engine, config):
"""Prepares a database for usage. Will either create all necessary tables """Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version. or upgrade from an older schema version.
If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty.
""" """
try: try:
cur = db_conn.cursor() cur = db_conn.cursor()
@ -60,13 +51,18 @@ def prepare_database(db_conn, database_engine, config):
if version_info: if version_info:
user_version, delta_files, upgraded = version_info user_version, delta_files, upgraded = version_info
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config
)
else:
_setup_new_database(cur, database_engine, config)
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) if config is None:
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
raise UpgradeDatabaseException("Database needs to be upgraded")
else:
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config
)
else:
_setup_new_database(cur, database_engine)
cur.close() cur.close()
db_conn.commit() db_conn.commit()
@ -75,7 +71,7 @@ def prepare_database(db_conn, database_engine, config):
raise raise
def _setup_new_database(cur, database_engine, config): def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then """Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas. applying any necessary deltas.
@ -148,12 +144,13 @@ def _setup_new_database(cur, database_engine, config):
applied_delta_files=[], applied_delta_files=[],
upgraded=False, upgraded=False,
database_engine=database_engine, database_engine=database_engine,
config=config, config=None,
is_empty=True,
) )
def _upgrade_existing_database(cur, current_version, applied_delta_files, def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded, database_engine, config): upgraded, database_engine, config, is_empty=False):
"""Upgrades an existing database. """Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules Delta files can either be SQL stored in *.sql files, or python modules
@ -246,7 +243,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file module_name, absolute_path, python_file
) )
logger.debug("Running script %s", relative_path) logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine, config=config) module.run_create(cur, database_engine)
if not is_empty:
module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc": elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've # Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package # disabled their generation; e.g. from distribution package
@ -361,36 +360,3 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded
return None return None
def prepare_sqlite3_database(db_conn):
"""This function should be called before `prepare_database` on sqlite3
databases.
Since we changed the way we store the current schema version and handle
updates to schemas, we need a way to upgrade from the old method to the
new. This only affects sqlite databases since they were the only ones
supported at the time.
"""
with db_conn:
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
db_conn.executescript(create_schema)
c = db_conn.execute("SELECT * FROM schema_version")
rows = c.fetchall()
c.close()
if not rows:
c = db_conn.execute("PRAGMA user_version")
row = c.fetchone()
c.close()
if row and row[0]:
db_conn.execute(
"REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)

View file

@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore):
self._update_presence_txn, stream_orderings, presence_states, self._update_presence_txn, stream_orderings, presence_states,
) )
defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) defer.returnValue((
stream_orderings[-1], self._presence_id_gen.get_current_token()
))
def _update_presence_txn(self, txn, stream_orderings, presence_states): def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states):
@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore):
defer.returnValue([UserPresenceState(**row) for row in rows]) defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_max_token() return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(
@ -174,16 +176,6 @@ class PresenceStore(SQLBaseStore):
desc="disallow_presence_visible", desc="disallow_presence_visible",
) )
def is_presence_visible(self, observed_localpart, observer_userid):
return self._simple_select_one(
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
retcols=["observed_user_id"],
allow_none=True,
desc="is_presence_visible",
)
def add_presence_list_pending(self, observer_localpart, observed_userid): def add_presence_list_pending(self, observer_localpart, observed_userid):
return self._simple_insert( return self._simple_insert(
table="presence_list", table="presence_list",

View file

@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore):
"""Get the position of the push rules stream. """Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to.""" room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token() return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id): def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):

View file

@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
def get_pushers_stream_token(self): def get_pushers_stream_token(self):
return self._pushers_id_gen.get_max_token() return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers(self, last_id, current_id, limit):
def get_all_updated_pushers_txn(txn): def get_all_updated_pushers_txn(txn):

View file

@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
) )
@cached(num_args=2) @cached(num_args=2)
@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore):
"content": content, "content": content,
}]) }])
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", @cachedList(cached_method_name="get_linearized_receipts_for_room",
num_args=3, inlineCallbacks=True) list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids: if not room_ids:
defer.returnValue({}) defer.returnValue({})
@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token() return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id): user_id, event_id, data, stream_id):
@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data room_id, receipt_type, user_id, event_ids, data
) )
max_persisted_id = self._stream_id_gen.get_max_token() max_persisted_id = self._stream_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))

View file

@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1,
inlineCallbacks=True) inlineCallbacks=True)
def are_guests(self, user_ids): def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
@ -458,12 +458,15 @@ class RegistrationStore(SQLBaseStore):
""" """
Gets the 3pid's guest access token if exists, else saves access_token. Gets the 3pid's guest access token if exists, else saves access_token.
:param medium (str): Medium of the 3pid. Must be "email". Args:
:param address (str): 3pid address. medium (str): Medium of the 3pid. Must be "email".
:param access_token (str): The access token to persist if none is address (str): 3pid address.
already persisted. access_token (str): The access token to persist if none is
:param inviter_user_id (str): User ID of the inviter. already persisted.
:return (deferred str): Whichever access token is persisted at the end inviter_user_id (str): User ID of the inviter.
Returns:
deferred str: Whichever access token is persisted at the end
of this function call. of this function call.
""" """
def insert(txn): def insert(txn):

View file

@ -36,7 +36,7 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def _store_room_members_txn(self, txn, events): def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database. """Store a room member in the database.
""" """
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -62,27 +62,65 @@ class RoomMemberStore(SQLBaseStore):
self._membership_stream_cache.entity_has_changed, self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering
) )
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
def get_room_member(self, user_id, room_id): # We update the local_invites table only if the event is "current",
"""Retrieve the current state of a room member. # i.e., its something that has just happened.
# The only current event that can also be an outlier is if its an
# invite that has come in across federation.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_invite_from_remote()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
}
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
Args: txn.execute(sql, (
user_id (str): The member's user ID. event.internal_metadata.stream_ordering,
room_id (str): The room the member is in. event.event_id,
Returns: event.room_id,
Deferred: Results in a MembershipEvent or None. event.state_key,
""" ))
return self.runInteraction(
"get_room_member", @defer.inlineCallbacks
self._get_members_events_txn, def locally_reject_invite(self, user_id, room_id):
room_id, sql = (
user_id=user_id, "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
).addCallback( " room_id = ? AND invitee = ? AND locally_rejected is NULL"
self._get_events " AND replaced_by is NULL"
).addCallback(
lambda events: events[0] if events else None
) )
def f(txn, stream_ordering):
txn.execute(sql, (
stream_ordering,
True,
room_id,
user_id,
))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cached(max_entries=5000) @cached(max_entries=5000)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
@ -127,18 +165,23 @@ class RoomMemberStore(SQLBaseStore):
user_id, [Membership.INVITE] user_id, [Membership.INVITE]
) )
def get_leave_and_ban_events_for_user(self, user_id): @defer.inlineCallbacks
""" Get all the leave events for a user def get_invite_for_user_in_room(self, user_id, room_id):
"""Gets the invite for the given user and room
Args: Args:
user_id (str): The user ID. user_id (str)
room_id (str)
Returns: Returns:
A deferred list of event objects. Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
""" """
return self.get_rooms_for_user_where_membership_is( invites = yield self.get_invited_rooms_for_user(user_id)
user_id, (Membership.LEAVE, Membership.BAN) for invite in invites:
).addCallback(lambda leaves: self._get_events([ if invite.room_id == room_id:
leave.event_id for leave in leaves defer.returnValue(invite)
])) defer.returnValue(None)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
@ -163,29 +206,55 @@ class RoomMemberStore(SQLBaseStore):
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list): membership_list):
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
args = [user_id] do_invite = Membership.INVITE in membership_list
args.extend(membership_list) membership_list = [m for m in membership_list if m != Membership.INVITE]
sql = ( results = []
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" if membership_list:
" FROM current_state_events as c" where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" INNER JOIN room_memberships as m" " OR ".join(["membership = ?" for _ in membership_list]),
" ON m.event_id = c.event_id" )
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
) % (where_clause,)
txn.execute(sql, args) args = [user_id]
return [ args.extend(membership_list)
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] sql = (
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM current_state_events as c"
" INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id"
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
) % (where_clause,)
txn.execute(sql, args)
results = [
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
if do_invite:
sql = (
"SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
" FROM local_invites as i"
" INNER JOIN events as e USING (event_id)"
" WHERE invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(sql, (user_id,))
results.extend(RoomsForUser(
room_id=r["room_id"],
sender=r["inviter"],
event_id=r["event_id"],
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
) for r in self.cursor_to_dict(txn))
return results
@cached(max_entries=5000) @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):

View file

@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, *args, **kwargs): def run_create(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex") cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall(): for row in cur.fetchall():
try: try:
@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs):
"UPDATE application_services_regex SET regex=? WHERE id=?", "UPDATE application_services_regex SET regex=? WHERE id=?",
(new_regex, row[0]) (new_regex, row[0])
) )
def run_upgrade(*args, **kwargs):
pass

View file

@ -27,7 +27,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...") logger.info("Porting pushers table...")
cur.execute(""" cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 ( CREATE TABLE IF NOT EXISTS pushers2 (
@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("DROP TABLE pushers") cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count) logger.info("Moved %d pushers to new table", count)
def run_upgrade(*args, **kwargs):
pass

View file

@ -43,7 +43,7 @@ SQLITE_TABLE = (
) )
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()): for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement) cur.execute(statement)
@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql) sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search", progress_json)) cur.execute(sql, ("event_search", progress_json))
def run_upgrade(*args, **kwargs):
pass

View file

@ -27,7 +27,7 @@ ALTER_TABLE = (
) )
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()): for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement) cur.execute(statement)
@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql) sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json)) cur.execute(sql, ("event_origin_server_ts", progress_json))
def run_upgrade(*args, **kwargs):
pass

View file

@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, config, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
# NULL indicates user was not registered by an appservice. # NULL indicates user was not registered by an appservice.
try: try:
cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
# Maybe we already added the column? Hope so... # Maybe we already added the column? Hope so...
pass pass
def run_upgrade(cur, database_engine, config, *args, **kwargs):
cur.execute("SELECT name FROM users") cur.execute("SELECT name FROM users")
rows = cur.fetchall() rows = cur.fetchall()

View file

@ -0,0 +1,38 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* The positions in the event stream_ordering when the current_state was
* replaced by the state at the event.
*/
CREATE TABLE IF NOT EXISTS current_state_resets(
event_stream_ordering BIGINT PRIMARY KEY NOT NULL
);
/* The outlier events that have aquired a state group typically through
* backfill. This is tracked separately to the events table, as assigning a
* state group change the position of the existing event in the stream
* ordering.
* However since a stream_ordering is assigned in persist_event for the
* (event, state) pair, we can use that stream_ordering to identify when
* the new state was assigned for the event.
*/
CREATE TABLE IF NOT EXISTS ex_outlier_stream(
event_stream_ordering BIGINT PRIMARY KEY NOT NULL,
event_id TEXT NOT NULL,
state_group BIGINT NOT NULL
);

View file

@ -0,0 +1,42 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE local_invites(
stream_id BIGINT NOT NULL,
inviter TEXT NOT NULL,
invitee TEXT NOT NULL,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
locally_rejected TEXT,
replaced_by TEXT
);
-- Insert all invites for local users into new `invites` table
INSERT INTO local_invites SELECT
stream_ordering as stream_id,
sender as inviter,
state_key as invitee,
event_id,
room_id,
NULL as locally_rejected,
NULL as replaced_by
FROM events
NATURAL JOIN current_state_events
NATURAL JOIN room_memberships
WHERE membership = 'invite' AND state_key IN (SELECT name FROM users);
CREATE INDEX local_invites_id ON local_invites(stream_id);
CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id);

View file

@ -64,12 +64,12 @@ class StateStore(SQLBaseStore):
for group, state_map in group_to_state.items() for group, state_map in group_to_state.items()
}) })
def _store_state_groups_txn(self, txn, event, context):
return self._store_mult_state_groups_txn(txn, [(event, context)])
def _store_mult_state_groups_txn(self, txn, events_and_contexts): def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state is None: if context.current_state is None:
continue continue
@ -82,7 +82,8 @@ class StateStore(SQLBaseStore):
if event.is_state(): if event.is_state():
state_events[(event.type, event.state_key)] = event state_events[(event.type, event.state_key)] = event
state_group = self._state_groups_id_gen.get_next() state_group = context.new_state_group_id
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
@ -114,11 +115,10 @@ class StateStore(SQLBaseStore):
table="event_to_state_groups", table="event_to_state_groups",
values=[ values=[
{ {
"state_group": state_groups[event.event_id], "state_group": state_group_id,
"event_id": event.event_id, "event_id": event_id,
} }
for event, context in events_and_contexts for event_id, state_group_id in state_groups.items()
if context.current_state is not None
], ],
) )
@ -249,11 +249,14 @@ class StateStore(SQLBaseStore):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned Args:
:param list[(str, str)]|None types: List of (type, state_key) tuples event_id(str): event whose state should be returned
which are used to filter the state fetched. May be None, which types(list[(str, str)]|None): List of (type, state_key) tuples
matches any key which are used to filter the state fetched. May be None, which
:return: a deferred dict from (type, state_key) -> state_event matches any key
Returns:
A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@ -270,8 +273,8 @@ class StateStore(SQLBaseStore):
desc="_get_state_group_for_event", desc="_get_state_group_for_event",
) )
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", @cachedList(cached_method_name="_get_state_group_for_event",
num_args=1, inlineCallbacks=True) list_name="event_ids", num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids): def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group """Returns mapping event_id -> state_group
""" """
@ -429,3 +432,33 @@ class StateStore(SQLBaseStore):
} }
defer.returnValue(results) defer.returnValue(results)
def get_all_new_state_groups(self, last_id, current_id, limit):
def get_all_new_state_groups_txn(txn):
sql = (
"SELECT id, room_id, event_id FROM state_groups"
" WHERE ? < id AND id <= ? ORDER BY id LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
groups = txn.fetchall()
if not groups:
return ([], [])
lower_bound = groups[0][0]
upper_bound = groups[-1][0]
sql = (
"SELECT state_group, type, state_key, event_id"
" FROM state_groups_state"
" WHERE ? <= state_group AND state_group <= ?"
)
txn.execute(sql, (lower_bound, upper_bound))
state_group_state = txn.fetchall()
return (groups, state_group_state)
return self.runInteraction(
"get_all_new_state_groups", get_all_new_state_groups_txn
)
def get_state_stream_token(self):
return self._state_groups_id_gen.get_current_token()

View file

@ -303,96 +303,6 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret) defer.returnValue(ret)
def get_room_events_stream(
self,
user_id,
from_key,
to_key,
limit=0,
is_guest=False,
room_ids=None
):
room_ids = room_ids or []
room_ids = [r for r in room_ids]
if is_guest:
current_room_membership_sql = (
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s)"
" AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
current_room_membership_args = room_ids
else:
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
# We also want to get any membership events about that user, e.g.
# invites or leave notifications.
membership_sql = (
"SELECT m.event_id FROM room_memberships as m "
"INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ? "
)
membership_args = [user_id]
if limit:
limit = max(limit, MAX_STREAM_SIZE)
else:
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key:
return defer.succeed(([], to_key))
sql = (
"SELECT e.event_id, e.stream_ordering FROM events AS e WHERE "
"(e.outlier = ? AND (room_id IN (%(current)s)) OR "
"(event_id IN (%(invites)s))) "
"AND e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
) % {
"current": current_room_membership_sql,
"invites": membership_sql,
"limit": limit
}
def f(txn):
args = ([False] + current_room_membership_args + membership_args +
[from_id.stream, to_id.stream])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows)
if rows:
key = "s%d" % max(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = to_key
return ret, key
return self.runInteraction("get_room_events_stream", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None, def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1): direction='b', limit=-1):
@ -539,7 +449,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'): def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token() token = yield self._stream_id_gen.get_current_token()
if direction != 'b': if direction != 'b':
defer.returnValue("s%d" % (token,)) defer.returnValue("s%d" % (token,))
else: else:

View file

@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns: Returns:
A deferred int. A deferred int.
""" """
return self._account_data_id_gen.get_max_token() return self._account_data_id_gen.get_current_token()
@cached() @cached()
def get_tags_for_user(self, user_id): def get_tags_for_user(self, user_id):
@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(self, txn, user_id, room_id, next_id):

View file

@ -21,7 +21,7 @@ import threading
class IdGenerator(object): class IdGenerator(object):
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
self._lock = threading.Lock() self._lock = threading.Lock()
self._next_id = _load_max_id(db_conn, table, column) self._next_id = _load_current_id(db_conn, table, column)
def get_next(self): def get_next(self):
with self._lock: with self._lock:
@ -29,12 +29,16 @@ class IdGenerator(object):
return self._next_id return self._next_id
def _load_max_id(db_conn, table, column): def _load_current_id(db_conn, table, column, step=1):
cur = db_conn.cursor() cur = db_conn.cursor()
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone() val, = cur.fetchone()
cur.close() cur.close()
return int(val) if val else 1 current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator(object): class StreamIdGenerator(object):
@ -45,17 +49,32 @@ class StreamIdGenerator(object):
all ids less than or equal to it have completed. This handles the fact that all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order. persistence of events can complete out of order.
Args:
db_conn(connection): A database connection to use to fetch the
initial value of the generator from.
table(str): A database table to read the initial value of the id
generator from.
column(str): The column of the database table to read the initial
value from the id generator from.
extra_tables(list): List of pairs of database tables and columns to
use to source the initial value of the generator from. The value
with the largest magnitude is used.
step(int): which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards.
Usage: Usage:
with stream_id_gen.get_next() as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[]): def __init__(self, db_conn, table, column, extra_tables=[], step=1):
assert step != 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column) self._step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables: for table, column in extra_tables:
self._current_max = max( self._current = (max if step > 0 else min)(
self._current_max, self._current,
_load_max_id(db_conn, table, column) _load_current_id(db_conn, table, column, step)
) )
self._unfinished_ids = deque() self._unfinished_ids = deque()
@ -66,8 +85,8 @@ class StreamIdGenerator(object):
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
self._current_max += 1 self._current += self._step
next_id = self._current_max next_id = self._current
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -88,8 +107,12 @@ class StreamIdGenerator(object):
# ... persist events ... # ... persist events ...
""" """
with self._lock: with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1) next_ids = range(
self._current_max += n self._current + self._step,
self._current + self._step * (n + 1),
self._step
)
self._current += n * self._step
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -105,15 +128,15 @@ class StreamIdGenerator(object):
return manager() return manager()
def get_max_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
return self._unfinished_ids[0] - 1 return self._unfinished_ids[0] - self._step
return self._current_max return self._current
class ChainedIdGenerator(object): class ChainedIdGenerator(object):
@ -125,7 +148,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column): def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator self.chained_generator = chained_generator
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column) self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque() self._unfinished_ids = deque()
def get_next(self): def get_next(self):
@ -137,7 +160,7 @@ class ChainedIdGenerator(object):
with self._lock: with self._lock:
self._current_max += 1 self._current_max += 1
next_id = self._current_max next_id = self._current_max
chained_id = self.chained_generator.get_max_token() chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id)) self._unfinished_ids.append((next_id, chained_id))
@ -151,7 +174,7 @@ class ChainedIdGenerator(object):
return manager() return manager()
def get_max_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
@ -160,4 +183,4 @@ class ChainedIdGenerator(object):
stream_id, chained_id = self._unfinished_ids[0] stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id) return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token()) return (self._current_max, self.chained_generator.get_current_token())

View file

@ -49,9 +49,6 @@ class Clock(object):
l.start(msec / 1000.0, now=False) l.start(msec / 1000.0, now=False)
return l return l
def stop_looping_call(self, loop):
loop.stop()
def call_later(self, delay, callback, *args, **kwargs): def call_later(self, delay, callback, *args, **kwargs):
"""Call something later """Call something later

View file

@ -16,7 +16,12 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from .logcontext import PreserveLoggingContext from .logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util import unwrapFirstError
from contextlib import contextmanager
@defer.inlineCallbacks @defer.inlineCallbacks
@ -107,3 +112,76 @@ class ObservableDeferred(object):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred, id(self), self._result, self._deferred,
) )
def concurrently_execute(func, args, limit):
"""Executes the function with each argument conncurrently while limiting
the number of concurrent executions.
Args:
func (func): Function to execute, should return a deferred.
args (list): List of arguments to pass to func, each invocation of func
gets a signle argument.
limit (int): Maximum number of conccurent executions.
Returns:
deferred: Resolved when all function invocations have finished.
"""
it = iter(args)
@defer.inlineCallbacks
def _concurrently_execute_inner():
try:
while True:
yield func(it.next())
except StopIteration:
pass
return defer.gatherResults([
preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit)
], consumeErrors=True).addErrback(unwrapFirstError)
class Linearizer(object):
"""Linearizes access to resources based on a key. Useful to ensure only one
thing is happening at a time on a given resource.
Example:
with (yield linearizer.queue("test_key")):
# do some work.
"""
def __init__(self):
self.key_to_defer = {}
@defer.inlineCallbacks
def queue(self, key):
# If there is already a deferred in the queue, we pull it out so that
# we can wait on it later.
# Then we replace it with a deferred that we resolve *after* the
# context manager has exited.
# We only return the context manager after the previous deferred has
# resolved.
# This all has the net effect of creating a chain of deferreds that
# wait for the previous deferred before starting their work.
current_defer = self.key_to_defer.get(key)
new_defer = defer.Deferred()
self.key_to_defer[key] = new_defer
if current_defer:
yield preserve_context_over_deferred(current_defer)
@contextmanager
def _ctx_manager():
try:
yield
finally:
new_defer.callback(None)
current_d = self.key_to_defer.get(key)
if current_d is new_defer:
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())

View file

@ -167,7 +167,8 @@ class CacheDescriptor(object):
% (orig.__name__,) % (orig.__name__,)
) )
self.cache = Cache( def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
@ -175,14 +176,12 @@ class CacheDescriptor(object):
tree=self.tree, tree=self.tree,
) )
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try: try:
cached_result_d = self.cache.get(cache_key) cached_result_d = cache.get(cache_key)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -204,7 +203,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = self.cache.sequence sequence = cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn, preserve_context_over_fn,
@ -213,20 +212,21 @@ class CacheDescriptor(object):
) )
def onErr(f): def onErr(f):
self.cache.invalidate(cache_key) cache.invalidate(cache_key)
return f return f
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret) cache.update(sequence, cache_key, ret)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = self.cache.invalidate_many wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = self.cache.prefill wrapped.prefill = cache.prefill
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped
@ -240,11 +240,12 @@ class CacheListDescriptor(object):
the list of missing keys to the wrapped fucntion. the list of missing keys to the wrapped fucntion.
""" """
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): def __init__(self, orig, cached_method_name, list_name, num_args=1,
inlineCallbacks=False):
""" """
Args: Args:
orig (function) orig (function)
cache (Cache) method_name (str); The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list list_name (str): Name of the argument which is the bulk lookup list
num_args (int) num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should inlineCallbacks (bool): Whether orig is a generator that should
@ -263,7 +264,7 @@ class CacheListDescriptor(object):
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name) self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache self.cached_method_name = cached_method_name
self.sentinel = object() self.sentinel = object()
@ -277,11 +278,13 @@ class CacheListDescriptor(object):
if self.list_name not in self.arg_names: if self.list_name not in self.arg_names:
raise Exception( raise Exception(
"Couldn't see arguments %r for %r." "Couldn't see arguments %r for %r."
% (self.list_name, cache.name,) % (self.list_name, cached_method_name,)
) )
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cache = getattr(obj, self.cached_method_name).cache
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
@ -297,14 +300,14 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = self.cache.get(tuple(key)).observe() res = cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res cached[arg] = res
except KeyError: except KeyError:
missing.append(arg) missing.append(arg)
if missing: if missing:
sequence = self.cache.sequence sequence = cache.sequence
args_to_call = dict(arg_dict) args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
@ -327,10 +330,10 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer) cache.update(sequence, tuple(key), observer)
def invalidate(f, key): def invalidate(f, key):
self.cache.invalidate(key) cache.invalidate(key)
return f return f
observer.addErrback(invalidate, tuple(key)) observer.addErrback(invalidate, tuple(key))
@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
) )
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`. """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument Used to do batch lookups for an already created cache. A single argument
@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
""" """
return lambda orig: CacheListDescriptor( return lambda orig: CacheListDescriptor(
orig, orig,
cache=cache, cached_method_name=cached_method_name,
list_name=list_name, list_name=list_name,
num_args=num_args, num_args=num_args,
inlineCallbacks=inlineCallbacks, inlineCallbacks=inlineCallbacks,

View file

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.async import ObservableDeferred
class ResponseCache(object):
"""
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
while the response is still being computed, that original response will be
used rather than trying to compute a new response.
"""
def __init__(self):
self.pending_result_cache = {} # Requests that haven't finished yet.
def get(self, key):
result = self.pending_result_cache.get(key)
if result is not None:
return result.observe()
else:
return None
def set(self, key, deferred):
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
def remove(r):
self.pending_result_cache.pop(key, None)
return r
result.addBoth(remove)
return result.observe()

View file

@ -15,7 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_fn
)
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
@ -25,6 +27,24 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def registered_user(distributor, user):
return distributor.fire("registered_user", user)
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class Distributor(object): class Distributor(object):
"""A central dispatch point for loosely-connected pieces of code to """A central dispatch point for loosely-connected pieces of code to
register, observe, and fire signals. register, observe, and fire signals.

View file

@ -100,20 +100,6 @@ class _PerHostRatelimiter(object):
self.current_processing = set() self.current_processing = set()
self.request_times = [] self.request_times = []
def is_empty(self):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
return not (
self.ready_request_queue
or self.sleeping_requests
or self.current_processing
or self.request_times
)
@contextlib.contextmanager @contextlib.contextmanager
def ratelimit(self): def ratelimit(self):
# `contextlib.contextmanager` takes a generator and turns it into a # `contextlib.contextmanager` takes a generator and turns it into a

View file

@ -21,10 +21,6 @@ _string_with_symbols = (
) )
def origin_from_ucid(ucid):
return ucid.split("@", 1)[1]
def random_string(length): def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,57 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.replication.slave.storage.events import SlavedEventStore
from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource
class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
"blue",
http_client=None,
replication_layer=Mock(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
)
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore()
self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs)
self.event_id = 0
@defer.inlineCallbacks
def replicate(self):
streams = self.slaved_store.stream_positions()
result = yield self.replication.replicate(streams, 100)
yield self.slaved_store.process_replication(result)
@defer.inlineCallbacks
def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args)
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)
self.assertEqual(master_result, slaved_result)

View file

@ -0,0 +1,307 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStoreTestCase
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
from synapse.storage.roommember import RoomsForUser
from twisted.internet import defer
USER_ID = "@feeling:blue"
USER_ID_2 = "@bright:blue"
OUTLIER = {"outlier": True}
ROOM_ID = "!room:blue"
def dict_equals(self, other):
return self.__dict__ == other.__dict__
def patch__eq__(cls):
eq = getattr(cls, "__eq__", None)
cls.__eq__ = dict_equals
def unpatch():
if eq is not None:
cls.__eq__ = eq
return unpatch
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def setUp(self):
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals
self.unpatches = [
patch__eq__(_EventInternalMetadata),
patch__eq__(FrozenEvent),
]
return super(SlavedEventStoreTestCase, self).setUp()
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
def test_room_name_and_aliases(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.persist(type="m.room.name", key="", name="name1")
yield self.persist(
type="m.room.aliases", key="blue", aliases=["#1:blue"]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"])
)
# Set the room name.
yield self.persist(type="m.room.name", key="", name="name2")
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"])
)
# Set the room aliases.
yield self.persist(
type="m.room.aliases", key="blue", aliases=["#2:blue"]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"])
)
# Leave and join the room clobbering the state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), (None, [])
)
@defer.inlineCallbacks
def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Join the room.
join = yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
# Leave the room.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Add some other user to the room.
join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
# Join the room clobbering the state.
# This should remove any evidence of the other user being in the room.
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
yield self.check("get_rooms_for_user", (USER_ID_2,), [])
@defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check(
"get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]
)
join = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
prev_events=[(create.event_id, {})],
)
yield self.replicate()
yield self.check(
"get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
)
@defer.inlineCallbacks
def test_get_current_state(self):
# Create the room.
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
)
# Join the room.
join1 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join1]
)
# Add some other user to the room.
join2 = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[join2]
)
# Leave the room, then rejoin the room clobbering state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
join3 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[]
)
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join3]
)
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_backfilled_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
event_id = 0
@defer.inlineCallbacks
def persist(
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
state=None, reset_state=False, backfill=False,
depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
**content
):
"""
Returns:
synapse.events.FrozenEvent: The event that was persisted.
"""
if depth is None:
depth = self.event_id
event_dict = {
"sender": sender,
"type": type,
"content": content,
"event_id": "$%d:blue" % (self.event_id,),
"room_id": room_id,
"depth": depth,
"origin_server_ts": self.event_id,
"prev_events": prev_events,
"auth_events": auth_events,
}
if key is not None:
event_dict["state_key"] = key
event_dict["prev_state"] = prev_state
if redacts is not None:
event_dict["redacts"] = redacts
event = FrozenEvent(event_dict, internal_metadata_dict=internal)
self.event_id += 1
context = EventContext(current_state=state)
ordering = None
if backfill:
yield self.master_store.persist_events(
[(event, context)], backfilled=True
)
else:
ordering, _ = yield self.master_store.persist_event(
event, context, current_state=reset_state
)
if ordering:
event.internal_metadata.stream_ordering = ordering
defer.returnValue(event)

View file

@ -58,15 +58,21 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body, {}) self.assertEquals(body, {})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_events(self): def test_events_and_state(self):
get = self.get(events="-1", timeout="0") get = self.get(events="-1", state="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room( yield self.hs.get_handlers().room_creation_handler.create_room(
Requester(self.user, "", False), {} Requester(self.user, "", False), {}
) )
code, body = yield get code, body = yield get
self.assertEquals(code, 200) self.assertEquals(code, 200)
self.assertEquals(body["events"]["field_names"], [ self.assertEquals(body["events"]["field_names"], [
"position", "internal", "json" "position", "internal", "json", "state_group"
])
self.assertEquals(body["state_groups"]["field_names"], [
"position", "room_id", "event_id"
])
self.assertEquals(body["state_group_state"]["field_names"], [
"position", "type", "state_key", "event_id"
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
@ -132,6 +138,7 @@ class ReplicationResourceCase(unittest.TestCase):
test_timeout_backfill = _test_timeout("backfill") test_timeout_backfill = _test_timeout("backfill")
test_timeout_push_rules = _test_timeout("push_rules") test_timeout_push_rules = _test_timeout("push_rules")
test_timeout_pushers = _test_timeout("pushers") test_timeout_pushers = _test_timeout("pushers")
test_timeout_state = _test_timeout("state")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_text_message(self, room_id, message): def send_text_message(self, room_id, message):
@ -182,4 +189,21 @@ class ReplicationResourceCase(unittest.TestCase):
) )
response_body = json.loads(response_json) response_body = json.loads(response_json)
if response_code == 200:
self.check_response(response_body)
defer.returnValue((response_code, response_body)) defer.returnValue((response_code, response_body))
def check_response(self, response_body):
for name, stream in response_body.items():
self.assertIn("field_names", stream)
field_names = stream["field_names"]
self.assertIn("rows", stream)
self.assertTrue(stream["rows"])
for row in stream["rows"]:
self.assertEquals(
len(row), len(field_names),
"%s: len(row = %r) == len(field_names = %r)" % (
name, row, field_names
)
)

View file

@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other, # set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server # expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]: for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=404) yield self.join(room=room, user=usr, expect_code=403)
yield self.leave(room=room, user=usr, expect_code=404) yield self.leave(room=room, user=usr, expect_code=403)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_membership_private_room_perms(self): def test_membership_private_room_perms(self):

View file

@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
"test", "test",
db_pool=self.db_pool, db_pool=self.db_pool,
config=config, config=config,
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
) )
self.datastore = SQLBaseStore(hs) self.datastore = SQLBaseStore(hs)

View file

@ -34,33 +34,6 @@ class PresenceStoreTestCase(unittest.TestCase):
self.u_apple = UserID.from_string("@apple:test") self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test") self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_visibility(self):
self.assertFalse((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
yield self.store.allow_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)
self.assertTrue((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
yield self.store.disallow_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)
self.assertFalse((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_presence_list(self): def test_presence_list(self):
self.assertEquals( self.assertEquals(

View file

@ -110,22 +110,10 @@ class RedactionTestCase(unittest.TestCase):
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
start = yield self.store.get_room_events_max_id()
msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") msg_event = yield self.inject_message(self.room1, self.u_alice, u"t")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
# Check event has not been redacted: # Check event has not been redacted:
event = results[0] event = yield self.store.get_event(msg_event.event_id)
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
@ -144,17 +132,7 @@ class RedactionTestCase(unittest.TestCase):
self.room1, msg_event.event_id, self.u_alice, reason self.room1, msg_event.event_id, self.u_alice, reason
) )
results, _ = yield self.store.get_room_events_stream( event = yield self.store.get_event(msg_event.event_id)
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
# Check redaction
event = results[0]
self.assertEqual(msg_event.event_id, event.event_id) self.assertEqual(msg_event.event_id, event.event_id)
@ -184,25 +162,12 @@ class RedactionTestCase(unittest.TestCase):
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
start = yield self.store.get_room_events_max_id()
msg_event = yield self.inject_room_member( msg_event = yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN, self.room1, self.u_bob, Membership.JOIN,
extra_content={"blue": "red"}, extra_content={"blue": "red"},
) )
end = yield self.store.get_room_events_max_id() event = yield self.store.get_event(msg_event.event_id)
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
# Check event has not been redacted:
event = results[0]
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
@ -221,17 +186,9 @@ class RedactionTestCase(unittest.TestCase):
self.room1, msg_event.event_id, self.u_alice, reason self.room1, msg_event.event_id, self.u_alice, reason
) )
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
# Check redaction # Check redaction
event = results[0] event = yield self.store.get_event(msg_event.event_id)
self.assertTrue("redacted_because" in event.unsigned) self.assertTrue("redacted_because" in event.unsigned)

View file

@ -70,13 +70,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
def test_one_member(self): def test_one_member(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
Membership.JOIN,
(yield self.store.get_room_member(
user_id=self.u_alice.to_string(),
room_id=self.room.to_string(),
)).membership
)
self.assertEquals( self.assertEquals(
[self.u_alice.to_string()], [self.u_alice.to_string()],
[m.user_id for m in ( [m.user_id for m in (

View file

@ -1,185 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver
from mock import Mock
class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(
resource_for_federation=Mock(),
http_client=None,
)
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_injector = EventInjector(hs)
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
self.room1 = RoomID.from_string("!abc123:test")
self.room2 = RoomID.from_string("!xyx987:test")
@defer.inlineCallbacks
def test_event_stream_get_other(self):
# Both bob and alice joins the room
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
event = results[0]
self.assertObjectHasAttributes(
{
"type": EventTypes.Message,
"user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"},
},
event,
)
@defer.inlineCallbacks
def test_event_stream_get_own(self):
# Both bob and alice joins the room
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
event = results[0]
self.assertObjectHasAttributes(
{
"type": EventTypes.Message,
"user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"},
},
event,
)
@defer.inlineCallbacks
def test_event_stream_join_leave(self):
# Both bob and alice joins the room
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Then bob leaves again.
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.LEAVE
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
)
# We should not get the message, as it happened *after* bob left.
self.assertEqual(0, len(results))
@defer.inlineCallbacks
def test_event_stream_prev_content(self):
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
start = yield self.store.get_room_events_max_id()
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN,
)
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
)
# We should not get the message, as it happened *after* bob left.
self.assertEqual(1, len(results))
event = results[0]
self.assertTrue(
"prev_content" in event.unsigned,
msg="No prev_content key"
)

View file

@ -21,6 +21,8 @@ from mock import Mock
from synapse.http.endpoint import resolve_service from synapse.http.endpoint import resolve_service
from tests.utils import MockClock
class DnsTestCase(unittest.TestCase): class DnsTestCase(unittest.TestCase):
@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(servers[0].host, ip_address) self.assertEquals(servers[0].host, ip_address)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_from_cache(self): def test_from_cache_expired_and_dns_fail(self):
dns_client_mock = Mock() dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.examle.com" service_name = "test_service.examle.com"
entry = Mock(spec_set=["expires"])
entry.expires = 0
cache = { cache = {
service_name: [object()] service_name: [entry]
} }
servers = yield resolve_service( servers = yield resolve_service(
@ -82,6 +87,31 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(len(servers), 1) self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name]) self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
def test_from_cache(self):
clock = MockClock()
dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[])
service_name = "test_service.examle.com"
entry = Mock(spec_set=["expires"])
entry.expires = 999999999
cache = {
service_name: [entry]
}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache, clock=clock,
)
self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_empty_cache(self): def test_empty_cache(self):
dns_client_mock = Mock() dns_client_mock = Mock()

View file

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.util.async import Linearizer
class LinearizerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_linearizer(self):
linearizer = Linearizer()
key = object()
d1 = linearizer.queue(key)
cm1 = yield d1
d2 = linearizer.queue(key)
self.assertFalse(d2.called)
with cm1:
self.assertFalse(d2.called)
self.assertTrue(d2.called)
with (yield d2):
pass

View file

@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer( hs = HomeServer(
name, db_pool=db_pool, config=config, name, db_pool=db_pool, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
get_db_conn=db_pool.get_db_conn, get_db_conn=db_pool.get_db_conn,
**kargs **kargs
) )
@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
**kargs **kargs
) )
@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
return conn return conn
def create_engine(self): def create_engine(self):
return create_engine(self.config) return create_engine(self.config.database_config)
class MemoryDataStore(object): class MemoryDataStore(object):