Merge branch 'develop' of github.com:matrix-org/synapse into erikj/dns_cache

This commit is contained in:
Erik Johnston 2016-04-07 11:11:17 +01:00
commit a28d066732
60 changed files with 2286 additions and 1242 deletions

View file

@ -19,6 +19,7 @@ from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse import argparse
import curses import curses
@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = {
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"],
} }
@ -292,7 +294,7 @@ class Porter(object):
} }
) )
database_engine.prepare_database(db_conn) prepare_database(db_conn, database_engine, config=None)
db_conn.commit() db_conn.commit()
@ -309,8 +311,8 @@ class Porter(object):
**self.postgres_config["args"] **self.postgres_config["args"]
) )
sqlite_engine = create_engine(FakeConfig(sqlite_config)) sqlite_engine = create_engine(sqlite_config)
postgres_engine = create_engine(FakeConfig(postgres_config)) postgres_engine = create_engine(postgres_config)
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine)
@ -792,8 +794,3 @@ if __name__ == "__main__":
if end_error_exec_info: if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback) traceback.print_exception(exc_type, exc_value, exc_traceback)
class FakeConfig:
def __init__(self, database_config):
self.database_config = database_config

View file

@ -17,3 +17,6 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
[pep8]
max-line-length = 90

View file

@ -33,7 +33,7 @@ from synapse.python_dependencies import (
from synapse.rest import ClientRestResource from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.server import HomeServer from synapse.server import HomeServer
@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self): def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should # Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine. # not be passed to the database engine.
db_params = { db_params = {
@ -254,6 +254,7 @@ class SynapseHomeServer(HomeServer):
} }
db_conn = self.database_engine.module.connect(**db_params) db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn) self.database_engine.on_new_connection(db_conn)
return db_conn return db_conn
@ -386,7 +387,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config) database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
@ -402,8 +403,10 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name']) logger.info("Preparing database: %s...", config.database_config['name'])
try: try:
db_conn = hs.get_db_conn() db_conn = hs.get_db_conn(run_new_connection=False)
database_engine.prepare_database(db_conn) prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
hs.run_startup_checks(db_conn, database_engine) hs.run_startup_checks(db_conn, database_engine)
db_conn.commit() db_conn.commit()

View file

@ -33,6 +33,9 @@ class _EventInternalMetadata(object):
def is_outlier(self): def is_outlier(self):
return getattr(self, "outlier", False) return getattr(self, "outlier", False)
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def _event_dict_property(key): def _event_dict_property(key):
def getter(self): def getter(self):

View file

@ -17,8 +17,9 @@ from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, RoomCreationHandler, RoomListHandler, RoomContextHandler,
) )
from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler

View file

@ -37,12 +37,22 @@ VISIBILITY_PRIORITY = (
) )
MEMBERSHIP_PRIORITY = (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
Membership.LEAVE,
Membership.BAN,
)
class BaseHandler(object): class BaseHandler(object):
""" """
Common base class for the event handlers. Common base class for the event handlers.
:type store: synapse.storage.events.StateStore Attributes:
:type state_handler: synapse.state.StateHandler store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler):
""" """
def __init__(self, hs): def __init__(self, hs):
@ -65,11 +75,13 @@ class BaseHandler(object):
""" Returns dict of user_id -> list of events that user is allowed to """ Returns dict of user_id -> list of events that user is allowed to
see. see.
:param (str, bool) user_tuples: (user id, is_peeking) for each Args:
user to be checked. is_peeking should be true if: user_tuples (str, bool): (user id, is_peeking) for each user to be
checked. is_peeking should be true if:
* the user is not currently a member of the room, and: * the user is not currently a member of the room, and:
* the user has not been a member of the room since the given * the user has not been a member of the room since the
events given events
events ([synapse.events.EventBase]): list of events to filter
""" """
forgotten = yield defer.gatherResults([ forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room( self.store.who_forgot_in_room(
@ -84,6 +96,12 @@ class BaseHandler(object):
) )
def allowed(event, user_id, is_peeking): def allowed(event, user_id, is_peeking):
"""
Args:
event (synapse.events.EventBase): event to check
user_id (str)
is_peeking (bool)
"""
state = event_id_to_state[event.event_id] state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event. # get the room_visibility at the time of the event.
@ -115,17 +133,30 @@ class BaseHandler(object):
if old_priority < new_priority: if old_priority < new_priority:
visibility = prev_visibility visibility = prev_visibility
# get the user's membership at the time of the event. (or rather, # likewise, if the event is the user's own membership event, use
# just *after* the event. Which means that people can see their # the 'most joined' membership
# own join events, but not (currently) their own leave events.) membership = None
if event.type == EventTypes.Member and event.state_key == user_id:
membership = event.content.get("membership", None)
if membership not in MEMBERSHIP_PRIORITY:
membership = "leave"
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority:
membership = prev_membership
# otherwise, get the user's membership at the time of the event.
if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None) membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event: if membership_event:
if membership_event.event_id in event_id_forgotten: if membership_event.event_id not in event_id_forgotten:
membership = None
else:
membership = membership_event.membership membership = membership_event.membership
else:
membership = None
# if the user was a member of the room at the time of the event, # if the user was a member of the room at the time of the event,
# they can see it. # they can see it.
@ -165,13 +196,16 @@ class BaseHandler(object):
""" """
Check which events a user is allowed to see Check which events a user is allowed to see
:param str user_id: user id to be checked Args:
:param [synapse.events.EventBase] events: list of events to be checked user_id(str): user id to be checked
:param bool is_peeking should be True if: events([synapse.events.EventBase]): list of events to be checked
is_peeking(bool): should be True if:
* the user is not currently a member of the room, and: * the user is not currently a member of the room, and:
* the user has not been a member of the room since the given * the user has not been a member of the room since the given
events events
:rtype [synapse.events.EventBase]
Returns:
[synapse.events.EventBase]
""" """
types = ( types = (
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@ -199,7 +233,12 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
else:
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id, builder.room_id,
) )
@ -221,50 +260,6 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder) context = yield state_handler.compute_event_context(builder)
# If we've received an invite over federation, there are no latest
# events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite.
if (
not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member
):
prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id
)
# The prev_member_event may already be in context.current_state,
# despite us not being present in the room; in particular, if
# inviting user, and all other local users, have already left.
#
# In that case, we have all the information we need, and we don't
# want to drop "context" - not least because we may need to handle
# the invite locally, which will require us to have the whole
# context (not just prev_member_event) to auth it.
#
context_event_ids = (
e.event_id for e in context.current_state.values()
)
if (
prev_member_event and
prev_member_event.event_id not in context_event_ids
):
# The prev_member_event is missing from context, so it must
# have arrived over federation and is an outlier. We forcibly
# set our context to the invite we received over federation
builder.prev_events = (
prev_member_event.event_id,
prev_member_event.prev_events
)
context = yield state_handler.compute_event_context(
builder,
old_state=(prev_member_event,),
outlier=True
)
if builder.is_state(): if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes( builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events context.prev_state_events

View file

@ -163,9 +163,13 @@ class AuthHandler(BaseHandler):
def get_session_id(self, clientdict): def get_session_id(self, clientdict):
""" """
Gets the session ID for a client given the client dictionary Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not Args:
send a session ID, returns None. clientdict: The dictionary sent by the client in the request
Returns:
str|None: The string session ID the client sent. If the client did
not send a session ID, returns None.
""" """
sid = None sid = None
if clientdict and 'auth' in clientdict: if clientdict and 'auth' in clientdict:
@ -179,9 +183,11 @@ class AuthHandler(BaseHandler):
Store a key-value pair into the sessions data associated with this Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by request. This data is stored server-side and cannot be modified by
the client. the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under Args:
:param value: (any) The data to store session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
value (any): The data to store
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value sess.setdefault('serverdict', {})[key] = value
@ -190,9 +196,11 @@ class AuthHandler(BaseHandler):
def get_session_data(self, session_id, key, default=None): def get_session_data(self, session_id, key, default=None):
""" """
Retrieve data stored with set_session_data Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under Args:
:param default: (any) Value to return if the key has not been set session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
default (any): Value to return if the key has not been set
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)

View file

