Merge branch 'develop' into HEAD

This commit is contained in:
Erik Johnston 2016-04-11 10:40:49 +01:00
commit 58de897c77
84 changed files with 3118 additions and 1876 deletions

View file

@ -57,3 +57,6 @@ Florent Violleau <floviolleau at gmail dot com>
Niklas Riekenbrauck <nikriek at gmail dot.com>
* Add JWT support for registration and login
Christoph Witzany <christoph at web.crofting.com>
* Add LDAP support for authentication

View file

@ -118,7 +118,6 @@ Installing prerequisites on CentOS 7::
python-virtualenv libffi-devel openssl-devel
sudo yum groupinstall "Development Tools"
Installing prerequisites on Mac OS X::
xcode-select --install
@ -150,12 +149,7 @@ In case of problems, please see the _Troubleshooting section below.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
Another alternative is to install via apt from http://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with
https://github.com/matrix-org/matrix-js-sdk/).
Finally, Martin Giess has created an auto-deployment process with vagrant/ansible,
Also, Martin Giess has created an auto-deployment process with vagrant/ansible,
tested with VirtualBox/AWS/DigitalOcean - see https://github.com/EMnify/matrix-synapse-auto-deploy
for details.
@ -229,6 +223,19 @@ For information on how to install and use PostgreSQL, please see
Platform Specific Instructions
==============================
Debian
------
Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with one of our SDKs :)
Fedora
------
Oleg Girko provides Fedora RPMs at
https://obs.infoserver.lv/project/monitor/matrix-synapse
ArchLinux
---------
@ -270,11 +277,17 @@ During setup of Synapse you need to call python2.7 directly again::
FreeBSD
-------
Synapse can be installed via FreeBSD Ports or Packages:
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
- Packages: ``pkg install py27-matrix-synapse``
NixOS
-----
Robin Lambertz has packaged Synapse for NixOS at:
https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix
Windows Install
---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages:

View file

@ -19,6 +19,7 @@ from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse
import curses
@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = {
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],
}
@ -292,7 +294,7 @@ class Porter(object):
}
)
database_engine.prepare_database(db_conn)
prepare_database(db_conn, database_engine, config=None)
db_conn.commit()
@ -309,8 +311,8 @@ class Porter(object):
**self.postgres_config["args"]
)
sqlite_engine = create_engine(FakeConfig(sqlite_config))
postgres_engine = create_engine(FakeConfig(postgres_config))
sqlite_engine = create_engine(sqlite_config)
postgres_engine = create_engine(postgres_config)
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)
@ -792,8 +794,3 @@ if __name__ == "__main__":
if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback)
class FakeConfig:
def __init__(self, database_config):
self.database_config = database_config

View file

@ -17,3 +17,6 @@ ignore =
[flake8]
max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
[pep8]
max-line-length = 90

View file

@ -33,7 +33,7 @@ from synapse.python_dependencies import (
from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.server import HomeServer
@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
def get_db_conn(self):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
@ -254,7 +254,8 @@ class SynapseHomeServer(HomeServer):
}
db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
@ -386,7 +387,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config)
database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer(
@ -402,8 +403,10 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
db_conn = hs.get_db_conn()
database_engine.prepare_database(db_conn)
db_conn = hs.get_db_conn(run_new_connection=False)
prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()

View file

@ -100,11 +100,6 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("push_bulk to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def push(self, service, event, txn_id=None):
response = yield self.push_bulk(service, [event], txn_id)
defer.returnValue(response)
def _serialize(self, events):
time_now = self.clock.time_msec()
return [

View file

@ -29,13 +29,15 @@ from .key import KeyConfig
from .saml2 import SAML2Config
from .cas import CasConfig
from .password import PasswordConfig
from .jwt import JWTConfig
from .ldap import LDAPConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,):
JWTConfig, LDAPConfig, PasswordConfig,):
pass

37
synapse/config/jwt.py Normal file
View file

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class JWTConfig(Config):
def read_config(self, config):
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
else:
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
def default_config(self, **kwargs):
return """\
# jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
"""

52
synapse/config/ldap.py Normal file
View file

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class LDAPConfig(Config):
def read_config(self, config):
ldap_config = config.get("ldap_config", None)
if ldap_config:
self.ldap_enabled = ldap_config.get("enabled", False)
self.ldap_server = ldap_config["server"]
self.ldap_port = ldap_config["port"]
self.ldap_tls = ldap_config.get("tls", False)
self.ldap_search_base = ldap_config["search_base"]
self.ldap_search_property = ldap_config["search_property"]
self.ldap_email_property = ldap_config["email_property"]
self.ldap_full_name_property = ldap_config["full_name_property"]
else:
self.ldap_enabled = False
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = False
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
# server: "ldap://localhost"
# port: 389
# tls: false
# search_base: "ou=Users,dc=example,dc=com"
# search_property: "cn"
# email_property: "email"
# full_name_property: "givenName"
"""

View file

@ -31,7 +31,10 @@ class _EventInternalMetadata(object):
return dict(self.__dict__)
def is_outlier(self):
return hasattr(self, "outlier") and self.outlier
return getattr(self, "outlier", False)
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def _event_dict_property(key):

View file

@ -17,8 +17,9 @@ from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler
from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler,
RoomCreationHandler, RoomListHandler, RoomContextHandler,
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler

View file

@ -37,12 +37,22 @@ VISIBILITY_PRIORITY = (
)
MEMBERSHIP_PRIORITY = (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
Membership.LEAVE,
Membership.BAN,
)
class BaseHandler(object):
"""
Common base class for the event handlers.
:type store: synapse.storage.events.StateStore
:type state_handler: synapse.state.StateHandler
Attributes:
store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
@ -65,11 +75,13 @@ class BaseHandler(object):
""" Returns dict of user_id -> list of events that user is allowed to
see.
:param (str, bool) user_tuples: (user id, is_peeking) for each
user to be checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
Args:
user_tuples (str, bool): (user id, is_peeking) for each user to be
checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the
given events
events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
@ -84,6 +96,12 @@ class BaseHandler(object):
)
def allowed(event, user_id, is_peeking):
"""
Args:
event (synapse.events.EventBase): event to check
user_id (str)
is_peeking (bool)
"""
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
@ -115,17 +133,30 @@ class BaseHandler(object):
if old_priority < new_priority:
visibility = prev_visibility
# get the user's membership at the time of the event. (or rather,
# just *after* the event. Which means that people can see their
# own join events, but not (currently) their own leave events.)
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id in event_id_forgotten:
membership = None
else:
membership = membership_event.membership
else:
membership = None
# likewise, if the event is the user's own membership event, use
# the 'most joined' membership
membership = None
if event.type == EventTypes.Member and event.state_key == user_id:
membership = event.content.get("membership", None)
if membership not in MEMBERSHIP_PRIORITY:
membership = "leave"
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority:
membership = prev_membership
# otherwise, get the user's membership at the time of the event.
if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id not in event_id_forgotten:
membership = membership_event.membership
# if the user was a member of the room at the time of the event,
# they can see it.
@ -165,13 +196,16 @@ class BaseHandler(object):
"""
Check which events a user is allowed to see
:param str user_id: user id to be checked
:param [synapse.events.EventBase] events: list of events to be checked
:param bool is_peeking should be True if:
Args:
user_id(str): user id to be checked
events([synapse.events.EventBase]): list of events to be checked
is_peeking(bool): should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
:rtype [synapse.events.EventBase]
Returns:
[synapse.events.EventBase]
"""
types = (
(EventTypes.RoomHistoryVisibility, ""),
@ -199,20 +233,25 @@ class BaseHandler(object):
)
@defer.inlineCallbacks
def _create_new_client_event(self, builder):
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
else:
depth = 1
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events
builder.depth = depth
@ -221,50 +260,6 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
# If we've received an invite over federation, there are no latest
# events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite.
if (
not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member
):
prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id
)
# The prev_member_event may already be in context.current_state,
# despite us not being present in the room; in particular, if
# inviting user, and all other local users, have already left.
#
# In that case, we have all the information we need, and we don't
# want to drop "context" - not least because we may need to handle
# the invite locally, which will require us to have the whole
# context (not just prev_member_event) to auth it.
#
context_event_ids = (
e.event_id for e in context.current_state.values()
)
if (
prev_member_event and
prev_member_event.event_id not in context_event_ids
):
# The prev_member_event is missing from context, so it must
# have arrived over federation and is an outlier. We forcibly
# set our context to the invite we received over federation
builder.prev_events = (
prev_member_event.event_id,
prev_member_event.prev_events
)
context = yield state_handler.compute_event_context(
builder,
old_state=(prev_member_event,),
outlier=True
)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events

View file

@ -49,6 +49,21 @@ class AuthHandler(BaseHandler):
self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled
self.ldap_server = hs.config.ldap_server
self.ldap_port = hs.config.ldap_port
self.ldap_tls = hs.config.ldap_tls
self.ldap_search_base = hs.config.ldap_search_base
self.ldap_search_property = hs.config.ldap_search_property
self.ldap_email_property = hs.config.ldap_email_property
self.ldap_full_name_property = hs.config.ldap_full_name_property
if self.ldap_enabled is True:
import ldap
logger.info("Import ldap version: %s", ldap.__version__)
self.hs = hs # FIXME better possibility to access registrationHandler later?
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
@ -163,9 +178,13 @@ class AuthHandler(BaseHandler):
def get_session_id(self, clientdict):
"""
Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not
send a session ID, returns None.
Args:
clientdict: The dictionary sent by the client in the request
Returns:
str|None: The string session ID the client sent. If the client did
not send a session ID, returns None.
"""
sid = None
if clientdict and 'auth' in clientdict:
@ -179,9 +198,11 @@ class AuthHandler(BaseHandler):
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param value: (any) The data to store
Args:
session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
value (any): The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value
@ -190,9 +211,11 @@ class AuthHandler(BaseHandler):
def get_session_data(self, session_id, key, default=None):
"""
Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param default: (any) Value to return if the key has not been set
Args:
session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
default (any): Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
@ -207,8 +230,10 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks
@ -332,8 +357,10 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
@ -399,11 +426,60 @@ class AuthHandler(BaseHandler):
else:
defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
if not self.validate_hash(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def _check_password(self, user_id, password):
defer.returnValue(
not (
(yield self._check_ldap_password(user_id, password))
or
(yield self._check_local_password(user_id, password))
))
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(not self.validate_hash(password, password_hash))
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
if self.ldap_enabled is not True:
logger.debug("LDAP not configured")
defer.returnValue(False)
import ldap
logger.info("Authenticating %s with LDAP" % user_id)
try:
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
logger.debug("Connecting LDAP server at %s" % ldap_url)
l = ldap.initialize(ldap_url)
if self.ldap_tls:
logger.debug("Initiating TLS")
self._connection.start_tls_s()
local_name = UserID.from_string(user_id).localpart
dn = "%s=%s, %s" % (
self.ldap_search_property,
local_name,
self.ldap_search_base)
logger.debug("DN for LDAP authentication: %s" % dn)
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
if not (yield self.does_user_exist(user_id)):
handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield handler.register(localpart=local_name)
)
defer.returnValue(True)
except ldap.LDAPError, e:
logger.warn("LDAP error: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks
def issue_access_token(self, user_id):

View file

@ -40,6 +40,7 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination
from synapse.push.action_generator import ActionGenerator
from synapse.util.distributor import user_joined_room
from twisted.internet import defer
@ -49,10 +50,6 @@ import logging
logger = logging.getLogger(__name__)
def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user, room_id)
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
@ -102,8 +99,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None,
auth_chain=None):
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler.
"""
@ -174,11 +170,7 @@ class FederationHandler(BaseHandler):
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
)
yield self._handle_new_events(origin, event_infos)
try:
context, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -289,6 +281,9 @@ class FederationHandler(BaseHandler):
def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id`
"""
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
@ -455,7 +450,7 @@ class FederationHandler(BaseHandler):
likely_domains = [
domain for domain, depth in curr_domains
if domain is not self.server_name
if domain != self.server_name
]
@defer.inlineCallbacks
@ -761,6 +756,7 @@ class FederationHandler(BaseHandler):
event = pdu
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
event.signatures.update(
compute_event_signature(
@ -788,13 +784,19 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
"leave"
)
signed_event = self._sign_event(event)
try:
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
"leave"
)
signed_event = self._sign_event(event)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
# Try the host we successfully got a response to /make_join/
# request first.
@ -804,10 +806,16 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
yield self.replication_layer.send_leave(
target_hosts,
signed_event
)
try:
yield self.replication_layer.send_leave(
target_hosts,
signed_event
)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
context = yield self.state_handler.compute_event_context(event)
@ -1069,9 +1077,6 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self._prep_event(
origin, event,
state=state,
@ -1087,14 +1092,12 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
is_new_state=not outlier,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False,
outliers=False):
def _handle_new_events(self, origin, event_infos, backfilled=False):
contexts = yield defer.gatherResults(
[
self._prep_event(
@ -1113,7 +1116,6 @@ class FederationHandler(BaseHandler):
for ev_info, context in itertools.izip(event_infos, contexts)
],
backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
)
@defer.inlineCallbacks
@ -1128,11 +1130,9 @@ class FederationHandler(BaseHandler):
"""
events_to_context = {}
for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True
ctx = yield self.state_handler.compute_event_context(e)
events_to_context[e.event_id] = ctx
event_map = {
e.event_id: e
@ -1176,16 +1176,14 @@ class FederationHandler(BaseHandler):
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
],
is_new_state=False,
)
new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False,
event, old_state=state
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context,
is_new_state=True,
current_state=state,
)
@ -1193,10 +1191,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=outlier,
event, old_state=state,
)
if not auth_events:
@ -1718,13 +1715,15 @@ class FederationHandler(BaseHandler):
def _check_signature(self, event, auth_events):
"""
Checks that the signature in the event is consistent with its invite.
:param event (Event): The m.room.member event to check
:param auth_events (dict<(event type, state_key), event>)
:raises
AuthError if signature didn't match any keys, or key has been
Args:
event (Event): The m.room.member event to check
auth_events (dict<(event type, state_key), event>):
Raises:
AuthError: if signature didn't match any keys, or key has been
revoked,
SynapseError if a transient error meant a key couldn't be checked
SynapseError: if a transient error meant a key couldn't be checked
for revocation.
"""
signed = event.content["third_party_invite"]["signed"]
@ -1766,12 +1765,13 @@ class FederationHandler(BaseHandler):
"""
Checks whether public_key has been revoked.
:param public_key (str): base-64 encoded public key.
:param url (str): Key revocation URL.
Args:
public_key (str): base-64 encoded public key.
url (str): Key revocation URL.
:raises
AuthError if they key has been revoked.
SynapseError if a transient error meant a key couldn't be checked
Raises:
AuthError: if they key has been revoked.
SynapseError: if a transient error meant a key couldn't be checked
for revocation.
"""
try:

View file

@ -21,6 +21,7 @@ from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken
@ -33,10 +34,6 @@ import logging
logger = logging.getLogger(__name__)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class MessageHandler(BaseHandler):
def __init__(self, hs):
@ -47,35 +44,6 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
@defer.inlineCallbacks
def get_message(self, msg_id=None, room_id=None, sender_id=None,
user_id=None):
""" Retrieve a message.
Args:
msg_id (str): The message ID to obtain.
room_id (str): The room where the message resides.
sender_id (str): The user ID of the user who sent the message.
user_id (str): The user ID of the user making this request.
Returns:
The message, or None if no message exists.
Raises:
SynapseError if something went wrong.
"""
yield self.auth.check_joined_room(room_id, user_id)
# Pull out the message from the db
# msg = yield self.store.get_message(
# room_id=room_id,
# msg_id=msg_id,
# user_id=sender_id
# )
# TODO (erikj): Once we work out the correct c-s api we need to think
# on how to do this.
defer.returnValue(None)
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True):
@ -175,7 +143,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None):
def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
"""
Given a dict from a client, create a new event.
@ -186,6 +154,9 @@ class MessageHandler(BaseHandler):
Args:
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
@ -198,12 +169,8 @@ class MessageHandler(BaseHandler):
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership == Membership.JOIN:
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
yield collect_presencelike_data(
self.distributor, target, builder.content
)
elif membership == Membership.INVITE:
profile = self.hs.get_handlers().profile_handler
content = builder.content
@ -224,6 +191,7 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event(
builder=builder,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@ -556,14 +524,7 @@ class MessageHandler(BaseHandler):
except:
logger.exception("Failed to get snapshot")
# Only do N rooms at once
n = 5
d_list = [handle_room(e) for e in room_list]
for i in range(0, len(d_list), n):
yield defer.gatherResults(
d_list[i:i + n],
consumeErrors=True
).addErrback(unwrapFirstError)
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():

View file

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester
from synapse.util import unwrapFirstError
from ._base import BaseHandler
@ -27,14 +26,6 @@ import logging
logger = logging.getLogger(__name__)
def changed_presencelike_data(distributor, user, state):
return distributor.fire("changed_presencelike_data", user, state)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class ProfileHandler(BaseHandler):
def __init__(self, hs):
@ -46,17 +37,9 @@ class ProfileHandler(BaseHandler):
)
distributor = hs.get_distributor()
self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user)
distributor.observe(
"collect_presencelike_data", self.collect_presencelike_data
)
def registered_user(self, user):
return self.store.create_profile(user.localpart)
@ -105,10 +88,6 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname
)
yield changed_presencelike_data(self.distributor, target_user, {
"displayname": new_displayname,
})
yield self._update_join_states(requester)
@defer.inlineCallbacks
@ -152,30 +131,8 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url
)
yield changed_presencelike_data(self.distributor, target_user, {
"avatar_url": new_avatar_url,
})
yield self._update_join_states(requester)
@defer.inlineCallbacks
def collect_presencelike_data(self, user, state):
if not self.hs.is_mine(user):
defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
).addErrback(unwrapFirstError)
state["displayname"] = displayname
state["avatar_url"] = avatar_url
defer.returnValue(None)
@defer.inlineCallbacks
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])

View file

@ -23,6 +23,7 @@ from synapse.api.errors import (
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
from synapse.util.distributor import registered_user
import logging
import urllib
@ -30,10 +31,6 @@ import urllib
logger = logging.getLogger(__name__)
def registered_user(distributor, user):
return distributor.fire("registered_user", user)
class RegistrationHandler(BaseHandler):
def __init__(self, hs):

View file

@ -18,19 +18,16 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
EventTypes, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError
from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from collections import OrderedDict
from unpaddedbase64 import decode_base64
import logging
import math
@ -41,20 +38,6 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://"
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomCreationHandler(BaseHandler):
PRESETS_DICT = {
@ -356,594 +339,23 @@ class RoomCreationHandler(BaseHandler):
)
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
elif action == "forget":
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": effective_membership_state}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": effective_membership_state,
},
token_id=requester.access_token_id,
txn_id=txn_id,
)
old_state = context.current_state.get((EventTypes.Member, event.state_key))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(
requester,
event,
context,
ratelimit=ratelimit,
remote_room_hosts=remote_room_hosts,
)
if action == "forget":
yield self.forget(requester.user, room_id)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
action = "send"
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
do_remote_join_dance, remote_room_hosts = self._should_do_dance(
context,
(self.get_inviter(event.state_key, context.current_state)),
remote_room_hosts,
)
if do_remote_join_dance:
action = "remote_join"
elif event.membership == Membership.LEAVE:
is_host_in_room = self.is_host_in_room(context.current_state)
if not is_host_in_room:
# perhaps we've been invited
inviter = self.get_inviter(target_user.to_string(), context.current_state)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
action = "remote_reject"
federation_handler = self.hs.get_handlers().federation_handler
if action == "remote_join":
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield federation_handler.do_invite_join(
remote_room_hosts,
event.room_id,
event.user_id,
event.content,
)
elif action == "remote_reject":
yield federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
event.user_id
)
else:
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
def _should_do_dance(self, context, inviter, room_hosts=None):
# TODO: Shouldn't this be remote_room_host?
room_hosts = room_hosts or []
is_host_in_room = self.is_host_in_room(context.current_state)
if is_host_in_room:
return False, room_hosts
if inviter and not self.hs.is_mine(inviter):
room_hosts.append(inviter.domain)
return True, room_hosts
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
def get_inviter(self, user_id, current_state):
prev_state = current_state.get((EventTypes.Member, user_id))
if prev_state and prev_state.membership == Membership.INVITE:
return UserID.from_string(prev_state.user_id)
return None
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
(str) the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
:param id_server (str): hostname + optional port for the identity server.
:param medium (str): The literal string "email".
:param address (str): The third party address being invited.
:param room_id (str): The ID of the room to which the user is invited.
:param inviter_user_id (str): The user ID of the inviter.
:param room_alias (str): An alias for the room, for cosmetic
notifications.
:param room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
:param room_join_rules (str): The join rules of the email
(e.g. "public").
:param room_name (str): The m.room.name of the room.
:param inviter_display_name (str): The current display name of the
inviter.
:param inviter_avatar_url (str): The URL of the inviter's avatar.
:return: A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
def forget(self, user, room_id):
return self.store.forget(user.to_string(), room_id)
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache()
def get_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
return result
@defer.inlineCallbacks
def get_public_room_list(self):
def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks
def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
@ -1004,18 +416,12 @@ class RoomListHandler(BaseHandler):
joined_users = yield self.store.get_users_in_room(room_id)
result["num_joined_members"] = len(joined_users)
defer.returnValue(result)
results.append(result)
result = []
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
chunk_result = yield defer.gatherResults([
handle_room(room_id)
for room_id in chunk
], consumeErrors=True).addErrback(unwrapFirstError)
result.extend(v for v in chunk_result if v)
yield concurrently_execute(handle_room, room_ids, 10)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": result})
defer.returnValue({"start": "START", "end": "END", "chunk": results})
class RoomContextHandler(BaseHandler):