@ -102,8 +102,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None, def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -174,11 +173,7 @@ class FederationHandler(BaseHandler):
}) })
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
yield self._handle_new_events( yield self._handle_new_events(origin, event_infos)
origin,
event_infos,
outliers=True
)
try: try:
context, event_stream_id, max_stream_id = yield self._handle_new_event( context, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -289,6 +284,9 @@ class FederationHandler(BaseHandler):
def backfill(self, dest, room_id, limit, extremities=[]): def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
""" """
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
if not extremities: if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id) extremities = yield self.store.get_oldest_events_in_room(room_id)
@ -455,7 +453,7 @@ class FederationHandler(BaseHandler):
likely_domains = [ likely_domains = [
domain for domain, depth in curr_domains domain for domain, depth in curr_domains
if domain is not self.server_name if domain != self.server_name
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@ -761,6 +759,7 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
@ -788,6 +787,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id): def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
try:
origin, event = yield self._make_and_verify_event( origin, event = yield self._make_and_verify_event(
target_hosts, target_hosts,
room_id, room_id,
@ -795,6 +795,11 @@ class FederationHandler(BaseHandler):
"leave" "leave"
) )
signed_event = self._sign_event(event) signed_event = self._sign_event(event)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
# Try the host we successfully got a response to /make_join/ # Try the host we successfully got a response to /make_join/
# request first. # request first.
@ -804,10 +809,16 @@ class FederationHandler(BaseHandler):
except ValueError: except ValueError:
pass pass
try:
yield self.replication_layer.send_leave( yield self.replication_layer.send_leave(
target_hosts, target_hosts,
signed_event signed_event
) )
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
@ -1069,9 +1080,6 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None): def _handle_new_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self._prep_event( context = yield self._prep_event(
origin, event, origin, event,
state=state, state=state,
@ -1087,14 +1095,12 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
is_new_state=not outlier,
) )
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False, def _handle_new_events(self, origin, event_infos, backfilled=False):
outliers=False):
contexts = yield defer.gatherResults( contexts = yield defer.gatherResults(
[ [
self._prep_event( self._prep_event(
@ -1113,7 +1119,6 @@ class FederationHandler(BaseHandler):
for ev_info, context in itertools.izip(event_infos, contexts) for ev_info, context in itertools.izip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1128,11 +1133,9 @@ class FederationHandler(BaseHandler):
""" """
events_to_context = {} events_to_context = {}
for e in itertools.chain(auth_events, state): for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
ctx = yield self.state_handler.compute_event_context(e)
events_to_context[e.event_id] = ctx
event_map = { event_map = {
e.event_id: e e.event_id: e
@ -1176,16 +1179,14 @@ class FederationHandler(BaseHandler):
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
], ],
is_new_state=False,
) )
new_event_context = yield self.state_handler.compute_event_context( new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False, event, old_state=state
) )
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
is_new_state=True,
current_state=state, current_state=state,
) )
@ -1193,10 +1194,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None): def _prep_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=outlier, event, old_state=state,
) )
if not auth_events: if not auth_events:
@ -1718,13 +1718,15 @@ class FederationHandler(BaseHandler):
def _check_signature(self, event, auth_events): def _check_signature(self, event, auth_events):
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
:param event (Event): The m.room.member event to check
:param auth_events (dict<(event type, state_key), event>)
:raises Args:
AuthError if signature didn't match any keys, or key has been event (Event): The m.room.member event to check
auth_events (dict<(event type, state_key), event>):
Raises:
AuthError: if signature didn't match any keys, or key has been
revoked, revoked,
SynapseError if a transient error meant a key couldn't be checked SynapseError: if a transient error meant a key couldn't be checked
for revocation. for revocation.
""" """
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
@ -1766,12 +1768,13 @@ class FederationHandler(BaseHandler):
""" """
Checks whether public_key has been revoked. Checks whether public_key has been revoked.
:param public_key (str): base-64 encoded public key. Args:
:param url (str): Key revocation URL. public_key (str): base-64 encoded public key.
url (str): Key revocation URL.
:raises Raises:
AuthError if they key has been revoked. AuthError: if they key has been revoked.
SynapseError if a transient error meant a key couldn't be checked SynapseError: if a transient error meant a key couldn't be checked
for revocation. for revocation.
""" """
try: try:

View file

@ -21,6 +21,7 @@ from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken from synapse.types import UserID, RoomStreamToken, StreamToken
@ -175,7 +176,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None): def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
""" """
Given a dict from a client, create a new event. Given a dict from a client, create a new event.
@ -186,6 +187,9 @@ class MessageHandler(BaseHandler):
Args: Args:
event_dict (dict): An entire event event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns: Returns:
Tuple of created event (FrozenEvent), Context Tuple of created event (FrozenEvent), Context
@ -224,6 +228,7 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event( event, context = yield self._create_new_client_event(
builder=builder, builder=builder,
prev_event_ids=prev_event_ids,
) )
defer.returnValue((event, context)) defer.returnValue((event, context))
@ -556,14 +561,7 @@ class MessageHandler(BaseHandler):
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
# Only do N rooms at once yield concurrently_execute(handle_room, room_list, 10)
n = 5
d_list = [handle_room(e) for e in room_list]
for i in range(0, len(d_list), n):
yield defer.gatherResults(
d_list[i:i + n],
consumeErrors=True
).addErrback(unwrapFirstError)
account_data_events = [] account_data_events = []
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():

View file

@ -18,20 +18,17 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset, EventTypes, JoinRules, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils
from synapse.util.async import concurrently_execute
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from collections import OrderedDict from collections import OrderedDict
from unpaddedbase64 import decode_base64
import logging import logging
import math import math
@ -357,599 +354,6 @@ class RoomCreationHandler(BaseHandler):
) )
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": effective_membership_state}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": effective_membership_state,
},
token_id=requester.access_token_id,
txn_id=txn_id,
)
old_state = context.current_state.get((EventTypes.Member, event.state_key))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(
requester,
event,
context,
ratelimit=ratelimit,
remote_room_hosts=remote_room_hosts,
)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
action = "send"
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
do_remote_join_dance, remote_room_hosts = self._should_do_dance(
context,
(self.get_inviter(event.state_key, context.current_state)),
remote_room_hosts,
)
if do_remote_join_dance:
action = "remote_join"
elif event.membership == Membership.LEAVE:
is_host_in_room = self.is_host_in_room(context.current_state)
if not is_host_in_room:
# perhaps we've been invited
inviter = self.get_inviter(target_user.to_string(), context.current_state)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
action = "remote_reject"
federation_handler = self.hs.get_handlers().federation_handler
if action == "remote_join":
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield federation_handler.do_invite_join(
remote_room_hosts,
event.room_id,
event.user_id,
event.content,
)
elif action == "remote_reject":
yield federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
event.user_id
)
else:
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
def _should_do_dance(self, context, inviter, room_hosts=None):
# TODO: Shouldn't this be remote_room_host?
room_hosts = room_hosts or []
is_host_in_room = self.is_host_in_room(context.current_state)
if is_host_in_room:
return False, room_hosts
if inviter and not self.hs.is_mine(inviter):
room_hosts.append(inviter.domain)
return True, room_hosts
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
def get_inviter(self, user_id, current_state):
prev_state = current_state.get((EventTypes.Member, user_id))
if prev_state and prev_state.membership == Membership.INVITE:
return UserID.from_string(prev_state.user_id)
return None
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
(str) the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
:param id_server (str): hostname + optional port for the identity server.
:param medium (str): The literal string "email".
:param address (str): The third party address being invited.
:param room_id (str): The ID of the room to which the user is invited.
:param inviter_user_id (str): The user ID of the inviter.
:param room_alias (str): An alias for the room, for cosmetic
notifications.
:param room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
:param room_join_rules (str): The join rules of the email
(e.g. "public").
:param room_name (str): The m.room.name of the room.
:param inviter_display_name (str): The current display name of the
inviter.
:param inviter_avatar_url (str): The URL of the inviter's avatar.
:return: A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
def forget(self, user, room_id):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership != Membership.LEAVE:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
if membership:
yield self.store.forget(user_id, room_id)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomListHandler, self).__init__(hs) super(RoomListHandler, self).__init__(hs)
@ -965,6 +369,8 @@ class RoomListHandler(BaseHandler):
def _get_public_room_list(self): def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids() room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_room(room_id): def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
@ -1025,18 +431,12 @@ class RoomListHandler(BaseHandler):
joined_users = yield self.store.get_users_in_room(room_id) joined_users = yield self.store.get_users_in_room(room_id)
result["num_joined_members"] = len(joined_users) result["num_joined_members"] = len(joined_users)
defer.returnValue(result) results.append(result)
result = [] yield concurrently_execute(handle_room, room_ids, 10)
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
chunk_result = yield defer.gatherResults([
handle_room(room_id)
for room_id in chunk
], consumeErrors=True).addErrback(unwrapFirstError)
result.extend(v for v in chunk_result if v)
# FIXME (erikj): START is no longer a valid value # FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": result}) defer.returnValue({"start": "START", "end": "END", "chunk": results})
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):

View file