View file

@ -0,0 +1,722 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.member_linearizer = Linearizer()
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
prev_event_ids,
txn_id=None,
ratelimit=True,
):
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": membership}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": membership,
},
token_id=requester.access_token_id,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
)
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.hs.get_handlers().federation_handler.do_invite_join(
remote_room_hosts,
room_id,
user.to_string(),
content,
)
yield user_joined_room(self.distributor, user, room_id)
def reject_remote_invite(self, user_id, room_id, remote_room_hosts):
return self.hs.get_handlers().federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
user_id
)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
key = (target, room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
requester,
target,
room_id,
action,
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
if not remote_room_hosts:
remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state = yield self.state_handler.get_current_state(
room_id, latest_event_ids=latest_event_ids,
)
old_state = current_state.get((EventTypes.Member, target.to_string()))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
is_host_in_room = self.is_host_in_room(current_state)
if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content = {"membership": Membership.JOIN}
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
ret = yield self.remote_join(
remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
try:
ret = yield self.reject_remote_invite(
target.to_string(), room_id, remote_room_hosts
)
defer.returnValue(ret)
except SynapseError as e:
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
yield self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
def get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
Args:
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
def forget(self, user, room_id):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership != Membership.LEAVE:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
if membership:
yield self.store.forget(user_id, room_id)

View file

@ -17,9 +17,10 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user
from twisted.internet import defer
@ -35,6 +36,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [
"user",
"filter_collection",
"is_guest",
"request_key",
])
@ -136,8 +138,8 @@ class SyncHandler(BaseHandler):
super(SyncHandler, self).__init__(hs)
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache()
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
"""Get the sync for a client if we have new data for it now. Otherwise
@ -146,7 +148,19 @@ class SyncHandler(BaseHandler):
Returns:
A Deferred SyncResult.
"""
result = self.response_cache.get(sync_config.request_key)
if not result:
result = self.response_cache.set(
sync_config.request_key,
self._wait_for_sync_for_user(
sync_config, since_token, timeout, full_state
)
)
return result
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
full_state):
context = LoggingContext.current_context()
if context:
if since_token is None:
@ -236,58 +250,50 @@ class SyncHandler(BaseHandler):
joined = []
invited = []
archived = []
deferreds = []
room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
for room_list_chunk in room_list_chunks:
for event in room_list_chunk:
if event.membership == Membership.JOIN:
room_sync_deferred = preserve_fn(
self.full_state_sync_for_joined_room
)(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(joined.append)
deferreds.append(room_sync_deferred)
elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if sync_config.user.to_string() == event.sender:
continue
user_id = sync_config.user.to_string()
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_sync_deferred = preserve_fn(
self.full_state_sync_for_archived_room
)(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(archived.append)
deferreds.append(room_sync_deferred)
@defer.inlineCallbacks
def _generate_room_entry(event):
if event.membership == Membership.JOIN:
room_result = yield self.full_state_sync_for_joined_room(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
joined.append(room_result)
elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if user_id == event.sender:
return
yield defer.gatherResults(
deferreds, consumeErrors=True
).addErrback(unwrapFirstError)
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_result = yield self.full_state_sync_for_archived_room(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
archived.append(room_result)
yield concurrently_execute(_generate_room_entry, room_list, 10)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
@ -657,7 +663,8 @@ class SyncHandler(BaseHandler):
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None, recents=None, newly_joined_room=False):
"""
:returns a Deferred TimelineBatch
Returns:
a Deferred TimelineBatch
"""
with Measure(self.clock, "load_filtered_recents"):
filtering_factor = 2
@ -824,8 +831,11 @@ class SyncHandler(BaseHandler):
"""
Get the room state after the given event
:param synapse.events.EventBase event: event of interest
:return: A Deferred map from ((type, state_key)->Event)
Args:
event(synapse.events.EventBase): event of interest
Returns:
A Deferred map from ((type, state_key)->Event)
"""
state = yield self.store.get_state_for_event(event.event_id)
if event.is_state():
@ -836,9 +846,13 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state
:returns: A Deferred map from ((type, state_key)->Event)
Args:
room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state
Returns:
A Deferred map from ((type, state_key)->Event)
"""
last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1,
@ -859,15 +873,18 @@ class SyncHandler(BaseHandler):
""" Works out the differnce in state between the start of the timeline
and the previous sync.
:param str room_id
:param TimelineBatch batch: The timeline batch for the room that will
be sent to the user.
:param sync_config
:param str since_token: Token of the end of the previous batch. May be None.
:param str now_token: Token of the end of the current batch.
:param bool full_state: Whether to force returning the full state.
Args:
room_id(str):
batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
the room that will be sent to the user.
sync_config(synapse.handlers.sync.SyncConfig):
since_token(str|None): Token of the end of the previous batch. May
be None.
now_token(str): Token of the end of the current batch.
full_state(bool): Whether to force returning the full state.
:returns A new event dictionary
Returns:
A deferred new event dictionary
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
@ -939,11 +956,13 @@ class SyncHandler(BaseHandler):
Check if the user has just joined the given room (so should
be given the full state)
:param sync_config:
:param dict[(str,str), synapse.events.FrozenEvent] state_delta: the
difference in state since the last sync
Args:
sync_config(synapse.handlers.sync.SyncConfig):
state_delta(dict[(str,str), synapse.events.FrozenEvent]): the
difference in state since the last sync
:returns A deferred Tuple (state_delta, limited)
Returns:
A deferred Tuple (state_delta, limited)
"""
join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None)

View file

@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
import collections
import logging
import random
import time
logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ SERVER_CACHE = {}
_Server = collections.namedtuple(
"_Server", "priority weight host port"
"_Server", "priority weight host port expires"
)
@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
host=domain,
port=default_port,
priority=0,
weight=0
weight=0,
expires=0,
)
else:
self.default_server = None
@ -153,7 +155,13 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = []
try:
@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
continue
payload = answer.payload
host = str(payload.target)
srv_ttl = answer.ttl
try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue
ips = [
answer.payload.dottedQuad()
for answer in answers
if answer.type == dns.A and answer.payload
]
for answer in answers:
if answer.type == dns.A and answer.payload:
ip = answer.payload.dottedQuad()
host_ttl = min(srv_ttl, answer.ttl)
for ip in ips:
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight)
))
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + host_ttl,
))
servers.sort()
cache[service_name] = list(servers)

View file

@ -26,14 +26,19 @@ logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False):
"""Parse an integer parameter from the request string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: An int value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
Returns:
int|None: An int value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer.
"""
if name in request.args:
@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False):
def parse_boolean(request, name, default=None, required=False):
"""Parse a boolean parameter from the request query string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: A bool value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
Returns:
bool|None: A bool value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false".
"""
@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string.
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:param allowed_values (list): List of allowed values for the string,
or None if any value is allowed, defaults to None
:return: A string value or the default.
:raises
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
default (str|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values (list[str]): List of allowed values for the string,
or None if any value is allowed, defaults to None
Returns:
str|None: A string value or the default.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False,
def parse_json_value_from_request(request):
"""Parse a JSON value from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:returns: The JSON value.
:raises
Args:
request: the twisted HTTP request.
Returns:
The JSON value.
Raises:
SynapseError if the request body couldn't be decoded as JSON.
"""
try:
@ -143,8 +162,10 @@ def parse_json_value_from_request(request):
def parse_json_object_from_request(request):
"""Parse a JSON object from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:raises
Args:
request: the twisted HTTP request.
Raises:
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""

View file

@ -503,13 +503,14 @@ class Notifier(object):
def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen.
:param callback:
Gets called whenever an event happens. If this returns a truthy
value then ``wait_for_replication`` returns, otherwise it waits
for another event.
:param int timeout:
How many milliseconds to wait for callback return a truthy value.
:returns:
Args:
callback: Gets called whenever an event happens. If this returns a
truthy value then ``wait_for_replication`` returns, otherwise
it waits for another event.
timeout: How many milliseconds to wait for callback return a truthy
value.
Returns:
A deferred that resolves with the value returned by the callback.
"""
listener = _NotificationListener(None)

View file

@ -19,9 +19,11 @@ import copy
def list_with_base_rules(rawrules):
"""Combine the list of rules set by the user with the default push rules
:param list rawrules: The rules the user has modified or set.
:returns: A new list with the rules set by the user combined with the
defaults.
Args:
rawrules(list): The rules the user has modified or set.
Returns:
A new list with the rules set by the user combined with the defaults.
"""
ruleslist = []
@ -160,7 +162,27 @@ BASE_APPEND_OVRRIDE_RULES = [
'actions': [
'dont_notify',
]
}
},
# Will we sometimes want to know about people joining and leaving?
# Perhaps: if so, this could be expanded upon. Seems the most usual case
# is that we don't though. We add this override rule so that even if
# the room rule is set to notify, we don't get notifications about
# join/leave/avatar/displayname events.
# See also: https://matrix.org/jira/browse/SYN-607
{
'rule_id': 'global/override/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
}
],
'actions': [
'dont_notify'
]
},
]
@ -261,25 +283,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
}
]
},
# This is too simple: https://matrix.org/jira/browse/SYN-607
# Removing for now
# {
# 'rule_id': 'global/underride/.m.rule.member_event',
# 'conditions': [
# {
# 'kind': 'event_match',
# 'key': 'type',
# 'pattern': 'm.room.member',
# '_id': '_member',
# }
# ],
# 'actions': [
# 'notify', {
# 'set_tweak': 'highlight',
# 'value': False
# }
# ]
# },
{
'rule_id': 'global/underride/.m.rule.message',
'conditions': [

View file

@ -133,8 +133,9 @@ class PushRuleEvaluator:
enabled = self.enabled_map.get(r['rule_id'], None)
if enabled is not None and not enabled:
continue
if not r.get("enabled", True):
elif enabled is None and not r.get("enabled", True):
# if no override, check enabled on the rule itself
# (may have come from a base rule)
continue
conditions = r['conditions']

View file

@ -36,6 +36,7 @@ REQUIREMENTS = {
"sortedcontainers": ["sortedcontainers"],
"pysaml2==4.0.3": ["saml2==4.0.3"],
"pymacaroons-pynacl": ["pymacaroons"],
"pyjwt": ["jwt"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {

View file

@ -38,6 +38,7 @@ STREAM_NAMES = (
("backfill",),
("push_rules",),
("pushers",),
("state",),
)
@ -76,7 +77,7 @@ class ReplicationResource(Resource):
The response is a JSON object with keys for each stream with updates. Under
each key is a JSON object with:
* "postion": The current position of the stream.
* "position": The current position of the stream.
* "field_names": The names of the fields in each row.
* "rows": The updates as an array of arrays.
@ -123,6 +124,7 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@ -133,6 +135,7 @@ class ReplicationResource(Resource):
backfill_token,
push_rules_token,
pushers_token,
state_token,
))
@request_handler
@ -142,31 +145,43 @@ class ReplicationResource(Resource):
timeout = parse_integer(request, "timeout", 10 * 1000)
request.setHeader(b"Content-Type", b"application/json")
writer = _Writer(request)
@defer.inlineCallbacks
request_streams = {
name: parse_integer(request, name)
for names in STREAM_NAMES for name in names
}
request_streams["streams"] = parse_string(request, "streams")
def replicate():
current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token)
return self.replicate(request_streams, limit)
yield self.account_data(writer, current_token, limit)
yield self.events(writer, current_token, limit)
yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
yield self.pushers(writer, current_token, limit)
self.streams(writer, current_token)
result = yield self.notifier.wait_for_replication(replicate, timeout)
logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer.total)
request.write(json.dumps(result, ensure_ascii=False))
finish_request(request)
yield self.notifier.wait_for_replication(replicate, timeout)
@defer.inlineCallbacks
def replicate(self, request_streams, limit):
writer = _Writer()
current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token)
writer.finish()
yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams)
# TODO: implement limit
yield self.presence(writer, current_token, request_streams)
yield self.typing(writer, current_token, request_streams)
yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
def streams(self, writer, current_token):
request_token = parse_string(writer.request, "streams")
logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams")
streams = []
@ -191,32 +206,43 @@ class ReplicationResource(Resource):
)
@defer.inlineCallbacks
def events(self, writer, current_token, limit):
request_events = parse_integer(writer.request, "events")
request_backfill = parse_integer(writer.request, "backfill")
def events(self, writer, current_token, limit, request_streams):
request_events = request_streams.get("events")
request_backfill = request_streams.get("backfill")
if request_events is not None or request_backfill is not None:
if request_events is None:
request_events = current_token.events
if request_backfill is None:
request_backfill = current_token.backfill
events_rows, backfill_rows = yield self.store.get_all_new_events(
res = yield self.store.get_all_new_events(
request_backfill, request_events,
current_token.backfill, current_token.events,
limit
)
writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group"
))
writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group"
))
writer.write_header_and_rows(
"events", events_rows, ("position", "internal", "json")
"forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group")
)
writer.write_header_and_rows(
"backfill", backfill_rows, ("position", "internal", "json")
"backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group")
)
writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",)
)
@defer.inlineCallbacks
def presence(self, writer, current_token):
def presence(self, writer, current_token, request_streams):
current_position = current_token.presence
request_presence = parse_integer(writer.request, "presence")
request_presence = request_streams.get("presence")
if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates(
@ -229,10 +255,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def typing(self, writer, current_token):
def typing(self, writer, current_token, request_streams):
current_position = current_token.presence
request_typing = parse_integer(writer.request, "typing")
request_typing = request_streams.get("typing")
if request_typing is not None:
typing_rows = yield self.typing_handler.get_all_typing_updates(
@ -243,10 +269,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def receipts(self, writer, current_token, limit):
def receipts(self, writer, current_token, limit, request_streams):
current_position = current_token.receipts
request_receipts = parse_integer(writer.request, "receipts")
request_receipts = request_streams.get("receipts")
if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts(
@ -257,12 +283,12 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def account_data(self, writer, current_token, limit):
def account_data(self, writer, current_token, limit, request_streams):
current_position = current_token.account_data
user_account_data = parse_integer(writer.request, "user_account_data")
room_account_data = parse_integer(writer.request, "room_account_data")
tag_account_data = parse_integer(writer.request, "tag_account_data")
user_account_data = request_streams.get("user_account_data")
room_account_data = request_streams.get("room_account_data")
tag_account_data = request_streams.get("tag_account_data")
if user_account_data is not None or room_account_data is not None:
if user_account_data is None:
@ -288,10 +314,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def push_rules(self, writer, current_token, limit):
def push_rules(self, writer, current_token, limit, request_streams):
current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules")
push_rules = request_streams.get("push_rules")
if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates(
@ -303,10 +329,11 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def pushers(self, writer, current_token, limit):
def pushers(self, writer, current_token, limit, request_streams):
current_position = current_token.pushers
pushers = parse_integer(writer.request, "pushers")
pushers = request_streams.get("pushers")
if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit
@ -320,12 +347,30 @@ class ReplicationResource(Resource):
"position", "user_id", "app_id", "pushkey"
))
@defer.inlineCallbacks
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
state = request_streams.get("state")
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
def __init__(self, request):
def __init__(self):
self.streams = {}
self.request = request
self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None):
@ -344,13 +389,12 @@ class _Writer(object):
self.total += len(rows)
def finish(self):
self.request.write(json.dumps(self.streams, ensure_ascii=False))
finish_request(self.request)
return self.streams
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers"
"push_rules", "pushers", "state"
))):
__slots__ = []

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage._base import SQLBaseStore
from twisted.internet import defer
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
def stream_positions(self):
return {}
def process_replication(self, result):
return defer.succeed(None)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker(object):
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self.advance(_load_current_id(db_conn, table, column))
def advance(self, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
return self._current

View file

@ -0,0 +1,204 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.state import StateStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_event = DataStore.get_event.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
get_membership_changes_for_user = (
DataStore.get_membership_changes_for_user.__func__
)
get_room_events_max_id = DataStore.get_room_events_max_id.__func__
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
_set_before_and_after = DataStore._set_before_and_after
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_parse_events_txn = DataStore._parse_events_txn.__func__
_get_events_txn = DataStore._get_events_txn.__func__
_enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__
_fetch_events_txn = DataStore._fetch_events_txn.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row = DataStore._get_event_from_row.__func__
_get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
state_resets = set(
r[0] for r in result.get("state_resets", {"rows": []})["rows"]
)
stream = result.get("events")
if stream:
self._stream_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
)
stream = result.get("backfill")
if stream:
self._backfill_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
)
stream = result.get("forward_ex_outliers")
if stream:
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled, state_resets):
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self._invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
)
def _invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id)
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
if not backfilled:
self._events_stream_cache.entity_has_changed(
event.room_id, event.internal_metadata.stream_ordering
)
# self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
# (event.room_id,)
# )
if event.type == EventTypes.Redaction:
self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
# self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering
# )
if not event.is_state():
return
if backfilled:
return
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))
if event.type in [EventTypes.Name, EventTypes.Aliases]:
self.get_room_name_and_aliases.invalidate(
(event.room_id,)
)
pass

View file

@ -33,6 +33,9 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
import jwt
from jwt.exceptions import InvalidTokenError
logger = logging.getLogger(__name__)
@ -43,12 +46,16 @@ class LoginRestServlet(ClientV1RestServlet):
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
@ -57,6 +64,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_GET(self, request):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
@ -98,6 +107,10 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
}
defer.returnValue((200, result))
elif self.jwt_enabled and (login_submission["type"] ==
LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
@ -209,6 +222,46 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
token = login_submission['token']
if token is None:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
except InvalidTokenError:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload['user']
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)

View file

@ -405,6 +405,42 @@ class RoomEventContext(ClientV1RestServlet):
defer.returnValue((200, results))
class RoomForgetRestServlet(ClientV1RestServlet):
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, txn_id=None):
requester = yield self.auth.get_user_by_req(
request,
allow_guest=False,
)
yield self.handlers.room_member_handler.forget(
user=requester.user,
room_id=room_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(
request, room_id, txn_id
)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
@ -624,6 +660,7 @@ def register_servlets(hs, http_server):
RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)

View file

@ -115,6 +115,8 @@ class SyncRestServlet(RestServlet):
)
)
request_key = (user, timeout, since, filter_id, full_state)
if filter_id:
if filter_id.startswith('{'):
try:
@ -134,6 +136,7 @@ class SyncRestServlet(RestServlet):
user=user,
filter_collection=filter,
is_guest=requester.is_guest,
request_key=request_key,
)
if since is not None:
@ -196,15 +199,17 @@ class SyncRestServlet(RestServlet):
"""
Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
Args:
rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
results for rooms this user is joined to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
:return: the joined rooms list, in our response format
:rtype: dict[str, dict[str, object]]
Returns:
dict[str, dict[str, object]]: the joined rooms list, in our
response format
"""
joined = {}
for room in rooms:
@ -218,15 +223,17 @@ class SyncRestServlet(RestServlet):
"""
Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
sync results for rooms this user is joined to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
Returns:
dict[str, dict[str, object]]: the invited rooms list, in our
response format
"""
invited = {}
for room in rooms:
@ -248,15 +255,17 @@ class SyncRestServlet(RestServlet):
"""
Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
Args:
rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
sync results for rooms this user is joined to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
Returns:
dict[str, dict[str, object]]: The invited rooms list, in our
response format
"""
joined = {}
for room in rooms:
@ -269,17 +278,18 @@ class SyncRestServlet(RestServlet):
@staticmethod
def encode_room(room, time_now, token_id, joined=True):
"""
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a
single room
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:param joined: True if the user is joined to this room - will mean
we handle ephemeral events
Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
single room
time_now (int): current time - used as a baseline for age
calculations
token_id (int): ID of the user's auth token - used for namespacing
of transaction IDs
joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events
:return: the room, encoded in our response format
:rtype: dict[str, object]
Returns:
dict[str, object]: the room, encoded in our response format
"""
def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter.

View file

@ -75,7 +75,8 @@ class StateHandler(object):
self._state_cache.start()
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
def get_current_state(self, room_id, event_type=None, state_key="",
latest_event_ids=None):
""" Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
@ -86,11 +87,13 @@ class StateHandler(object):
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
:returns map from (type, state_key) to event
Returns:
map from (type, state_key) to event
"""
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
res = yield self.resolve_state_groups(room_id, event_ids)
res = yield self.resolve_state_groups(room_id, latest_event_ids)
state = res[1]
if event_type:
@ -100,7 +103,7 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None, outlier=False):
def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
@ -115,7 +118,7 @@ class StateHandler(object):
"""
context = EventContext()
if outlier:
if event.internal_metadata.is_outlier():
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
@ -176,10 +179,11 @@ class StateHandler(object):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
:returns a Deferred tuple of (`state_group`, `state`, `prev_state`).
`state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
Returns:
a Deferred tuple of (`state_group`, `state`, `prev_state`).
`state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@ -251,9 +255,10 @@ class StateHandler(object):
def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
:returns a tuple (new_state, prev_states). new_state is a map
from (type, state_key) to event. prev_states is a list of event_ids.
:rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str])
Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
(new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
"""
with Measure(self.clock, "state._resolve_events"):
state = {}

View file

@ -88,22 +88,17 @@ class DataStore(RoomMemberStore, RoomStore,
self.hs = hs
self.database_engine = hs.database_engine
cur = db_conn.cursor()
try:
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
self.min_stream_token = min(self.min_stream_token, -1)
finally:
cur.close()
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering"
db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
@ -116,7 +111,7 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@ -129,7 +124,7 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")],
)
events_max = self._stream_id_gen.get_max_token()
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
@ -145,7 +140,7 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max,
)
account_max = self._account_data_id_gen.get_max_token()
account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
@ -156,7 +151,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(),
max_value=self._presence_id_gen.get_current_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val,
@ -167,7 +162,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0],
max_value=self._push_rules_stream_id_gen.get_current_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
@ -182,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
self.__presence_on_startup = None
return active_on_startup
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.

View file

@ -810,11 +810,39 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values())
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
class _RollbackButIsFineException(Exception):

View file

@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id
)
result = self._account_data_id_gen.get_max_token()
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@defer.inlineCallbacks
@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id
)
result = self._account_data_id_gen.get_max_token()
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):

View file

@ -26,13 +26,13 @@ SUPPORTED_MODULE = {
}
def create_engine(config):
name = config.database_config["name"]
def create_engine(database_config):
name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
module = importlib.import_module(name)
return engine_class(module, config=config)
return engine_class(module)
raise RuntimeError(
"Unsupported database engine '%s'" % (name,)

View file

@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import prepare_database
from ._base import IncorrectDatabaseSetup
class PostgresEngine(object):
single_threaded = False
def __init__(self, database_module, config):
def __init__(self, database_module):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
self.config = config
def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING")
@ -44,9 +41,6 @@ class PostgresEngine(object):
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
def prepare_database(self, db_conn):
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
return error.pgcode in ["40001", "40P01"]

View file

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import (
prepare_database, prepare_sqlite3_database
)
from synapse.storage.prepare_database import prepare_database
import struct
@ -23,9 +21,8 @@ import struct
class Sqlite3Engine(object):
single_threaded = True
def __init__(self, database_module, config):
def __init__(self, database_module):
self.module = database_module
self.config = config
def check_database(self, txn):
pass
@ -34,13 +31,9 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
self.prepare_database(db_conn)
prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error):
return False

View file

@ -163,6 +163,22 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
@defer.inlineCallbacks
def get_max_depth_of_events(self, event_ids):
sql = (
"SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
) % (",".join(["?"] * len(event_ids)),)
rows = yield self._execute(
"get_max_depth_of_events", None,
sql, *event_ids
)
if rows:
defer.returnValue(rows[0][0])
else:
defer.returnValue(1)
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
txn,

View file

@ -26,8 +26,9 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore):
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
:param event: the event set actions for
:param tuples: list of tuples of (user_id, actions)
Args:
event: the event set actions for
tuples: list of tuples of (user_id, actions)
"""
values = []
for uid, actions in tuples:

View file

@ -24,7 +24,7 @@ from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from canonicaljson import encode_canonical_json
from contextlib import contextmanager
from collections import namedtuple
import logging
import math
@ -60,64 +60,71 @@ class EventsStore(SQLBaseStore):
)
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False,
is_new_state=True):
def persist_events(self, events_and_contexts, backfilled=False):
if not events_and_contexts:
return
if backfilled:
start = self.min_stream_token - 1
self.min_stream_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_stream_token, -1)
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
for (event, _), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
with state_group_id_manager as state_group_ids:
for (event, context), stream, state_group_id in zip(
events_and_contexts, stream_orderings, state_group_ids
):
event.internal_metadata.stream_ordering = stream
# Assign a state group_id in case a new id is needed for
# this context. In theory we only need to assign this
# for contexts that have current_state and aren't outliers
# but that make the code more complicated. Assigning an ID
# per event only causes the state_group_ids to grow as fast
# as the stream_ordering so in practise shouldn't be a problem.
context.new_state_group_id = state_group_id
chunks = [
events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
chunks = [
events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
is_new_state=is_new_state,
)
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
)
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context,
is_new_state=True, current_state=None):
def persist_event(self, event, context, current_state=None):
try:
with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
is_new_state=is_new_state,
current_state=current_state,
)
with self._state_groups_id_gen.get_next() as state_group_id:
event.internal_metadata.stream_ordering = stream_ordering
context.new_state_group_id = state_group_id
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
)
except _RollbackButIsFineException:
pass
max_persisted_id = yield self._stream_id_gen.get_max_token()
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks
@ -177,8 +184,7 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@log_function
def _persist_event_txn(self, txn, event, context,
is_new_state=True, current_state=None):
def _persist_event_txn(self, txn, event, context, current_state):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
@ -186,7 +192,16 @@ class EventsStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id)
txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = event.internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn(
txn,
@ -210,12 +225,10 @@ class EventsStore(SQLBaseStore):
txn,
[(event, context)],
backfilled=False,
is_new_state=is_new_state,
)
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
is_new_state=True):
def _persist_events_txn(self, txn, events_and_contexts, backfilled):
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@ -282,9 +295,7 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
self._store_state_groups_txn(
txn, event, context,
)
self._store_mult_state_groups_txn(txn, ((event, context),))
metadata_json = encode_json(
event.internal_metadata.get_dict()
@ -299,6 +310,18 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,)
)
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
"event_stream_ordering": stream_order,
"event_id": event.event_id,
"state_group": state_group_id,
}
)
sql = (
"UPDATE events SET outlier = ?"
" WHERE event_id = ?"
@ -310,19 +333,14 @@ class EventsStore(SQLBaseStore):
self._update_extremeties(txn, [event])
events_and_contexts = filter(
lambda ec: ec[0] not in to_remove,
events_and_contexts
)
events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
if not events_and_contexts:
return
self._store_mult_state_groups_txn(txn, [
(event, context)
for event, context in events_and_contexts
if not event.internal_metadata.is_outlier()
])
self._store_mult_state_groups_txn(txn, events_and_contexts)
self._handle_mult_prev_events(
txn,
@ -349,7 +367,8 @@ class EventsStore(SQLBaseStore):
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
]
],
backfilled=backfilled,
)
def event_dict(event):
@ -421,10 +440,9 @@ class EventsStore(SQLBaseStore):
txn, [event for event, _ in events_and_contexts]
)
state_events_and_contexts = filter(
lambda i: i[0].is_state(),
events_and_contexts,
)
state_events_and_contexts = [
ec for ec in events_and_contexts if ec[0].is_state()
]
state_values = []
for event, context in state_events_and_contexts:
@ -462,32 +480,44 @@ class EventsStore(SQLBaseStore):
],
)
if is_new_state:
for event, _ in state_events_and_contexts:
if not context.rejected:
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
if backfilled:
# Backfilled events come before the current state so we don't need
# to update the current state table
return
if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after(
self.get_room_name_and_aliases.invalidate,
(event.room_id,)
)
for event, _ in state_events_and_contexts:
if event.internal_metadata.is_outlier():
# Outlier events shouldn't clobber the current state.
continue
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
if context.rejected:
# If the event failed it's auth checks then it shouldn't
# clobbler the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after(
self.get_room_name_and_aliases.invalidate,
(event.room_id,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return
@ -1076,10 +1106,7 @@ class EventsStore(SQLBaseStore):
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
# TODO: Fix race with the persit_event txn by using one of the
# stream id managers
return -self.min_stream_token
return -self._backfill_id_gen.get_current_token()
def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit):
@ -1087,10 +1114,12 @@ class EventsStore(SQLBaseStore):
new events or as backfilled events"""
def get_all_new_events_txn(txn):
sql = (
"SELECT e.stream_ordering, ej.internal_metadata, ej.json"
"SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
" FROM events as e"
" JOIN event_json as ej"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
" LEFT JOIN event_to_state_groups as eg"
" ON e.event_id = eg.event_id"
" WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
@ -1098,14 +1127,43 @@ class EventsStore(SQLBaseStore):
if last_forward_id != current_forward_id:
txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall()
if len(new_forward_events) == limit:
upper_bound = new_forward_events[-1][0]
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering FROM current_state_resets"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_forward_id, upper_bound))
state_resets = txn.fetchall()
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_forward_id, upper_bound))
forward_ex_outliers = txn.fetchall()
else:
new_forward_events = []
state_resets = []
forward_ex_outliers = []
sql = (
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json"
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
" eg.state_group"
" FROM events as e"
" JOIN event_json as ej"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
" LEFT JOIN event_to_state_groups as eg"
" ON e.event_id = eg.event_id"
" WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
" ORDER BY e.stream_ordering DESC"
" LIMIT ?"
@ -1113,8 +1171,35 @@ class EventsStore(SQLBaseStore):
if last_backfill_id != current_backfill_id:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall()
if len(new_backfill_events) == limit:
upper_bound = new_backfill_events[-1][0]
else:
upper_bound = current_backfill_id
sql = (
"SELECT -event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_backfill_id, -upper_bound))
backward_ex_outliers = txn.fetchall()
else:
new_backfill_events = []
backward_ex_outliers = []
return (new_forward_events, new_backfill_events)
return AllNewEventsResult(
new_forward_events, new_backfill_events,
forward_ex_outliers, backward_ex_outliers,
state_resets,
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
"forward_ex_outliers", "backward_ex_outliers",
"state_resets"
])

View file

@ -25,23 +25,11 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 30
SCHEMA_VERSION = 31
dir_path = os.path.abspath(os.path.dirname(__file__))
def read_schema(path):
""" Read the named database schema.
Args:
path: Path of the database schema.
Returns:
A string containing the database schema.
"""
with open(path) as schema_file:
return schema_file.read()
class PrepareDatabaseException(Exception):
pass
@ -53,6 +41,9 @@ class UpgradeDatabaseException(PrepareDatabaseException):
def prepare_database(db_conn, database_engine, config):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty.
"""
try:
cur = db_conn.cursor()
@ -60,13 +51,18 @@ def prepare_database(db_conn, database_engine, config):
if version_info:
user_version, delta_files, upgraded = version_info
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config
)
else:
_setup_new_database(cur, database_engine, config)
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
if config is None:
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
raise UpgradeDatabaseException("Database needs to be upgraded")
else:
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config
)
else:
_setup_new_database(cur, database_engine)
cur.close()
db_conn.commit()
@ -75,7 +71,7 @@ def prepare_database(db_conn, database_engine, config):
raise
def _setup_new_database(cur, database_engine, config):
def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
@ -148,12 +144,13 @@ def _setup_new_database(cur, database_engine, config):
applied_delta_files=[],
upgraded=False,
database_engine=database_engine,
config=config,
config=None,
is_empty=True,
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded, database_engine, config):
upgraded, database_engine, config, is_empty=False):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
@ -246,7 +243,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine, config=config)
module.run_create(cur, database_engine)
if not is_empty:
module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
@ -361,36 +360,3 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded
return None
def prepare_sqlite3_database(db_conn):
"""This function should be called before `prepare_database` on sqlite3
databases.
Since we changed the way we store the current schema version and handle
updates to schemas, we need a way to upgrade from the old method to the
new. This only affects sqlite databases since they were the only ones
supported at the time.
"""
with db_conn:
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
db_conn.executescript(create_schema)
c = db_conn.execute("SELECT * FROM schema_version")
rows = c.fetchall()
c.close()
if not rows:
c = db_conn.execute("PRAGMA user_version")
row = c.fetchone()
c.close()
if row and row[0]:
db_conn.execute(
"REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)

View file

@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore):
self._update_presence_txn, stream_orderings, presence_states,
)
defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token()))
defer.returnValue((
stream_orderings[-1], self._presence_id_gen.get_current_token()
))
def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states):
@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore):
defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self):
return self._presence_id_gen.get_max_token()
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
@ -174,16 +176,6 @@ class PresenceStore(SQLBaseStore):
desc="disallow_presence_visible",
)
def is_presence_visible(self, observed_localpart, observer_userid):
return self._simple_select_one(
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
retcols=["observed_user_id"],
allow_none=True,
desc="is_presence_visible",
)
def add_presence_list_pending(self, observer_localpart, observed_userid):
return self._simple_insert(
table="presence_list",

View file

@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token()
return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):