@ -0,0 +1,718 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
prev_event_ids,
txn_id=None,
ratelimit=True,
):
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": membership}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": membership,
},
token_id=requester.access_token_id,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
)
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.hs.get_handlers().federation_handler.do_invite_join(
remote_room_hosts,
room_id,
user.to_string(),
content,
)
yield user_joined_room(self.distributor, user, room_id)
def reject_remote_invite(self, user_id, room_id, remote_room_hosts):
return self.hs.get_handlers().federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
user_id
)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
if not remote_room_hosts:
remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state = yield self.state_handler.get_current_state(
room_id, latest_event_ids=latest_event_ids,
)
old_state = current_state.get((EventTypes.Member, target.to_string()))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
is_host_in_room = self.is_host_in_room(current_state)
if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content = {"membership": Membership.JOIN}
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
ret = yield self.remote_join(
remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
try:
ret = yield self.reject_remote_invite(
target.to_string(), room_id, remote_room_hosts
)
defer.returnValue(ret)
except SynapseError as e:
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
yield self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
def _should_do_dance(self, current_state, inviter, room_hosts=None):
# TODO: Shouldn't this be remote_room_host?
room_hosts = room_hosts or []
is_host_in_room = self.is_host_in_room(current_state)
if is_host_in_room:
return False, room_hosts
if inviter and not self.hs.is_mine(inviter):
room_hosts.append(inviter.domain)
return True, room_hosts
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
def get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
Args:
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
def forget(self, user, room_id):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership != Membership.LEAVE:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
if membership:
yield self.store.forget(user_id, room_id)

View file

@ -17,8 +17,8 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
@ -250,15 +250,13 @@ class SyncHandler(BaseHandler):
joined = [] joined = []
invited = [] invited = []
archived = [] archived = []
deferreds = []
room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)] user_id = sync_config.user.to_string()
for room_list_chunk in room_list_chunks:
for event in room_list_chunk: @defer.inlineCallbacks
def _generate_room_entry(event):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
room_sync_deferred = preserve_fn( room_result = yield self.full_state_sync_for_joined_room(
self.full_state_sync_for_joined_room
)(
room_id=event.room_id, room_id=event.room_id,
sync_config=sync_config, sync_config=sync_config,
now_token=now_token, now_token=now_token,
@ -267,8 +265,7 @@ class SyncHandler(BaseHandler):
tags_by_room=tags_by_room, tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room, account_data_by_room=account_data_by_room,
) )
room_sync_deferred.addCallback(joined.append) joined.append(room_result)
deferreds.append(room_sync_deferred)
elif event.membership == Membership.INVITE: elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id) invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult( invited.append(InvitedSyncResult(
@ -279,15 +276,13 @@ class SyncHandler(BaseHandler):
# Always send down rooms we were banned or kicked from. # Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave: if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
if sync_config.user.to_string() == event.sender: if user_id == event.sender:
continue return
leave_token = now_token.copy_and_replace( leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,) "room_key", "s%d" % (event.stream_ordering,)
) )
room_sync_deferred = preserve_fn( room_result = yield self.full_state_sync_for_archived_room(
self.full_state_sync_for_archived_room
)(
sync_config=sync_config, sync_config=sync_config,
room_id=event.room_id, room_id=event.room_id,
leave_event_id=event.event_id, leave_event_id=event.event_id,
@ -296,12 +291,9 @@ class SyncHandler(BaseHandler):
tags_by_room=tags_by_room, tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room, account_data_by_room=account_data_by_room,
) )
room_sync_deferred.addCallback(archived.append) archived.append(room_result)
deferreds.append(room_sync_deferred)
yield defer.gatherResults( yield concurrently_execute(_generate_room_entry, room_list, 10)
deferreds, consumeErrors=True
).addErrback(unwrapFirstError)
account_data_for_user = sync_config.filter_collection.filter_account_data( account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data) self.account_data_for_user(account_data)
@ -671,7 +663,8 @@ class SyncHandler(BaseHandler):
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None, recents=None, newly_joined_room=False): since_token=None, recents=None, newly_joined_room=False):
""" """
:returns a Deferred TimelineBatch Returns:
a Deferred TimelineBatch
""" """
with Measure(self.clock, "load_filtered_recents"): with Measure(self.clock, "load_filtered_recents"):
filtering_factor = 2 filtering_factor = 2
@ -838,8 +831,11 @@ class SyncHandler(BaseHandler):
""" """
Get the room state after the given event Get the room state after the given event
:param synapse.events.EventBase event: event of interest Args:
:return: A Deferred map from ((type, state_key)->Event) event(synapse.events.EventBase): event of interest
Returns:
A Deferred map from ((type, state_key)->Event)
""" """
state = yield self.store.get_state_for_event(event.event_id) state = yield self.store.get_state_for_event(event.event_id)
if event.is_state(): if event.is_state():
@ -850,9 +846,13 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position): def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position """ Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state Args:
:returns: A Deferred map from ((type, state_key)->Event) room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state
Returns:
A Deferred map from ((type, state_key)->Event)
""" """
last_events, token = yield self.store.get_recent_events_for_room( last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1, room_id, end_token=stream_position.room_key, limit=1,
@ -873,15 +873,18 @@ class SyncHandler(BaseHandler):
""" Works out the differnce in state between the start of the timeline """ Works out the differnce in state between the start of the timeline
and the previous sync. and the previous sync.
:param str room_id Args:
:param TimelineBatch batch: The timeline batch for the room that will room_id(str):
be sent to the user. batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
:param sync_config the room that will be sent to the user.
:param str since_token: Token of the end of the previous batch. May be None. sync_config(synapse.handlers.sync.SyncConfig):
:param str now_token: Token of the end of the current batch. since_token(str|None): Token of the end of the previous batch. May
:param bool full_state: Whether to force returning the full state. be None.
now_token(str): Token of the end of the current batch.
full_state(bool): Whether to force returning the full state.
:returns A new event dictionary Returns:
A deferred new event dictionary
""" """
# TODO(mjark) Check if the state events were received by the server # TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state # after the previous sync, since we need to include those state
@ -953,11 +956,13 @@ class SyncHandler(BaseHandler):
Check if the user has just joined the given room (so should Check if the user has just joined the given room (so should
be given the full state) be given the full state)
:param sync_config: Args:
:param dict[(str,str), synapse.events.FrozenEvent] state_delta: the sync_config(synapse.handlers.sync.SyncConfig):
state_delta(dict[(str,str), synapse.events.FrozenEvent]): the
difference in state since the last sync difference in state since the last sync
:returns A deferred Tuple (state_delta, limited) Returns:
A deferred Tuple (state_delta, limited)
""" """
join_event = state_delta.get(( join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None) EventTypes.Member, sync_config.user.to_string()), None)

View file

@ -26,14 +26,19 @@ logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False): def parse_integer(request, name, default=None, required=False):
"""Parse an integer parameter from the request string """Parse an integer parameter from the request string
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (int|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False. parameter is absent, defaults to False.
:return: An int value or the default.
:raises Returns:
SynapseError if the parameter is absent and required, or if the int|None: An int value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer. parameter is present and not an integer.
""" """
if name in request.args: if name in request.args:
@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False):
def parse_boolean(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False):
"""Parse a boolean parameter from the request query string """Parse a boolean parameter from the request query string
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (bool|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False. parameter is absent, defaults to False.
:return: A bool value or the default.
:raises Returns:
SynapseError if the parameter is absent and required, or if the bool|None: A bool value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false". parameter is present and not one of "true" or "false".
""" """
@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string. """Parse a string parameter from the request query string.
:param request: the twisted HTTP request. Args:
:param name (str): the name of the query parameter. request: the twisted HTTP request.
:param default: value to use if the parameter is absent, defaults to None. name (str): the name of the query parameter.
:param required (bool): whether to raise a 400 SynapseError if the default (str|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False. parameter is absent, defaults to False.
:param allowed_values (list): List of allowed values for the string, allowed_values (list[str]): List of allowed values for the string,
or None if any value is allowed, defaults to None or None if any value is allowed, defaults to None
:return: A string value or the default.
:raises Returns:
str|None: A string value or the default.
Raises:
SynapseError if the parameter is absent and required, or if the SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and parameter is present, must be one of a list of allowed values and
is not one of those allowed values. is not one of those allowed values.
@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False,
def parse_json_value_from_request(request): def parse_json_value_from_request(request):
"""Parse a JSON value from the body of a twisted HTTP request. """Parse a JSON value from the body of a twisted HTTP request.
:param request: the twisted HTTP request. Args:
:returns: The JSON value. request: the twisted HTTP request.
:raises
Returns:
The JSON value.
Raises:
SynapseError if the request body couldn't be decoded as JSON. SynapseError if the request body couldn't be decoded as JSON.
""" """
try: try:
@ -143,8 +162,10 @@ def parse_json_value_from_request(request):
def parse_json_object_from_request(request): def parse_json_object_from_request(request):
"""Parse a JSON object from the body of a twisted HTTP request. """Parse a JSON object from the body of a twisted HTTP request.
:param request: the twisted HTTP request. Args:
:raises request: the twisted HTTP request.
Raises:
SynapseError if the request body couldn't be decoded as JSON or SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object. if it wasn't a JSON object.
""" """

View file

@ -503,13 +503,14 @@ class Notifier(object):
def wait_for_replication(self, callback, timeout): def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen. """Wait for an event to happen.
:param callback: Args:
Gets called whenever an event happens. If this returns a truthy callback: Gets called whenever an event happens. If this returns a
value then ``wait_for_replication`` returns, otherwise it waits truthy value then ``wait_for_replication`` returns, otherwise
for another event. it waits for another event.
:param int timeout: timeout: How many milliseconds to wait for callback return a truthy
How many milliseconds to wait for callback return a truthy value. value.
:returns:
Returns:
A deferred that resolves with the value returned by the callback. A deferred that resolves with the value returned by the callback.
""" """
listener = _NotificationListener(None) listener = _NotificationListener(None)

View file

@ -19,9 +19,11 @@ import copy
def list_with_base_rules(rawrules): def list_with_base_rules(rawrules):
"""Combine the list of rules set by the user with the default push rules """Combine the list of rules set by the user with the default push rules
:param list rawrules: The rules the user has modified or set. Args:
:returns: A new list with the rules set by the user combined with the rawrules(list): The rules the user has modified or set.
defaults.
Returns:
A new list with the rules set by the user combined with the defaults.
""" """
ruleslist = [] ruleslist = []

View file

@ -133,8 +133,9 @@ class PushRuleEvaluator:
enabled = self.enabled_map.get(r['rule_id'], None) enabled = self.enabled_map.get(r['rule_id'], None)
if enabled is not None and not enabled: if enabled is not None and not enabled:
continue continue
elif enabled is None and not r.get("enabled", True):
if not r.get("enabled", True): # if no override, check enabled on the rule itself
# (may have come from a base rule)
continue continue
conditions = r['conditions'] conditions = r['conditions']

View file

@ -145,32 +145,43 @@ class ReplicationResource(Resource):
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
writer = _Writer(request)
request_streams = {
name: parse_integer(request, name)
for names in STREAM_NAMES for name in names
}
request_streams["streams"] = parse_string(request, "streams")
def replicate():
return self.replicate(request_streams, limit)
result = yield self.notifier.wait_for_replication(replicate, timeout)
request.write(json.dumps(result, ensure_ascii=False))
finish_request(request)
@defer.inlineCallbacks @defer.inlineCallbacks
def replicate(): def replicate(self, request_streams, limit):
writer = _Writer()
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token) logger.info("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit) yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit) yield self.events(writer, current_token, limit, request_streams)
yield self.presence(writer, current_token) # TODO: implement limit # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit yield self.presence(writer, current_token, request_streams)
yield self.receipts(writer, current_token, limit) yield self.typing(writer, current_token, request_streams)
yield self.push_rules(writer, current_token, limit) yield self.receipts(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit) yield self.pushers(writer, current_token, limit, request_streams)
self.streams(writer, current_token) yield self.state(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer.total) defer.returnValue(writer.finish())
yield self.notifier.wait_for_replication(replicate, timeout) def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams")
writer.finish()
def streams(self, writer, current_token):
request_token = parse_string(writer.request, "streams")
streams = [] streams = []
@ -195,32 +206,43 @@ class ReplicationResource(Resource):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def events(self, writer, current_token, limit): def events(self, writer, current_token, limit, request_streams):
request_events = parse_integer(writer.request, "events") request_events = request_streams.get("events")
request_backfill = parse_integer(writer.request, "backfill") request_backfill = request_streams.get("backfill")
if request_events is not None or request_backfill is not None: if request_events is not None or request_backfill is not None:
if request_events is None: if request_events is None:
request_events = current_token.events request_events = current_token.events
if request_backfill is None: if request_backfill is None:
request_backfill = current_token.backfill request_backfill = current_token.backfill
events_rows, backfill_rows = yield self.store.get_all_new_events( res = yield self.store.get_all_new_events(
request_backfill, request_events, request_backfill, request_events,
current_token.backfill, current_token.events, current_token.backfill, current_token.events,
limit limit
) )
writer.write_header_and_rows("events", events_rows, ( writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group"
)) ))
writer.write_header_and_rows("backfill", backfill_rows, ( writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group"
)) ))
writer.write_header_and_rows(
"forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group")
)
writer.write_header_and_rows(
"backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group")
)
writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def presence(self, writer, current_token): def presence(self, writer, current_token, request_streams):
current_position = current_token.presence current_position = current_token.presence
request_presence = parse_integer(writer.request, "presence") request_presence = request_streams.get("presence")
if request_presence is not None: if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates( presence_rows = yield self.presence_handler.get_all_presence_updates(
@ -233,10 +255,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token): def typing(self, writer, current_token, request_streams):
current_position = current_token.presence current_position = current_token.presence
request_typing = parse_integer(writer.request, "typing") request_typing = request_streams.get("typing")
if request_typing is not None: if request_typing is not None:
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
@ -247,10 +269,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def receipts(self, writer, current_token, limit): def receipts(self, writer, current_token, limit, request_streams):
current_position = current_token.receipts current_position = current_token.receipts
request_receipts = parse_integer(writer.request, "receipts") request_receipts = request_streams.get("receipts")
if request_receipts is not None: if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts( receipts_rows = yield self.store.get_all_updated_receipts(
@ -261,12 +283,12 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def account_data(self, writer, current_token, limit): def account_data(self, writer, current_token, limit, request_streams):
current_position = current_token.account_data current_position = current_token.account_data
user_account_data = parse_integer(writer.request, "user_account_data") user_account_data = request_streams.get("user_account_data")
room_account_data = parse_integer(writer.request, "room_account_data") room_account_data = request_streams.get("room_account_data")
tag_account_data = parse_integer(writer.request, "tag_account_data") tag_account_data = request_streams.get("tag_account_data")
if user_account_data is not None or room_account_data is not None: if user_account_data is not None or room_account_data is not None:
if user_account_data is None: if user_account_data is None:
@ -292,10 +314,10 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules(self, writer, current_token, limit): def push_rules(self, writer, current_token, limit, request_streams):
current_position = current_token.push_rules current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules") push_rules = request_streams.get("push_rules")
if push_rules is not None: if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates( rows = yield self.store.get_all_push_rule_updates(
@ -307,10 +329,11 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def pushers(self, writer, current_token, limit): def pushers(self, writer, current_token, limit, request_streams):
current_position = current_token.pushers current_position = current_token.pushers
pushers = parse_integer(writer.request, "pushers") pushers = request_streams.get("pushers")
if pushers is not None: if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers( updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit pushers, current_position, limit
@ -325,10 +348,11 @@ class ReplicationResource(Resource):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def state(self, writer, current_token, limit): def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state current_position = current_token.state
state = parse_integer(writer.request, "state") state = request_streams.get("state")
if state is not None: if state is not None:
state_groups, state_group_state = ( state_groups, state_group_state = (
yield self.store.get_all_new_state_groups( yield self.store.get_all_new_state_groups(
@ -345,9 +369,8 @@ class ReplicationResource(Resource):
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
def __init__(self, request): def __init__(self):
self.streams = {} self.streams = {}
self.request = request
self.total = 0 self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None): def write_header_and_rows(self, name, rows, fields, position=None):
@ -366,8 +389,7 @@ class _Writer(object):
self.total += len(rows) self.total += len(rows)
def finish(self): def finish(self):
self.request.write(json.dumps(self.streams, ensure_ascii=False)) return self.streams
finish_request(self.request)
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage._base import SQLBaseStore
from twisted.internet import defer
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
def stream_positions(self):
return {}
def process_replication(self, result):
return defer.succeed(None)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker(object):
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self.advance(_load_current_id(db_conn, table, column))
def advance(self, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
return self._current

View file

@ -0,0 +1,199 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.state import StateStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
get_membership_changes_for_user = (
DataStore.get_membership_changes_for_user.__func__
)
get_room_events_max_id = DataStore.get_room_events_max_id.__func__
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
_set_before_and_after = DataStore._set_before_and_after
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_parse_events_txn = DataStore._parse_events_txn.__func__
_get_events_txn = DataStore._get_events_txn.__func__
_fetch_events_txn = DataStore._fetch_events_txn.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfilled"] = self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
state_resets = set(
r[0] for r in result.get("state_resets", {"rows": []})["rows"]
)
stream = result.get("events")
if stream:
self._stream_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
)
stream = result.get("backfill")
if stream:
self._backfill_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
)
stream = result.get("forward_ex_outliers")
if stream:
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled, state_resets):
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self._invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
)
def _invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id)
if not backfilled:
self._events_stream_cache.entity_has_changed(
event.room_id, event.internal_metadata.stream_ordering
)
# self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
# (event.room_id,)
# )
if event.type == EventTypes.Redaction:
self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
# self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering
# )
if not event.is_state():
return
if backfilled:
return
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))
if event.type in [EventTypes.Name, EventTypes.Aliases]:
self.get_room_name_and_aliases.invalidate(
(event.room_id,)
)
pass

View file

@ -199,15 +199,17 @@ class SyncRestServlet(RestServlet):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync Args:
rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
results for rooms this user is joined to results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age time_now(int): current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
:return: the joined rooms list, in our response format Returns:
:rtype: dict[str, dict[str, object]] dict[str, dict[str, object]]: the joined rooms list, in our
response format
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
@ -221,15 +223,17 @@ class SyncRestServlet(RestServlet):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
sync results for rooms this user is joined to sync results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age time_now(int): current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
:return: the invited rooms list, in our response format Returns:
:rtype: dict[str, dict[str, object]] dict[str, dict[str, object]]: the invited rooms list, in our
response format
""" """
invited = {} invited = {}
for room in rooms: for room in rooms:
@ -251,15 +255,17 @@ class SyncRestServlet(RestServlet):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of Args:
rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
sync results for rooms this user is joined to sync results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age time_now(int): current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
:return: the invited rooms list, in our response format Returns:
:rtype: dict[str, dict[str, object]] dict[str, dict[str, object]]: The invited rooms list, in our
response format
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
@ -272,17 +278,18 @@ class SyncRestServlet(RestServlet):
@staticmethod @staticmethod
def encode_room(room, time_now, token_id, joined=True): def encode_room(room, time_now, token_id, joined=True):
""" """
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
single room single room
:param int time_now: current time - used as a baseline for age time_now (int): current time - used as a baseline for age
calculations calculations
:param int token_id: ID of the user's auth token - used for namespacing token_id (int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
:param joined: True if the user is joined to this room - will mean joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events we handle ephemeral events
:return: the room, encoded in our response format Returns:
:rtype: dict[str, object] dict[str, object]: the room, encoded in our response format
""" """
def serialize(event): def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.

View file

@ -75,7 +75,8 @@ class StateHandler(object):
self._state_cache.start() self._state_cache.start()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key="",
latest_event_ids=None):
""" Retrieves the current state for the room. This is done by """ Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts. event graph and then resolving any of the state conflicts.
@ -86,11 +87,13 @@ class StateHandler(object):
If `event_type` is specified, then the method returns only the one If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`. event (or None) with that `event_type` and `state_key`.
:returns map from (type, state_key) to event Returns:
map from (type, state_key) to event
""" """
event_ids = yield self.store.get_latest_event_ids_in_room(room_id) if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
res = yield self.resolve_state_groups(room_id, event_ids) res = yield self.resolve_state_groups(room_id, latest_event_ids)
state = res[1] state = res[1]
if event_type: if event_type:
@ -100,7 +103,7 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None, outlier=False): def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The """ Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph `current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event` just before the event - i.e. it never includes `event`
@ -115,7 +118,7 @@ class StateHandler(object):
""" """
context = EventContext() context = EventContext()
if outlier: if event.internal_metadata.is_outlier():
# If this is an outlier, then we know it shouldn't have any current # If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
@ -176,7 +179,8 @@ class StateHandler(object):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
:returns a Deferred tuple of (`state_group`, `state`, `prev_state`). Returns:
a Deferred tuple of (`state_group`, `state`, `prev_state`).
`state_group` is the name of a state group if one and only one is `state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids. `prev_state` is a list of event ids.
@ -251,9 +255,10 @@ class StateHandler(object):
def _resolve_events(self, state_sets, event_type=None, state_key=""): def _resolve_events(self, state_sets, event_type=None, state_key=""):
""" """
:returns a tuple (new_state, prev_states). new_state is a map Returns
from (type, state_key) to event. prev_states is a list of event_ids. (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
:rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str]) (new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
""" """
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
state = {} state = {}

View file

@ -88,22 +88,17 @@ class DataStore(RoomMemberStore, RoomStore,
self.hs = hs self.hs = hs
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
cur = db_conn.cursor()
try:
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
self.min_stream_token = min(self.min_stream_token, -1)
finally:
cur.close()
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
) )
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering" db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1
) )
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
@ -129,7 +124,7 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
events_max = self._stream_id_gen.get_max_token() events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",
entity_column="room_id", entity_column="room_id",
@ -145,7 +140,7 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max,
) )
account_max = self._account_data_id_gen.get_max_token() account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max, "AccountDataAndTagsChangeCache", account_max,
) )
@ -156,7 +151,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", db_conn, "presence_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(), max_value=self._presence_id_gen.get_current_token(),
) )
self.presence_stream_cache = StreamChangeCache( self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val, "PresenceStreamChangeCache", min_presence_val,
@ -167,7 +162,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "push_rules_stream", db_conn, "push_rules_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0], max_value=self._push_rules_stream_id_gen.get_current_token()[0],
) )
self.push_rules_stream_cache = StreamChangeCache( self.push_rules_stream_cache = StreamChangeCache(
@ -182,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
self.__presence_on_startup = None self.__presence_on_startup = None
return active_on_startup return active_on_startup
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
def _get_active_presence(self, db_conn): def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register """Fetch non-offline presence from the database so that we can register
the appropriate time outs. the appropriate time outs.

View file

@ -816,6 +816,40 @@ class SQLBaseStore(object):
self._next_stream_id += 1 self._next_stream_id += 1
return i return i
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying

View file

@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id "add_room_account_data", add_account_data_txn, next_id
) )
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id "add_user_account_data", add_account_data_txn, next_id
) )
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id): def _update_max_stream_id(self, txn, next_id):

View file

@ -26,13 +26,13 @@ SUPPORTED_MODULE = {
} }
def create_engine(config): def create_engine(database_config):
name = config.database_config["name"] name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None) engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class: if engine_class:
module = importlib.import_module(name) module = importlib.import_module(name)
return engine_class(module, config=config) return engine_class(module)
raise RuntimeError( raise RuntimeError(
"Unsupported database engine '%s'" % (name,) "Unsupported database engine '%s'" % (name,)

View file

@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.prepare_database import prepare_database
from ._base import IncorrectDatabaseSetup from ._base import IncorrectDatabaseSetup
class PostgresEngine(object): class PostgresEngine(object):
single_threaded = False single_threaded = False
def __init__(self, database_module, config): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)
self.config = config
def check_database(self, txn): def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING") txn.execute("SHOW SERVER_ENCODING")
@ -44,9 +41,6 @@ class PostgresEngine(object):
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
) )
def prepare_database(self, db_conn):
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error): def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError): if isinstance(error, self.module.DatabaseError):
return error.pgcode in ["40001", "40P01"] return error.pgcode in ["40001", "40P01"]

View file

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.prepare_database import ( from synapse.storage.prepare_database import prepare_database
prepare_database, prepare_sqlite3_database
)
import struct import struct
@ -23,9 +21,8 @@ import struct
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True single_threaded = True
def __init__(self, database_module, config): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.config = config
def check_database(self, txn): def check_database(self, txn):
pass pass
@ -34,13 +31,9 @@ class Sqlite3Engine(object):
return sql return sql
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
self.prepare_database(db_conn) prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank) db_conn.create_function("rank", 1, _rank)
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error): def is_deadlock(self, error):
return False return False

View file

@ -163,6 +163,22 @@ class EventFederationStore(SQLBaseStore):
room_id, room_id,
) )
@defer.inlineCallbacks
def get_max_depth_of_events(self, event_ids):
sql = (
"SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
) % (",".join(["?"] * len(event_ids)),)
rows = yield self._execute(
"get_max_depth_of_events", None,
sql, *event_ids
)
if rows:
defer.returnValue(rows[0][0])
else:
defer.returnValue(1)
def _get_min_depth_interaction(self, txn, room_id): def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn( min_depth = self._simple_select_one_onecol_txn(
txn, txn,

View file

@ -26,8 +26,9 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore): class EventPushActionsStore(SQLBaseStore):
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
""" """
:param event: the event set actions for Args:
:param tuples: list of tuples of (user_id, actions) event: the event set actions for
tuples: list of tuples of (user_id, actions)
""" """
values = [] values = []
for uid, actions in tuples: for uid, actions in tuples:

View file

@ -24,8 +24,7 @@ from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from contextlib import contextmanager from collections import namedtuple
import logging import logging
import math import math
@ -61,20 +60,14 @@ class EventsStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False, def persist_events(self, events_and_contexts, backfilled=False):
is_new_state=True):
if not events_and_contexts: if not events_and_contexts:
return return
if backfilled: if backfilled:
start = self.min_stream_token - 1 stream_ordering_manager = self._backfill_id_gen.get_next_mult(
self.min_stream_token -= len(events_and_contexts) + 1 len(events_and_contexts)
stream_orderings = range(start, self.min_stream_token, -1) )
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else: else:
stream_ordering_manager = self._stream_id_gen.get_next_mult( stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts) len(events_and_contexts)
@ -110,13 +103,11 @@ class EventsStore(SQLBaseStore):
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
is_new_state=is_new_state,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, def persist_event(self, event, context, current_state=None):
is_new_state=True, current_state=None):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
@ -128,13 +119,12 @@ class EventsStore(SQLBaseStore):
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context, context=context,
is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_max_token() max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((stream_ordering, max_persisted_id)) defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -194,8 +184,7 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@log_function @log_function
def _persist_event_txn(self, txn, event, context, def _persist_event_txn(self, txn, event, context, current_state):
is_new_state, current_state):
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
@ -203,7 +192,16 @@ class EventsStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id) txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = event.internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
@ -227,12 +225,10 @@ class EventsStore(SQLBaseStore):
txn, txn,
[(event, context)], [(event, context)],
backfilled=False, backfilled=False,
is_new_state=is_new_state,
) )
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled):
is_new_state):
depth_updates = {} depth_updates = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
@ -314,6 +310,18 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,) (metadata_json, event.event_id,)
) )
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
"event_stream_ordering": stream_order,
"event_id": event.event_id,
"state_group": state_group_id,
}
)
sql = ( sql = (
"UPDATE events SET outlier = ?" "UPDATE events SET outlier = ?"
" WHERE event_id = ?" " WHERE event_id = ?"
@ -359,7 +367,8 @@ class EventsStore(SQLBaseStore):
event event
for event, _ in events_and_contexts for event, _ in events_and_contexts
if event.type == EventTypes.Member if event.type == EventTypes.Member
] ],
backfilled=backfilled,
) )
def event_dict(event): def event_dict(event):
@ -431,10 +440,9 @@ class EventsStore(SQLBaseStore):
txn, [event for event, _ in events_and_contexts] txn, [event for event, _ in events_and_contexts]
) )
state_events_and_contexts = filter( state_events_and_contexts = [
lambda i: i[0].is_state(), ec for ec in events_and_contexts if ec[0].is_state()
events_and_contexts, ]
)
state_values = [] state_values = []
for event, context in state_events_and_contexts: for event, context in state_events_and_contexts:
@ -472,9 +480,21 @@ class EventsStore(SQLBaseStore):
], ],
) )
if is_new_state: if backfilled:
# Backfilled events come before the current state so we don't need
# to update the current state table
return
for event, _ in state_events_and_contexts: for event, _ in state_events_and_contexts:
if not context.rejected: if event.internal_metadata.is_outlier():
# Outlier events shouldn't clobber the current state.
continue
if context.rejected:
# If the event failed it's auth checks then it shouldn't
# clobbler the current state.
continue
txn.call_after( txn.call_after(
self._get_current_state_for_key.invalidate, self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,) (event.room_id, event.type, event.state_key,)
@ -1086,10 +1106,7 @@ class EventsStore(SQLBaseStore):
def get_current_backfill_token(self): def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached""" """The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
# TODO: Fix race with the persit_event txn by using one of the
# stream id managers
return -self.min_stream_token
def get_all_new_events(self, last_backfill_id, last_forward_id, def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit): current_backfill_id, current_forward_id, limit):
@ -1110,8 +1127,34 @@ class EventsStore(SQLBaseStore):
if last_forward_id != current_forward_id: if last_forward_id != current_forward_id:
txn.execute(sql, (last_forward_id, current_forward_id, limit)) txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall() new_forward_events = txn.fetchall()
if len(new_forward_events) == limit:
upper_bound = new_forward_events[-1][0]
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering FROM current_state_resets"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_forward_id, upper_bound))
state_resets = txn.fetchall()
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_forward_id, upper_bound))
forward_ex_outliers = txn.fetchall()
else: else:
new_forward_events = [] new_forward_events = []
state_resets = []
forward_ex_outliers = []
sql = ( sql = (
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json," "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
@ -1128,8 +1171,35 @@ class EventsStore(SQLBaseStore):
if last_backfill_id != current_backfill_id: if last_backfill_id != current_backfill_id:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall() new_backfill_events = txn.fetchall()
if len(new_backfill_events) == limit:
upper_bound = new_backfill_events[-1][0]
else:
upper_bound = current_backfill_id
sql = (
"SELECT -event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_backfill_id, -upper_bound))
backward_ex_outliers = txn.fetchall()
else: else:
new_backfill_events = [] new_backfill_events = []
backward_ex_outliers = []
return (new_forward_events, new_backfill_events) return AllNewEventsResult(
new_forward_events, new_backfill_events,
forward_ex_outliers, backward_ex_outliers,
state_resets,
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn) return self.runInteraction("get_all_new_events", get_all_new_events_txn)
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
"forward_ex_outliers", "backward_ex_outliers",
"state_resets"
])

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 30 SCHEMA_VERSION = 31
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -53,6 +53,9 @@ class UpgradeDatabaseException(PrepareDatabaseException):
def prepare_database(db_conn, database_engine, config): def prepare_database(db_conn, database_engine, config):
"""Prepares a database for usage. Will either create all necessary tables """Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version. or upgrade from an older schema version.
If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty.
""" """
try: try:
cur = db_conn.cursor() cur = db_conn.cursor()
@ -60,13 +63,18 @@ def prepare_database(db_conn, database_engine, config):
if version_info: if version_info:
user_version, delta_files, upgraded = version_info user_version, delta_files, upgraded = version_info
if config is None:
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
raise UpgradeDatabaseException("Database needs to be upgraded")
else:
_upgrade_existing_database( _upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config cur, user_version, delta_files, upgraded, database_engine, config
) )
else: else:
_setup_new_database(cur, database_engine, config) _setup_new_database(cur, database_engine)
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close() cur.close()
db_conn.commit() db_conn.commit()
@ -75,7 +83,7 @@ def prepare_database(db_conn, database_engine, config):
raise raise
def _setup_new_database(cur, database_engine, config): def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then """Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas. applying any necessary deltas.
@ -148,12 +156,13 @@ def _setup_new_database(cur, database_engine, config):
applied_delta_files=[], applied_delta_files=[],
upgraded=False, upgraded=False,
database_engine=database_engine, database_engine=database_engine,
config=config, config=None,
is_empty=True,
) )
def _upgrade_existing_database(cur, current_version, applied_delta_files, def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded, database_engine, config): upgraded, database_engine, config, is_empty=False):
"""Upgrades an existing database. """Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules Delta files can either be SQL stored in *.sql files, or python modules
@ -246,6 +255,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file module_name, absolute_path, python_file
) )
logger.debug("Running script %s", relative_path) logger.debug("Running script %s", relative_path)
module.run_create(cur, database_engine)
if not is_empty:
module.run_upgrade(cur, database_engine, config=config) module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc": elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've # Sometimes .pyc files turn up anyway even though we've
@ -361,36 +372,3 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded
return None return None
def prepare_sqlite3_database(db_conn):
"""This function should be called before `prepare_database` on sqlite3
databases.
Since we changed the way we store the current schema version and handle
updates to schemas, we need a way to upgrade from the old method to the
new. This only affects sqlite databases since they were the only ones
supported at the time.
"""
with db_conn:
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
db_conn.executescript(create_schema)
c = db_conn.execute("SELECT * FROM schema_version")
rows = c.fetchall()
c.close()
if not rows:
c = db_conn.execute("PRAGMA user_version")
row = c.fetchone()
c.close()
if row and row[0]:
db_conn.execute(
"REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)

View file

@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore):
self._update_presence_txn, stream_orderings, presence_states, self._update_presence_txn, stream_orderings, presence_states,
) )
defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) defer.returnValue((
stream_orderings[-1], self._presence_id_gen.get_current_token()
))
def _update_presence_txn(self, txn, stream_orderings, presence_states): def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states):
@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore):
defer.returnValue([UserPresenceState(**row) for row in rows]) defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_max_token() return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(

View file

@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore):
"""Get the position of the push rules stream. """Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to.""" room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token() return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id): def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):

View file

@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
def get_pushers_stream_token(self): def get_pushers_stream_token(self):
return self._pushers_id_gen.get_max_token() return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers(self, last_id, current_id, limit):
def get_all_updated_pushers_txn(txn): def get_all_updated_pushers_txn(txn):

View file

@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
) )
@cached(num_args=2) @cached(num_args=2)
@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore):
"content": content, "content": content,
}]) }])
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", @cachedList(cached_method_name="get_linearized_receipts_for_room",
num_args=3, inlineCallbacks=True) list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids: if not room_ids:
defer.returnValue({}) defer.returnValue({})
@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token() return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id): user_id, event_id, data, stream_id):
@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data room_id, receipt_type, user_id, event_ids, data
) )
max_persisted_id = self._stream_id_gen.get_max_token() max_persisted_id = self._stream_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))

View file

@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1,
inlineCallbacks=True) inlineCallbacks=True)
def are_guests(self, user_ids): def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
@ -458,12 +458,15 @@ class RegistrationStore(SQLBaseStore):
""" """
Gets the 3pid's guest access token if exists, else saves access_token. Gets the 3pid's guest access token if exists, else saves access_token.
:param medium (str): Medium of the 3pid. Must be "email". Args:
:param address (str): 3pid address. medium (str): Medium of the 3pid. Must be "email".
:param access_token (str): The access token to persist if none is address (str): 3pid address.
access_token (str): The access token to persist if none is
already persisted. already persisted.
:param inviter_user_id (str): User ID of the inviter. inviter_user_id (str): User ID of the inviter.
:return (deferred str): Whichever access token is persisted at the end
Returns:
deferred str: Whichever access token is persisted at the end
of this function call. of this function call.
""" """
def insert(txn): def insert(txn):

View file

@ -36,7 +36,7 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def _store_room_members_txn(self, txn, events): def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database. """Store a room member in the database.
""" """
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -62,6 +62,64 @@ class RoomMemberStore(SQLBaseStore):
self._membership_stream_cache.entity_has_changed, self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering
) )
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened.
# The only current event that can also be an outlier is if its an
# invite that has come in across federation.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_invite_from_remote()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
}
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(sql, (
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
))
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (
stream_ordering,
True,
room_id,
user_id,
))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -127,6 +185,24 @@ class RoomMemberStore(SQLBaseStore):
user_id, [Membership.INVITE] user_id, [Membership.INVITE]
) )
@defer.inlineCallbacks
def get_invite_for_user_in_room(self, user_id, room_id):
"""Gets the invite for the given user and room
Args:
user_id (str)
room_id (str)
Returns:
Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
"""
invites = yield self.get_invited_rooms_for_user(user_id)
for invite in invites:
if invite.room_id == room_id:
defer.returnValue(invite)
defer.returnValue(None)
def get_leave_and_ban_events_for_user(self, user_id): def get_leave_and_ban_events_for_user(self, user_id):
""" Get all the leave events for a user """ Get all the leave events for a user
Args: Args:
@ -163,6 +239,12 @@ class RoomMemberStore(SQLBaseStore):
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list): membership_list):
do_invite = Membership.INVITE in membership_list
membership_list = [m for m in membership_list if m != Membership.INVITE]
results = []
if membership_list:
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" OR ".join(["membership = ?" for _ in membership_list]), " OR ".join(["membership = ?" for _ in membership_list]),
) )
@ -183,10 +265,30 @@ class RoomMemberStore(SQLBaseStore):
) % (where_clause,) ) % (where_clause,)
txn.execute(sql, args) txn.execute(sql, args)
return [ results = [
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
if do_invite:
sql = (
"SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
" FROM local_invites as i"
" INNER JOIN events as e USING (event_id)"
" WHERE invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(sql, (user_id,))
results.extend(RoomsForUser(
room_id=r["room_id"],
sender=r["inviter"],
event_id=r["event_id"],
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
) for r in self.cursor_to_dict(txn))
return results
@cached(max_entries=5000) @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self.runInteraction( return self.runInteraction(

View file

@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, *args, **kwargs): def run_create(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex") cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall(): for row in cur.fetchall():
try: try:
@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs):
"UPDATE application_services_regex SET regex=? WHERE id=?", "UPDATE application_services_regex SET regex=? WHERE id=?",
(new_regex, row[0]) (new_regex, row[0])
) )
def run_upgrade(*args, **kwargs):
pass

View file

@ -27,7 +27,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...") logger.info("Porting pushers table...")
cur.execute(""" cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 ( CREATE TABLE IF NOT EXISTS pushers2 (
@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("DROP TABLE pushers") cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count) logger.info("Moved %d pushers to new table", count)
def run_upgrade(*args, **kwargs):
pass

View file

@ -43,7 +43,7 @@ SQLITE_TABLE = (
) )
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()): for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement) cur.execute(statement)
@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql) sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search", progress_json)) cur.execute(sql, ("event_search", progress_json))
def run_upgrade(*args, **kwargs):
pass

View file

@ -27,7 +27,7 @@ ALTER_TABLE = (
) )
def run_upgrade(cur, database_engine, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()): for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement) cur.execute(statement)
@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql) sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json)) cur.execute(sql, ("event_origin_server_ts", progress_json))
def run_upgrade(*args, **kwargs):
pass

View file

@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, config, *args, **kwargs): def run_create(cur, database_engine, *args, **kwargs):
# NULL indicates user was not registered by an appservice. # NULL indicates user was not registered by an appservice.
try: try:
cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
# Maybe we already added the column? Hope so... # Maybe we already added the column? Hope so...
pass pass
def run_upgrade(cur, database_engine, config, *args, **kwargs):
cur.execute("SELECT name FROM users") cur.execute("SELECT name FROM users")
rows = cur.fetchall() rows = cur.fetchall()

View file

@ -0,0 +1,38 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* The positions in the event stream_ordering when the current_state was
* replaced by the state at the event.
*/
CREATE TABLE IF NOT EXISTS current_state_resets(
event_stream_ordering BIGINT PRIMARY KEY NOT NULL
);
/* The outlier events that have aquired a state group typically through
* backfill. This is tracked separately to the events table, as assigning a
* state group change the position of the existing event in the stream
* ordering.
* However since a stream_ordering is assigned in persist_event for the
* (event, state) pair, we can use that stream_ordering to identify when
* the new state was assigned for the event.
*/
CREATE TABLE IF NOT EXISTS ex_outlier_stream(
event_stream_ordering BIGINT PRIMARY KEY NOT NULL,
event_id TEXT NOT NULL,
state_group BIGINT NOT NULL
);

View file

@ -0,0 +1,42 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE local_invites(
stream_id BIGINT NOT NULL,
inviter TEXT NOT NULL,
invitee TEXT NOT NULL,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
locally_rejected TEXT,
replaced_by TEXT
);
-- Insert all invites for local users into new `invites` table
INSERT INTO local_invites SELECT
stream_ordering as stream_id,
sender as inviter,
state_key as invitee,
event_id,
room_id,
NULL as locally_rejected,
NULL as replaced_by
FROM events
NATURAL JOIN current_state_events
NATURAL JOIN room_memberships
WHERE membership = 'invite' AND state_key IN (SELECT name FROM users);
CREATE INDEX local_invites_id ON local_invites(stream_id);
CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id);

View file

@ -249,11 +249,14 @@ class StateStore(SQLBaseStore):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned Args:
:param list[(str, str)]|None types: List of (type, state_key) tuples event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which which are used to filter the state fetched. May be None, which
matches any key matches any key
:return: a deferred dict from (type, state_key) -> state_event
Returns:
A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@ -270,8 +273,8 @@ class StateStore(SQLBaseStore):
desc="_get_state_group_for_event", desc="_get_state_group_for_event",
) )
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", @cachedList(cached_method_name="_get_state_group_for_event",
num_args=1, inlineCallbacks=True) list_name="event_ids", num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids): def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group """Returns mapping event_id -> state_group
""" """
@ -458,4 +461,4 @@ class StateStore(SQLBaseStore):
) )
def get_state_stream_token(self): def get_state_stream_token(self):
return self._state_groups_id_gen.get_max_token() return self._state_groups_id_gen.get_current_token()

View file

@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'): def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token() token = yield self._stream_id_gen.get_current_token()
if direction != 'b': if direction != 'b':
defer.returnValue("s%d" % (token,)) defer.returnValue("s%d" % (token,))
else: else:

View file

@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns: Returns:
A deferred int. A deferred int.
""" """
return self._account_data_id_gen.get_max_token() return self._account_data_id_gen.get_current_token()
@cached() @cached()
def get_tags_for_user(self, user_id): def get_tags_for_user(self, user_id):
@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(self, txn, user_id, room_id, next_id):

View file

@ -21,7 +21,7 @@ import threading
class IdGenerator(object): class IdGenerator(object):
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
self._lock = threading.Lock() self._lock = threading.Lock()
self._next_id = _load_max_id(db_conn, table, column) self._next_id = _load_current_id(db_conn, table, column)
def get_next(self): def get_next(self):
with self._lock: with self._lock:
@ -29,12 +29,16 @@ class IdGenerator(object):
return self._next_id return self._next_id
def _load_max_id(db_conn, table, column): def _load_current_id(db_conn, table, column, step=1):
cur = db_conn.cursor() cur = db_conn.cursor()
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone() val, = cur.fetchone()
cur.close() cur.close()
return int(val) if val else 1 current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator(object): class StreamIdGenerator(object):
@ -45,17 +49,32 @@ class StreamIdGenerator(object):
all ids less than or equal to it have completed. This handles the fact that all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order. persistence of events can complete out of order.
Args:
db_conn(connection): A database connection to use to fetch the
initial value of the generator from.
table(str): A database table to read the initial value of the id
generator from.
column(str): The column of the database table to read the initial
value from the id generator from.
extra_tables(list): List of pairs of database tables and columns to
use to source the initial value of the generator from. The value
with the largest magnitude is used.
step(int): which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards.
Usage: Usage:
with stream_id_gen.get_next() as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[]): def __init__(self, db_conn, table, column, extra_tables=[], step=1):
assert step != 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column) self._step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables: for table, column in extra_tables:
self._current_max = max( self._current = (max if step > 0 else min)(
self._current_max, self._current,
_load_max_id(db_conn, table, column) _load_current_id(db_conn, table, column, step)
) )
self._unfinished_ids = deque() self._unfinished_ids = deque()
@ -66,8 +85,8 @@ class StreamIdGenerator(object):
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
self._current_max += 1 self._current += self._step
next_id = self._current_max next_id = self._current
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -88,8 +107,12 @@ class StreamIdGenerator(object):
# ... persist events ... # ... persist events ...
""" """
with self._lock: with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1) next_ids = range(
self._current_max += n self._current + self._step,
self._current + self._step * (n + 1),
self._step
)
self._current += n
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -105,15 +128,15 @@ class StreamIdGenerator(object):
return manager() return manager()
def get_max_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
return self._unfinished_ids[0] - 1 return self._unfinished_ids[0] - self._step
return self._current_max return self._current
class ChainedIdGenerator(object): class ChainedIdGenerator(object):
@ -125,7 +148,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column): def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator self.chained_generator = chained_generator
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column) self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque() self._unfinished_ids = deque()
def get_next(self): def get_next(self):
@ -137,7 +160,7 @@ class ChainedIdGenerator(object):
with self._lock: with self._lock:
self._current_max += 1 self._current_max += 1
next_id = self._current_max next_id = self._current_max
chained_id = self.chained_generator.get_max_token() chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id)) self._unfinished_ids.append((next_id, chained_id))
@ -151,7 +174,7 @@ class ChainedIdGenerator(object):
return manager() return manager()
def get_max_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
@ -160,4 +183,4 @@ class ChainedIdGenerator(object):
stream_id, chained_id = self._unfinished_ids[0] stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id) return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token()) return (self._current_max, self.chained_generator.get_current_token())

View file

@ -16,7 +16,8 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from .logcontext import PreserveLoggingContext from .logcontext import PreserveLoggingContext, preserve_fn
from synapse.util import unwrapFirstError
@defer.inlineCallbacks @defer.inlineCallbacks
@ -107,3 +108,32 @@ class ObservableDeferred(object):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred, id(self), self._result, self._deferred,
) )
def concurrently_execute(func, args, limit):
"""Executes the function with each argument conncurrently while limiting
the number of concurrent executions.
Args:
func (func): Function to execute, should return a deferred.
args (list): List of arguments to pass to func, each invocation of func
gets a signle argument.
limit (int): Maximum number of conccurent executions.
Returns:
deferred: Resolved when all function invocations have finished.
"""
it = iter(args)
@defer.inlineCallbacks
def _concurrently_execute_inner():
try:
while True:
yield func(it.next())
except StopIteration:
pass
return defer.gatherResults([
preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit)
], consumeErrors=True).addErrback(unwrapFirstError)

View file

@ -167,7 +167,8 @@ class CacheDescriptor(object):
% (orig.__name__,) % (orig.__name__,)
) )
self.cache = Cache( def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
@ -175,14 +176,12 @@ class CacheDescriptor(object):
tree=self.tree, tree=self.tree,
) )
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try: try:
cached_result_d = self.cache.get(cache_key) cached_result_d = cache.get(cache_key)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -204,7 +203,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = self.cache.sequence sequence = cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn, preserve_context_over_fn,
@ -213,20 +212,21 @@ class CacheDescriptor(object):
) )
def onErr(f): def onErr(f):
self.cache.invalidate(cache_key) cache.invalidate(cache_key)
return f return f
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret) cache.update(sequence, cache_key, ret)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = self.cache.invalidate_many wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = self.cache.prefill wrapped.prefill = cache.prefill
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped
@ -240,11 +240,12 @@ class CacheListDescriptor(object):
the list of missing keys to the wrapped fucntion. the list of missing keys to the wrapped fucntion.
""" """
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): def __init__(self, orig, cached_method_name, list_name, num_args=1,
inlineCallbacks=False):
""" """
Args: Args:
orig (function) orig (function)
cache (Cache) method_name (str); The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list list_name (str): Name of the argument which is the bulk lookup list
num_args (int) num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should inlineCallbacks (bool): Whether orig is a generator that should
@ -263,7 +264,7 @@ class CacheListDescriptor(object):
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name) self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache self.cached_method_name = cached_method_name
self.sentinel = object() self.sentinel = object()
@ -277,11 +278,13 @@ class CacheListDescriptor(object):
if self.list_name not in self.arg_names: if self.list_name not in self.arg_names:
raise Exception( raise Exception(
"Couldn't see arguments %r for %r." "Couldn't see arguments %r for %r."
% (self.list_name, cache.name,) % (self.list_name, cached_method_name,)
) )
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cache = getattr(obj, self.cached_method_name).cache
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
@ -297,14 +300,14 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg key[self.list_pos] = arg
try: try:
res = self.cache.get(tuple(key)).observe() res = cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res cached[arg] = res
except KeyError: except KeyError:
missing.append(arg) missing.append(arg)
if missing: if missing:
sequence = self.cache.sequence sequence = cache.sequence
args_to_call = dict(arg_dict) args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
@ -327,10 +330,10 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer) cache.update(sequence, tuple(key), observer)
def invalidate(f, key): def invalidate(f, key):
self.cache.invalidate(key) cache.invalidate(key)
return f return f
observer.addErrback(invalidate, tuple(key)) observer.addErrback(invalidate, tuple(key))
@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
) )
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`. """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument Used to do batch lookups for an already created cache. A single argument
@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
""" """
return lambda orig: CacheListDescriptor( return lambda orig: CacheListDescriptor(
orig, orig,
cache=cache, cached_method_name=cached_method_name,
list_name=list_name, list_name=list_name,
num_args=num_args, num_args=num_args,
inlineCallbacks=inlineCallbacks, inlineCallbacks=inlineCallbacks,

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,57 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.replication.slave.storage.events import SlavedEventStore
from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource
class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
"blue",
http_client=None,
replication_layer=Mock(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
)
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore()
self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs)
self.event_id = 0
@defer.inlineCallbacks
def replicate(self):
streams = self.slaved_store.stream_positions()
result = yield self.replication.replicate(streams, 100)
yield self.slaved_store.process_replication(result)
@defer.inlineCallbacks
def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args)
self.assertEqual(master_result, slaved_result)
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)

View file

@ -0,0 +1,169 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStoreTestCase
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.storage.roommember import RoomsForUser
from twisted.internet import defer
USER_ID = "@feeling:blue"
USER_ID_2 = "@bright:blue"
OUTLIER = {"outlier": True}
ROOM_ID = "!room:blue"
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks
def test_room_name_and_aliases(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.persist(type="m.room.name", key="", name="name1")
yield self.persist(
type="m.room.aliases", key="blue", aliases=["#1:blue"]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"])
)
# Set the room name.
yield self.persist(type="m.room.name", key="", name="name2")
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"])
)
# Set the room aliases.
yield self.persist(
type="m.room.aliases", key="blue", aliases=["#2:blue"]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"])
)
# Leave and join the room clobbering the state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_room_name_and_aliases", (ROOM_ID,), (None, [])
)
@defer.inlineCallbacks
def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Join the room.
join = yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
# Leave the room.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Add some other user to the room.
join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
# Join the room clobbering the state.
# This should remove any evidence of the other user being in the room.
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
yield self.check("get_rooms_for_user", (USER_ID_2,), [])
event_id = 0
@defer.inlineCallbacks
def persist(
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
state=None, reset_state=False, backfill=False,
depth=None, prev_events=[], auth_events=[], prev_state=[],
**content
):
"""
Returns:
synapse.events.FrozenEvent: The event that was persisted.
"""
if depth is None:
depth = self.event_id
event_dict = {
"sender": sender,
"type": type,
"content": content,
"event_id": "$%d:blue" % (self.event_id,),
"room_id": room_id,
"depth": depth,
"origin_server_ts": self.event_id,
"prev_events": prev_events,
"auth_events": auth_events,
}
if key is not None:
event_dict["state_key"] = key
event_dict["prev_state"] = prev_state
event = FrozenEvent(event_dict, internal_metadata_dict=internal)
self.event_id += 1
context = EventContext(current_state=state)
ordering = None
if backfill:
yield self.master_store.persist_events(
[(event, context)], backfilled=True
)
else:
ordering, _ = yield self.master_store.persist_event(
event, context, current_state=reset_state
)
if ordering:
event.internal_metadata.stream_ordering = ordering
defer.returnValue(event)

View file

@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other, # set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server # expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]: for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=404) yield self.join(room=room, user=usr, expect_code=403)
yield self.leave(room=room, user=usr, expect_code=404) yield self.leave(room=room, user=usr, expect_code=403)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_membership_private_room_perms(self): def test_membership_private_room_perms(self):

View file

@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
"test", "test",
db_pool=self.db_pool, db_pool=self.db_pool,
config=config, config=config,
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
) )
self.datastore = SQLBaseStore(hs) self.datastore = SQLBaseStore(hs)

View file

@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer( hs = HomeServer(
name, db_pool=db_pool, config=config, name, db_pool=db_pool, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
get_db_conn=db_pool.get_db_conn, get_db_conn=db_pool.get_db_conn,
**kargs **kargs
) )
@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config), database_engine=create_engine(config.database_config),
**kargs **kargs
) )
@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
return conn return conn
def create_engine(self): def create_engine(self):
return create_engine(self.config) return create_engine(self.config.database_config)
class MemoryDataStore(object): class MemoryDataStore(object):