View file

@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows)
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_max_token()
return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit):
def get_all_updated_pushers_txn(txn):

View file

@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
@cached(num_args=2)
@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore):
"content": content,
}])
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
num_args=3, inlineCallbacks=True)
@cachedList(cached_method_name="get_linearized_receipts_for_room",
list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
defer.returnValue({})
@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token()
return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data
)
max_persisted_id = self._stream_id_gen.get_max_token()
max_persisted_id = self._stream_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id))

View file

@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
@cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1,
inlineCallbacks=True)
def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
@ -458,12 +458,15 @@ class RegistrationStore(SQLBaseStore):
"""
Gets the 3pid's guest access token if exists, else saves access_token.
:param medium (str): Medium of the 3pid. Must be "email".
:param address (str): 3pid address.
:param access_token (str): The access token to persist if none is
already persisted.
:param inviter_user_id (str): User ID of the inviter.
:return (deferred str): Whichever access token is persisted at the end
Args:
medium (str): Medium of the 3pid. Must be "email".
address (str): 3pid address.
access_token (str): The access token to persist if none is
already persisted.
inviter_user_id (str): User ID of the inviter.
Returns:
deferred str: Whichever access token is persisted at the end
of this function call.
"""
def insert(txn):

View file

@ -36,7 +36,7 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
def _store_room_members_txn(self, txn, events):
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
@ -62,27 +62,65 @@ class RoomMemberStore(SQLBaseStore):
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened.
# The only current event that can also be an outlier is if its an
# invite that has come in across federation.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_invite_from_remote()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
}
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
Args:
user_id (str): The member's user ID.
room_id (str): The room the member is in.
Returns:
Deferred: Results in a MembershipEvent or None.
"""
return self.runInteraction(
"get_room_member",
self._get_members_events_txn,
room_id,
user_id=user_id,
).addCallback(
self._get_events
).addCallback(
lambda events: events[0] if events else None
txn.execute(sql, (
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
))
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (
stream_ordering,
True,
room_id,
user_id,
))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cached(max_entries=5000)
def get_users_in_room(self, room_id):
def f(txn):
@ -127,18 +165,23 @@ class RoomMemberStore(SQLBaseStore):
user_id, [Membership.INVITE]
)
def get_leave_and_ban_events_for_user(self, user_id):
""" Get all the leave events for a user
@defer.inlineCallbacks
def get_invite_for_user_in_room(self, user_id, room_id):
"""Gets the invite for the given user and room
Args:
user_id (str): The user ID.
user_id (str)
room_id (str)
Returns:
A deferred list of event objects.
Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
"""
return self.get_rooms_for_user_where_membership_is(
user_id, (Membership.LEAVE, Membership.BAN)
).addCallback(lambda leaves: self._get_events([
leave.event_id for leave in leaves
]))
invites = yield self.get_invited_rooms_for_user(user_id)
for invite in invites:
if invite.room_id == room_id:
defer.returnValue(invite)
defer.returnValue(None)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@ -163,29 +206,55 @@ class RoomMemberStore(SQLBaseStore):
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
args = [user_id]
args.extend(membership_list)
do_invite = Membership.INVITE in membership_list
membership_list = [m for m in membership_list if m != Membership.INVITE]
sql = (
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM current_state_events as c"
" INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id"
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
) % (where_clause,)
results = []
if membership_list:
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
txn.execute(sql, args)
return [
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
args = [user_id]
args.extend(membership_list)
sql = (
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM current_state_events as c"
" INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id"
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
) % (where_clause,)
txn.execute(sql, args)
results = [
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
if do_invite:
sql = (
"SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
" FROM local_invites as i"
" INNER JOIN events as e USING (event_id)"
" WHERE invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(sql, (user_id,))
results.extend(RoomsForUser(
room_id=r["room_id"],
sender=r["inviter"],
event_id=r["event_id"],
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
) for r in self.cursor_to_dict(txn))
return results
@cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):

View file

@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, *args, **kwargs):
def run_create(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:
@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs):
"UPDATE application_services_regex SET regex=? WHERE id=?",
(new_regex, row[0])
)
def run_upgrade(*args, **kwargs):
pass

View file

@ -27,7 +27,7 @@ import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs):
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 (
@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count)
def run_upgrade(*args, **kwargs):
pass

View file

@ -43,7 +43,7 @@ SQLITE_TABLE = (
)
def run_upgrade(cur, database_engine, *args, **kwargs):
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement)
@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search", progress_json))
def run_upgrade(*args, **kwargs):
pass

View file

@ -27,7 +27,7 @@ ALTER_TABLE = (
)
def run_upgrade(cur, database_engine, *args, **kwargs):
def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement)
@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json))
def run_upgrade(*args, **kwargs):
pass

View file

@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, config, *args, **kwargs):
def run_create(cur, database_engine, *args, **kwargs):
# NULL indicates user was not registered by an appservice.
try:
cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
# Maybe we already added the column? Hope so...
pass
def run_upgrade(cur, database_engine, config, *args, **kwargs):
cur.execute("SELECT name FROM users")
rows = cur.fetchall()

View file

@ -0,0 +1,38 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* The positions in the event stream_ordering when the current_state was
* replaced by the state at the event.
*/
CREATE TABLE IF NOT EXISTS current_state_resets(
event_stream_ordering BIGINT PRIMARY KEY NOT NULL
);
/* The outlier events that have aquired a state group typically through
* backfill. This is tracked separately to the events table, as assigning a
* state group change the position of the existing event in the stream
* ordering.
* However since a stream_ordering is assigned in persist_event for the
* (event, state) pair, we can use that stream_ordering to identify when
* the new state was assigned for the event.
*/
CREATE TABLE IF NOT EXISTS ex_outlier_stream(
event_stream_ordering BIGINT PRIMARY KEY NOT NULL,
event_id TEXT NOT NULL,
state_group BIGINT NOT NULL
);

View file

@ -0,0 +1,42 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE local_invites(
stream_id BIGINT NOT NULL,
inviter TEXT NOT NULL,
invitee TEXT NOT NULL,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
locally_rejected TEXT,
replaced_by TEXT
);
-- Insert all invites for local users into new `invites` table
INSERT INTO local_invites SELECT
stream_ordering as stream_id,
sender as inviter,
state_key as invitee,
event_id,
room_id,
NULL as locally_rejected,
NULL as replaced_by
FROM events
NATURAL JOIN current_state_events
NATURAL JOIN room_memberships
WHERE membership = 'invite' AND state_key IN (SELECT name FROM users);
CREATE INDEX local_invites_id ON local_invites(stream_id);
CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id);

View file

@ -64,12 +64,12 @@ class StateStore(SQLBaseStore):
for group, state_map in group_to_state.items()
})
def _store_state_groups_txn(self, txn, event, context):
return self._store_mult_state_groups_txn(txn, [(event, context)])
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state is None:
continue
@ -82,7 +82,8 @@ class StateStore(SQLBaseStore):
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = self._state_groups_id_gen.get_next()
state_group = context.new_state_group_id
self._simple_insert_txn(
txn,
table="state_groups",
@ -114,11 +115,10 @@ class StateStore(SQLBaseStore):
table="event_to_state_groups",
values=[
{
"state_group": state_groups[event.event_id],
"event_id": event.event_id,
"state_group": state_group_id,
"event_id": event_id,
}
for event, context in events_and_contexts
if context.current_state is not None
for event_id, state_group_id in state_groups.items()
],
)
@ -249,11 +249,14 @@ class StateStore(SQLBaseStore):
"""
Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned
:param list[(str, str)]|None types: List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
:return: a deferred dict from (type, state_key) -> state_event
Args:
event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@ -270,8 +273,8 @@ class StateStore(SQLBaseStore):
desc="_get_state_group_for_event",
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
num_args=1, inlineCallbacks=True)
@cachedList(cached_method_name="_get_state_group_for_event",
list_name="event_ids", num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
@ -429,3 +432,33 @@ class StateStore(SQLBaseStore):
}
defer.returnValue(results)
def get_all_new_state_groups(self, last_id, current_id, limit):
def get_all_new_state_groups_txn(txn):
sql = (
"SELECT id, room_id, event_id FROM state_groups"
" WHERE ? < id AND id <= ? ORDER BY id LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
groups = txn.fetchall()
if not groups:
return ([], [])
lower_bound = groups[0][0]
upper_bound = groups[-1][0]
sql = (
"SELECT state_group, type, state_key, event_id"
" FROM state_groups_state"
" WHERE ? <= state_group AND state_group <= ?"
)
txn.execute(sql, (lower_bound, upper_bound))
state_group_state = txn.fetchall()
return (groups, state_group_state)
return self.runInteraction(
"get_all_new_state_groups", get_all_new_state_groups_txn
)
def get_state_stream_token(self):
return self._state_groups_id_gen.get_current_token()

View file

@ -303,96 +303,6 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
def get_room_events_stream(
self,
user_id,
from_key,
to_key,
limit=0,
is_guest=False,
room_ids=None
):
room_ids = room_ids or []
room_ids = [r for r in room_ids]
if is_guest:
current_room_membership_sql = (
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s)"
" AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
current_room_membership_args = room_ids
else:
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
# We also want to get any membership events about that user, e.g.
# invites or leave notifications.
membership_sql = (
"SELECT m.event_id FROM room_memberships as m "
"INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ? "
)
membership_args = [user_id]
if limit:
limit = max(limit, MAX_STREAM_SIZE)
else:
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key:
return defer.succeed(([], to_key))
sql = (
"SELECT e.event_id, e.stream_ordering FROM events AS e WHERE "
"(e.outlier = ? AND (room_id IN (%(current)s)) OR "
"(event_id IN (%(invites)s))) "
"AND e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
) % {
"current": current_room_membership_sql,
"invites": membership_sql,
"limit": limit
}
def f(txn):
args = ([False] + current_room_membership_args + membership_args +
[from_id.stream, to_id.stream])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows)
if rows:
key = "s%d" % max(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = to_key
return ret, key
return self.runInteraction("get_room_events_stream", f)
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1):
@ -539,7 +449,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token()
token = yield self._stream_id_gen.get_current_token()
if direction != 'b':
defer.returnValue("s%d" % (token,))
else:

View file

@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns:
A deferred int.
"""
return self._account_data_id_gen.get_max_token()
return self._account_data_id_gen.get_current_token()
@cached()
def get_tags_for_user(self, user_id):
@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token()
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@defer.inlineCallbacks
@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token()
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):

View file

@ -21,7 +21,7 @@ import threading
class IdGenerator(object):
def __init__(self, db_conn, table, column):
self._lock = threading.Lock()
self._next_id = _load_max_id(db_conn, table, column)
self._next_id = _load_current_id(db_conn, table, column)
def get_next(self):
with self._lock:
@ -29,12 +29,16 @@ class IdGenerator(object):
return self._next_id
def _load_max_id(db_conn, table, column):
def _load_current_id(db_conn, table, column, step=1):
cur = db_conn.cursor()
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone()
cur.close()
return int(val) if val else 1
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator(object):
@ -45,17 +49,32 @@ class StreamIdGenerator(object):
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
Args:
db_conn(connection): A database connection to use to fetch the
initial value of the generator from.
table(str): A database table to read the initial value of the id
generator from.
column(str): The column of the database table to read the initial
value from the id generator from.
extra_tables(list): List of pairs of database tables and columns to
use to source the initial value of the generator from. The value
with the largest magnitude is used.
step(int): which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards.
Usage:
with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
def __init__(self, db_conn, table, column, extra_tables=[]):
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
assert step != 0
self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
self._step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self._current_max = max(
self._current_max,
_load_max_id(db_conn, table, column)
self._current = (max if step > 0 else min)(
self._current,
_load_current_id(db_conn, table, column, step)
)
self._unfinished_ids = deque()
@ -66,8 +85,8 @@ class StreamIdGenerator(object):
# ... persist event ...
"""
with self._lock:
self._current_max += 1
next_id = self._current_max
self._current += self._step
next_id = self._current
self._unfinished_ids.append(next_id)
@ -88,8 +107,12 @@ class StreamIdGenerator(object):
# ... persist events ...
"""
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
next_ids = range(
self._current + self._step,
self._current + self._step * (n + 1),
self._step
)
self._current += n * self._step
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@ -105,15 +128,15 @@ class StreamIdGenerator(object):
return manager()
def get_max_token(self):
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
return self._unfinished_ids[0] - 1
return self._unfinished_ids[0] - self._step
return self._current_max
return self._current
class ChainedIdGenerator(object):
@ -125,7 +148,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque()
def get_next(self):
@ -137,7 +160,7 @@ class ChainedIdGenerator(object):
with self._lock:
self._current_max += 1
next_id = self._current_max
chained_id = self.chained_generator.get_max_token()
chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id))
@ -151,7 +174,7 @@ class ChainedIdGenerator(object):
return manager()
def get_max_token(self):
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
@ -160,4 +183,4 @@ class ChainedIdGenerator(object):
stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token())
return (self._current_max, self.chained_generator.get_current_token())

View file

@ -49,9 +49,6 @@ class Clock(object):
l.start(msec / 1000.0, now=False)
return l
def stop_looping_call(self, loop):
loop.stop()
def call_later(self, delay, callback, *args, **kwargs):
"""Call something later

View file

@ -16,7 +16,12 @@
from twisted.internet import defer, reactor
from .logcontext import PreserveLoggingContext
from .logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util import unwrapFirstError
from contextlib import contextmanager
@defer.inlineCallbacks
@ -107,3 +112,76 @@ class ObservableDeferred(object):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)
def concurrently_execute(func, args, limit):
"""Executes the function with each argument conncurrently while limiting
the number of concurrent executions.
Args:
func (func): Function to execute, should return a deferred.
args (list): List of arguments to pass to func, each invocation of func
gets a signle argument.
limit (int): Maximum number of conccurent executions.
Returns:
deferred: Resolved when all function invocations have finished.
"""
it = iter(args)
@defer.inlineCallbacks
def _concurrently_execute_inner():
try:
while True:
yield func(it.next())
except StopIteration:
pass
return defer.gatherResults([
preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit)
], consumeErrors=True).addErrback(unwrapFirstError)
class Linearizer(object):
"""Linearizes access to resources based on a key. Useful to ensure only one
thing is happening at a time on a given resource.
Example:
with (yield linearizer.queue("test_key")):
# do some work.
"""
def __init__(self):
self.key_to_defer = {}
@defer.inlineCallbacks
def queue(self, key):
# If there is already a deferred in the queue, we pull it out so that
# we can wait on it later.
# Then we replace it with a deferred that we resolve *after* the
# context manager has exited.
# We only return the context manager after the previous deferred has
# resolved.
# This all has the net effect of creating a chain of deferreds that
# wait for the previous deferred before starting their work.
current_defer = self.key_to_defer.get(key)
new_defer = defer.Deferred()
self.key_to_defer[key] = new_defer
if current_defer:
yield preserve_context_over_deferred(current_defer)
@contextmanager
def _ctx_manager():
try:
yield
finally:
new_defer.callback(None)
current_d = self.key_to_defer.get(key)
if current_d is new_defer:
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())

View file

@ -167,7 +167,8 @@ class CacheDescriptor(object):
% (orig.__name__,)
)
self.cache = Cache(
def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
@ -175,14 +176,12 @@ class CacheDescriptor(object):
tree=self.tree,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
cached_result_d = self.cache.get(cache_key)
cached_result_d = cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@ -204,7 +203,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = self.cache.sequence
sequence = cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
@ -213,20 +212,21 @@ class CacheDescriptor(object):
)
def onErr(f):
self.cache.invalidate(cache_key)
cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
cache.update(sequence, cache_key, ret)
return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.invalidate_many = self.cache.invalidate_many
wrapped.prefill = self.cache.prefill
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
@ -240,11 +240,12 @@ class CacheListDescriptor(object):
the list of missing keys to the wrapped fucntion.
"""
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
def __init__(self, orig, cached_method_name, list_name, num_args=1,
inlineCallbacks=False):
"""
Args:
orig (function)
cache (Cache)
method_name (str); The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
@ -263,7 +264,7 @@ class CacheListDescriptor(object):
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
self.cached_method_name = cached_method_name
self.sentinel = object()
@ -277,11 +278,13 @@ class CacheListDescriptor(object):
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
% (self.list_name, cache.name,)
% (self.list_name, cached_method_name,)
)
def __get__(self, obj, objtype=None):
cache = getattr(obj, self.cached_method_name).cache
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
@ -297,14 +300,14 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
res = self.cache.get(tuple(key)).observe()
res = cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
except KeyError:
missing.append(arg)
if missing:
sequence = self.cache.sequence
sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@ -327,10 +330,10 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer)
cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
self.cache.invalidate(key)
cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
)
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
"""
return lambda orig: CacheListDescriptor(
orig,
cache=cache,
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,

View file

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.async import ObservableDeferred
class ResponseCache(object):
"""
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
while the response is still being computed, that original response will be
used rather than trying to compute a new response.
"""
def __init__(self):
self.pending_result_cache = {} # Requests that haven't finished yet.
def get(self, key):
result = self.pending_result_cache.get(key)
if result is not None:
return result.observe()
else:
return None
def set(self, key, deferred):
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
def remove(r):
self.pending_result_cache.pop(key, None)
return r
result.addBoth(remove)
return result.observe()

View file

@ -15,7 +15,9 @@
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_fn
)
from synapse.util import unwrapFirstError
@ -25,6 +27,24 @@ import logging
logger = logging.getLogger(__name__)
def registered_user(distributor, user):
return distributor.fire("registered_user", user)
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class Distributor(object):
"""A central dispatch point for loosely-connected pieces of code to
register, observe, and fire signals.

View file

@ -100,20 +100,6 @@ class _PerHostRatelimiter(object):
self.current_processing = set()
self.request_times = []
def is_empty(self):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
return not (
self.ready_request_queue
or self.sleeping_requests
or self.current_processing
or self.request_times
)
@contextlib.contextmanager
def ratelimit(self):
# `contextlib.contextmanager` takes a generator and turns it into a

View file

@ -21,10 +21,6 @@ _string_with_symbols = (
)
def origin_from_ucid(ucid):
return ucid.split("@", 1)[1]
def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,57 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.replication.slave.storage.events import SlavedEventStore
from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource
class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
"blue",
http_client=None,
replication_layer=Mock(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
)
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore()
self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs)
self.event_id = 0
@defer.inlineCallbacks
def replicate(self):
streams = self.slaved_store.stream_positions()
result = yield self.replication.replicate(streams, 100)
yield self.slaved_store.process_replication(result)
@defer.inlineCallbacks
def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args)
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)
self.assertEqual(master_result, slaved_result)

View file

@ -0,0 +1,307 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStoreTestCase
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
from synapse.storage.roommember import RoomsForUser
from twisted.internet import defer
USER_ID = "@feeling:blue"
USER_ID_2 = "@bright:blue"
OUTLIER = {"outlier": True}
ROOM_ID = "!room:blue"
def dict_equals(self, other):
return self.__dict__ == other.__dict__
def patch__eq__(cls):
eq = getattr(cls, "__eq__", None)
cls.__eq__ = dict_equals
def unpatch():
if eq is not None:
cls.__eq__ = eq
return unpatch
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def setUp(self):
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals
self.unpatches = [
patch__eq__(_EventInternalMetadata),
patch__eq__(FrozenEvent),
]
return super(SlavedEventStoreTestCase, self).setUp()
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
def test_room_name_and_aliases(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.persist(type="m.room.name", key="", name="name1")
yield self.persist(
type="m.room.aliases", key="blue", aliases=["#1:blue"]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"])
)
# Set the room name.
yield self.persist(type="m.room.name", key="", name="name2")
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"])
)
# Set the room aliases.
yield self.persist(
type="m.room.aliases", key="blue", aliases=["#2:blue"]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"])
)
# Leave and join the room clobbering the state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), (None, [])
)
@defer.inlineCallbacks
def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Join the room.
join = yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
# Leave the room.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Add some other user to the room.
join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
# Join the room clobbering the state.
# This should remove any evidence of the other user being in the room.
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
yield self.check("get_rooms_for_user", (USER_ID_2,), [])
@defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check(
"get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]
)
join = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
prev_events=[(create.event_id, {})],
)
yield self.replicate()
yield self.check(
"get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
)
@defer.inlineCallbacks
def test_get_current_state(self):
# Create the room.
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
)
# Join the room.
join1 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join1]
)
# Add some other user to the room.
join2 = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[join2]
)
# Leave the room, then rejoin the room clobbering state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
join3 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[]
)
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join3]
)
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_backfilled_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
event_id = 0
@defer.inlineCallbacks
def persist(
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
state=None, reset_state=False, backfill=False,
depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
**content
):
"""
Returns:
synapse.events.FrozenEvent: The event that was persisted.
"""
if depth is None:
depth = self.event_id
event_dict = {
"sender": sender,
"type": type,
"content": content,
"event_id": "$%d:blue" % (self.event_id,),
"room_id": room_id,
"depth": depth,
"origin_server_ts": self.event_id,
"prev_events": prev_events,
"auth_events": auth_events,
}
if key is not None:
event_dict["state_key"] = key
event_dict["prev_state"] = prev_state
if redacts is not None:
event_dict["redacts"] = redacts
event = FrozenEvent(event_dict, internal_metadata_dict=internal)
self.event_id += 1
context = EventContext(current_state=state)
ordering = None
if backfill:
yield self.master_store.persist_events(
[(event, context)], backfilled=True
)
else:
ordering, _ = yield self.master_store.persist_event(
event, context, current_state=reset_state
)
if ordering:
event.internal_metadata.stream_ordering = ordering
defer.returnValue(event)

View file

@ -58,15 +58,21 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body, {})
@defer.inlineCallbacks
def test_events(self):
get = self.get(events="-1", timeout="0")
def test_events_and_state(self):
get = self.get(events="-1", state="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room(
Requester(self.user, "", False), {}
)
code, body = yield get
self.assertEquals(code, 200)
self.assertEquals(body["events"]["field_names"], [
"position", "internal", "json"
"position", "internal", "json", "state_group"
])
self.assertEquals(body["state_groups"]["field_names"], [
"position", "room_id", "event_id"
])
self.assertEquals(body["state_group_state"]["field_names"], [
"position", "type", "state_key", "event_id"
])
@defer.inlineCallbacks
@ -132,6 +138,7 @@ class ReplicationResourceCase(unittest.TestCase):
test_timeout_backfill = _test_timeout("backfill")
test_timeout_push_rules = _test_timeout("push_rules")
test_timeout_pushers = _test_timeout("pushers")
test_timeout_state = _test_timeout("state")
@defer.inlineCallbacks
def send_text_message(self, room_id, message):
@ -182,4 +189,21 @@ class ReplicationResourceCase(unittest.TestCase):
)
response_body = json.loads(response_json)
if response_code == 200:
self.check_response(response_body)
defer.returnValue((response_code, response_body))
def check_response(self, response_body):
for name, stream in response_body.items():
self.assertIn("field_names", stream)
field_names = stream["field_names"]
self.assertIn("rows", stream)
self.assertTrue(stream["rows"])
for row in stream["rows"]:
self.assertEquals(
len(row), len(field_names),
"%s: len(row = %r) == len(field_names = %r)" % (
name, row, field_names
)
)

View file

@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=404)
yield self.leave(room=room, user=usr, expect_code=404)
yield self.join(room=room, user=usr, expect_code=403)
yield self.leave(room=room, user=usr, expect_code=403)
@defer.inlineCallbacks
def test_membership_private_room_perms(self):

View file

@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
"test",
db_pool=self.db_pool,
config=config,
database_engine=create_engine(config),
database_engine=create_engine(config.database_config),
)
self.datastore = SQLBaseStore(hs)

View file

@ -34,33 +34,6 @@ class PresenceStoreTestCase(unittest.TestCase):
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_visibility(self):
self.assertFalse((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
yield self.store.allow_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)
self.assertTrue((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
yield self.store.disallow_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)
self.assertFalse((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
@defer.inlineCallbacks
def test_presence_list(self):
self.assertEquals(

View file

@ -110,22 +110,10 @@ class RedactionTestCase(unittest.TestCase):
self.room1, self.u_alice, Membership.JOIN
)
start = yield self.store.get_room_events_max_id()
msg_event = yield self.inject_message(self.room1, self.u_alice, u"t")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
# Check event has not been redacted:
event = results[0]
event = yield self.store.get_event(msg_event.event_id)
self.assertObjectHasAttributes(
{
@ -144,17 +132,7 @@ class RedactionTestCase(unittest.TestCase):
self.room1, msg_event.event_id, self.u_alice, reason
)
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
# Check redaction
event = results[0]
event = yield self.store.get_event(msg_event.event_id)
self.assertEqual(msg_event.event_id, event.event_id)
@ -184,25 +162,12 @@ class RedactionTestCase(unittest.TestCase):
self.room1, self.u_alice, Membership.JOIN
)
start = yield self.store.get_room_events_max_id()
msg_event = yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN,
extra_content={"blue": "red"},
)
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
# Check event has not been redacted:
event = results[0]
event = yield self.store.get_event(msg_event.event_id)
self.assertObjectHasAttributes(
{
@ -221,17 +186,9 @@ class RedactionTestCase(unittest.TestCase):
self.room1, msg_event.event_id, self.u_alice, reason
)
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
# Check redaction
event = results[0]
event = yield self.store.get_event(msg_event.event_id)
self.assertTrue("redacted_because" in event.unsigned)

View file

@ -70,13 +70,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
def test_one_member(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
Membership.JOIN,
(yield self.store.get_room_member(
user_id=self.u_alice.to_string(),
room_id=self.room.to_string(),
)).membership
)
self.assertEquals(
[self.u_alice.to_string()],
[m.user_id for m in (

View file

@ -1,185 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver
from mock import Mock
class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(
resource_for_federation=Mock(),
http_client=None,
)
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_injector = EventInjector(hs)
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
self.room1 = RoomID.from_string("!abc123:test")
self.room2 = RoomID.from_string("!xyx987:test")
@defer.inlineCallbacks
def test_event_stream_get_other(self):
# Both bob and alice joins the room
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
event = results[0]
self.assertObjectHasAttributes(
{
"type": EventTypes.Message,
"user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"},
},
event,
)
@defer.inlineCallbacks
def test_event_stream_get_own(self):
# Both bob and alice joins the room
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
)
self.assertEqual(1, len(results))
event = results[0]
self.assertObjectHasAttributes(
{
"type": EventTypes.Message,
"user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"},
},
event,
)
@defer.inlineCallbacks
def test_event_stream_join_leave(self):
# Both bob and alice joins the room
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Then bob leaves again.
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.LEAVE
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
)
# We should not get the message, as it happened *after* bob left.
self.assertEqual(0, len(results))
@defer.inlineCallbacks
def test_event_stream_prev_content(self):
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
start = yield self.store.get_room_events_max_id()
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN,
)
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
)
# We should not get the message, as it happened *after* bob left.
self.assertEqual(1, len(results))
event = results[0]
self.assertTrue(
"prev_content" in event.unsigned,
msg="No prev_content key"
)

View file

@ -21,6 +21,8 @@ from mock import Mock
from synapse.http.endpoint import resolve_service
from tests.utils import MockClock
class DnsTestCase(unittest.TestCase):
@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(servers[0].host, ip_address)
@defer.inlineCallbacks
def test_from_cache(self):
def test_from_cache_expired_and_dns_fail(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.examle.com"
entry = Mock(spec_set=["expires"])
entry.expires = 0
cache = {
service_name: [object()]
service_name: [entry]
}
servers = yield resolve_service(
@ -82,6 +87,31 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
def test_from_cache(self):
clock = MockClock()
dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[])
service_name = "test_service.examle.com"
entry = Mock(spec_set=["expires"])
entry.expires = 999999999
cache = {
service_name: [entry]
}
servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache, clock=clock,
)
self.assertFalse(dns_client_mock.lookupService.called)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
def test_empty_cache(self):
dns_client_mock = Mock()

View file

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.util.async import Linearizer
class LinearizerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_linearizer(self):
linearizer = Linearizer()
key = object()
d1 = linearizer.queue(key)
cm1 = yield d1
d2 = linearizer.queue(key)
self.assertFalse(d2.called)
with cm1:
self.assertFalse(d2.called)
self.assertTrue(d2.called)
with (yield d2):
pass

View file

@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
database_engine=create_engine(config),
database_engine=create_engine(config.database_config),
get_db_conn=db_pool.get_db_conn,
**kargs
)
@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
database_engine=create_engine(config),
database_engine=create_engine(config.database_config),
**kargs
)
@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
return conn
def create_engine(self):
return create_engine(self.config)
return create_engine(self.config.database_config)
class MemoryDataStore(object):