Merge branch 'develop' into sh-cas-auth-via-homeserver

This commit is contained in:
Steven Hammerton 2015-11-17 10:55:41 +00:00
commit f5e25c5f35
39 changed files with 1266 additions and 290 deletions

View file

@ -1,3 +1,19 @@
Changes in synapse v0.11.0-rc1 (2015-11-11)
===========================================
* Add Search API (PR #307, #324, #327, #336, #350, #359)
* Add 'archived' state to v2 /sync API (PR #316)
* Add ability to reject invites (PR #317)
* Add config option to disable password login (PR #322)
* Add the login fallback API (PR #330)
* Add room context API (PR #334)
* Add room tagging support (PR #335)
* Update v2 /sync API to match spec (PR #305, #316, #321, #332, #337, #341)
* Change retry schedule for application services (PR #320)
* Change retry schedule for remote servers (PR #340)
* Fix bug where we hosted static content in the incorrect place (PR #329)
* Fix bug where we didn't increment retry interval for remote servers (PR #343)
Changes in synapse v0.10.1-rc1 (2015-10-15) Changes in synapse v0.10.1-rc1 (2015-10-15)
=========================================== ===========================================

View file

@ -20,4 +20,6 @@ recursive-include synapse/static *.gif
recursive-include synapse/static *.html recursive-include synapse/static *.html
recursive-include synapse/static *.js recursive-include synapse/static *.js
exclude jenkins.sh
prune demo/etc prune demo/etc

View file

@ -20,8 +20,8 @@ The overall architecture is::
https://somewhere.org/_matrix https://elsewhere.net/_matrix https://somewhere.org/_matrix https://elsewhere.net/_matrix
``#matrix:matrix.org`` is the official support room for Matrix, and can be ``#matrix:matrix.org`` is the official support room for Matrix, and can be
accessed by the web client at http://matrix.org/beta or via an IRC bridge at accessed by any client from https://matrix.org/blog/try-matrix-now or via IRC
irc://irc.freenode.net/matrix. bridge at irc://irc.freenode.net/matrix.
Synapse is currently in rapid development, but as of version 0.5 we believe it Synapse is currently in rapid development, but as of version 0.5 we believe it
is sufficiently stable to be run as an internet-facing service for real usage! is sufficiently stable to be run as an internet-facing service for real usage!
@ -77,14 +77,14 @@ Meanwhile, iOS and Android SDKs and clients are available from:
- https://github.com/matrix-org/matrix-android-sdk - https://github.com/matrix-org/matrix-android-sdk
We'd like to invite you to join #matrix:matrix.org (via We'd like to invite you to join #matrix:matrix.org (via
https://matrix.org/beta), run a homeserver, take a look at the Matrix spec at https://matrix.org/blog/try-matrix-now), run a homeserver, take a look at the
https://matrix.org/docs/spec and API docs at https://matrix.org/docs/api, Matrix spec at https://matrix.org/docs/spec and API docs at
experiment with the APIs and the demo clients, and report any bugs via https://matrix.org/docs/api, experiment with the APIs and the demo clients, and
https://matrix.org/jira. report any bugs via https://matrix.org/jira.
Thanks for using Matrix! Thanks for using Matrix!
[1] End-to-end encryption is currently in development [1] End-to-end encryption is currently in development - see https://matrix.org/git/olm
Synapse Installation Synapse Installation
==================== ====================

4
jenkins.sh Executable file
View file

@ -0,0 +1,4 @@
#!/bin/bash -eu
export PYTHONDONTWRITEBYTECODE=yep
TOXSUFFIX="--reporter=subunit | subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml" tox

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.10.1-rc1" __version__ = "0.11.0-rc1"

View file

@ -68,6 +68,7 @@ class EventTypes(object):
RoomHistoryVisibility = "m.room.history_visibility" RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias" CanonicalAlias = "m.room.canonical_alias"
RoomAvatar = "m.room.avatar" RoomAvatar = "m.room.avatar"
GuestAccess = "m.room.guest_access"
# These are used for validation # These are used for validation
Message = "m.room.message" Message = "m.room.message"

View file

@ -439,6 +439,7 @@ def setup(config_options):
hs.get_pusherpool().start() hs.get_pusherpool().start()
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
return hs return hs

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import errno
import os import os
import yaml import yaml
import sys import sys
@ -91,8 +92,11 @@ class Config(object):
@classmethod @classmethod
def ensure_directory(cls, dir_path): def ensure_directory(cls, dir_path):
dir_path = cls.abspath(dir_path) dir_path = cls.abspath(dir_path)
if not os.path.exists(dir_path): try:
os.makedirs(dir_path) os.makedirs(dir_path)
except OSError, e:
if e.errno != errno.EEXIST:
raise
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):
raise ConfigError( raise ConfigError(
"%s is not a directory" % (dir_path,) "%s is not a directory" % (dir_path,)

View file

@ -357,7 +357,8 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_membership_event(self, destinations, room_id, user_id, membership): def make_membership_event(self, destinations, room_id, user_id, membership,
content={},):
""" """
Creates an m.room.member event, with context, without participating in the room. Creates an m.room.member event, with context, without participating in the room.
@ -398,6 +399,14 @@ class FederationClient(FederationBase):
logger.debug("Got response to make_%s: %s", membership, pdu_dict) logger.debug("Got response to make_%s: %s", membership, pdu_dict)
pdu_dict["content"].update(content)
# The protoevent received over the JSON wire may not have all
# the required fields. Lets just gloss over that because
# there's some we never care about
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
defer.returnValue( defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
) )

View file

@ -29,6 +29,12 @@ logger = logging.getLogger(__name__)
class BaseHandler(object): class BaseHandler(object):
"""
Common base class for the event handlers.
:type store: synapse.storage.events.StateStore
:type state_handler: synapse.state.StateHandler
"""
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -175,6 +181,8 @@ class BaseHandler(object):
if not suppress_auth: if not suppress_auth:
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None) room_alias_str = event.content.get("alias", None)
@ -282,3 +290,58 @@ class BaseHandler(object):
federation_handler.handle_new_event( federation_handler.handle_new_event(
event, destinations=destinations, event, destinations=destinations,
) )
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, current_state):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
yield self.kick_guest_users(current_state)
@defer.inlineCallbacks
def kick_guest_users(self, current_state):
for member_event in current_state:
try:
if member_event.type != EventTypes.Member:
continue
if not self.hs.is_mine(UserID.from_string(member_event.state_key)):
continue
if member_event.content["membership"] not in {
Membership.JOIN,
Membership.INVITE
}:
continue
if (
"kind" not in member_event.content
or member_event.content["kind"] != "guest"
):
continue
# We make the user choose to leave, rather than have the
# event-sender kick them. This is partially because we don't
# need to worry about power levels, and partially because guest
# users are a concept which doesn't hugely work over federation,
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
message_handler = self.hs.get_handlers().message_handler
yield message_handler.create_and_send_event(
{
"type": EventTypes.Member,
"state_key": member_event.state_key,
"content": {
"membership": Membership.LEAVE,
"kind": "guest"
},
"room_id": member_event.room_id,
"sender": member_event.state_key
},
ratelimit=False,
)
except Exception as e:
logger.warn("Error kicking guest user: %s" % (e,))

View file

@ -564,7 +564,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee): def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
server `target_host`. server `target_host`.
@ -584,7 +584,8 @@ class FederationHandler(BaseHandler):
target_hosts, target_hosts,
room_id, room_id,
joinee, joinee,
"join" "join",
content,
) )
self.room_queues[room_id] = [] self.room_queues[room_id] = []
@ -840,12 +841,14 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership): def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},):
origin, pdu = yield self.replication_layer.make_membership_event( origin, pdu = yield self.replication_layer.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
membership membership,
content,
) )
logger.debug("Got response to make_%s: %s", membership, pdu) logger.debug("Got response to make_%s: %s", membership, pdu)
@ -1097,8 +1100,6 @@ class FederationHandler(BaseHandler):
context = yield self._prep_event( context = yield self._prep_event(
origin, event, origin, event,
state=state, state=state,
backfilled=backfilled,
current_state=current_state,
auth_events=auth_events, auth_events=auth_events,
) )
@ -1121,7 +1122,6 @@ class FederationHandler(BaseHandler):
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
backfilled=backfilled,
auth_events=ev_info.get("auth_events"), auth_events=ev_info.get("auth_events"),
) )
for ev_info in event_infos for ev_info in event_infos
@ -1208,8 +1208,7 @@ class FederationHandler(BaseHandler):
defer.returnValue((event_stream_id, max_stream_id)) defer.returnValue((event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False, def _prep_event(self, origin, event, state=None, auth_events=None):
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
@ -1242,6 +1241,10 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess:
full_context = yield self.store.get_current_state(room_id=event.room_id)
yield self.maybe_kick_guest_users(event, full_context)
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -167,7 +167,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None): token_id=None, txn_id=None, is_guest=False):
""" Given a dict from a client, create and handle a new event. """ Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events, Creates an FrozenEvent object, filling out auth_events, prev_events,
@ -213,7 +213,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context) yield member_handler.change_membership(event, context, is_guest=is_guest)
else: else:
yield self.handle_new_client_event( yield self.handle_new_client_event(
event=event, event=event,
@ -258,20 +258,30 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id, is_guest): def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
if is_guest: try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
return
except AuthError, auth_error:
visibility = yield self.state_handler.get_current_state( visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
if visibility.content["history_visibility"] == "world_readable": if (
visibility and
visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None)) defer.returnValue((Membership.JOIN, None))
return return
else: if not is_guest:
raise AuthError( raise auth_error
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN raise AuthError(
) 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
else: )
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False): def get_state_events(self, user_id, room_id, is_guest=False):
@ -456,7 +466,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None): def room_initial_sync(self, user_id, room_id, pagin_config=None, is_guest=False):
"""Capture the a snapshot of a room. If user is currently a member of """Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left. the room this will be what was in the room when they left.
@ -473,15 +483,19 @@ class MessageHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room. A JSON serialisable dict with the snapshot of the room.
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id,
user_id,
is_guest
)
if member_event.membership == Membership.JOIN: if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined( result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, member_event user_id, room_id, pagin_config, membership, is_guest
) )
elif member_event.membership == Membership.LEAVE: elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted( result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, member_event user_id, room_id, pagin_config, membership, member_event_id, is_guest
) )
private_user_data = [] private_user_data = []
@ -497,19 +511,19 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config, def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
member_event): membership, member_event_id, is_guest):
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event.event_id], None [member_event_id], None
) )
room_state = room_state[member_event.event_id] room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None
if limit is None: if limit is None:
limit = 10 limit = 10
stream_token = yield self.store.get_stream_token_for_event( stream_token = yield self.store.get_stream_token_for_event(
member_event.event_id member_event_id
) )
messages, token = yield self.store.get_recent_events_for_room( messages, token = yield self.store.get_recent_events_for_room(
@ -519,7 +533,7 @@ class MessageHandler(BaseHandler):
) )
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, messages user_id, messages, is_guest=is_guest
) )
start_token = StreamToken(token[0], 0, 0, 0, 0) start_token = StreamToken(token[0], 0, 0, 0, 0)
@ -528,7 +542,7 @@ class MessageHandler(BaseHandler):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
defer.returnValue({ defer.returnValue({
"membership": member_event.membership, "membership": membership,
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": [serialize_event(m, time_now) for m in messages], "chunk": [serialize_event(m, time_now) for m in messages],
@ -542,7 +556,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config, def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
member_event): membership, is_guest):
current_state = yield self.state.get_current_state( current_state = yield self.state.get_current_state(
room_id=room_id, room_id=room_id,
) )
@ -574,12 +588,14 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
states = yield presence_handler.get_states( states = {}
target_users=[UserID.from_string(m.user_id) for m in room_members], if not is_guest:
auth_user=auth_user, states = yield presence_handler.get_states(
as_event=True, target_users=[UserID.from_string(m.user_id) for m in room_members],
check_auth=False, auth_user=auth_user,
) as_event=True,
check_auth=False,
)
defer.returnValue(states.values()) defer.returnValue(states.values())
@ -599,7 +615,7 @@ class MessageHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, messages user_id, messages, is_guest=is_guest, require_all_visible_for_guests=False
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -607,8 +623,7 @@ class MessageHandler(BaseHandler):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
defer.returnValue({ ret = {
"membership": member_event.membership,
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": [serialize_event(m, time_now) for m in messages], "chunk": [serialize_event(m, time_now) for m in messages],
@ -618,4 +633,8 @@ class MessageHandler(BaseHandler):
"state": state, "state": state,
"presence": presence, "presence": presence,
"receipts": receipts, "receipts": receipts,
}) }
if not is_guest:
ret["membership"] = membership
defer.returnValue(ret)

View file

@ -950,7 +950,8 @@ class PresenceHandler(BaseHandler):
) )
while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS: while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS:
self._remote_offline_serials.pop() # remove the oldest self._remote_offline_serials.pop() # remove the oldest
del self._user_cachemap[user] if user in self._user_cachemap:
del self._user_cachemap[user]
else: else:
# Remove the user from remote_offline_serials now that they're # Remove the user from remote_offline_serials now that they're
# no longer offline # no longer offline

View file

@ -369,7 +369,7 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain) remotedomains.add(member.domain)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True): def change_membership(self, event, context, do_auth=True, is_guest=False):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
Args: Args:
@ -390,6 +390,20 @@ class RoomMemberHandler(BaseHandler):
# if this HS is not currently in the room, i.e. we have to do the # if this HS is not currently in the room, i.e. we have to do the
# invite/join dance. # invite/join dance.
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if is_guest:
guest_access = context.current_state.get(
(EventTypes.GuestAccess, ""),
None
)
is_guest_access_allowed = (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
if not is_guest_access_allowed:
raise AuthError(403, "Guest access not allowed")
yield self._do_join(event, context, do_auth=do_auth) yield self._do_join(event, context, do_auth=do_auth)
else: else:
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
@ -490,7 +504,8 @@ class RoomMemberHandler(BaseHandler):
yield handler.do_invite_join( yield handler.do_invite_join(
room_hosts, room_hosts,
room_id, room_id,
event.user_id event.user_id,
event.content,
) )
else: else:
logger.debug("Doing normal join") logger.debug("Doing normal join")
@ -582,7 +597,6 @@ class RoomMemberHandler(BaseHandler):
medium, medium,
address, address,
id_server, id_server,
display_name,
token_id, token_id,
txn_id txn_id
): ):
@ -609,7 +623,6 @@ class RoomMemberHandler(BaseHandler):
else: else:
yield self._make_and_store_3pid_invite( yield self._make_and_store_3pid_invite(
id_server, id_server,
display_name,
medium, medium,
address, address,
room_id, room_id,
@ -673,7 +686,6 @@ class RoomMemberHandler(BaseHandler):
def _make_and_store_3pid_invite( def _make_and_store_3pid_invite(
self, self,
id_server, id_server,
display_name,
medium, medium,
address, address,
room_id, room_id,
@ -681,7 +693,7 @@ class RoomMemberHandler(BaseHandler):
token_id, token_id,
txn_id txn_id
): ):
token, public_key, key_validity_url = ( token, public_key, key_validity_url, display_name = (
yield self._ask_id_server_for_third_party_invite( yield self._ask_id_server_for_third_party_invite(
id_server, id_server,
medium, medium,
@ -725,10 +737,11 @@ class RoomMemberHandler(BaseHandler):
# TODO: Check for success # TODO: Check for success
token = data["token"] token = data["token"]
public_key = data["public_key"] public_key = data["public_key"]
display_name = data["display_name"]
key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server, id_server_scheme, id_server,
) )
defer.returnValue((token, public_key, key_validity_url)) defer.returnValue((token, public_key, key_validity_url, display_name))
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
@ -753,7 +766,7 @@ class RoomListHandler(BaseHandler):
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit): def get_event_context(self, user, room_id, event_id, limit, is_guest):
"""Retrieves events, pagination tokens and state around a given event """Retrieves events, pagination tokens and state around a given event
in a room. in a room.
@ -777,11 +790,17 @@ class RoomContextHandler(BaseHandler):
) )
results["events_before"] = yield self._filter_events_for_client( results["events_before"] = yield self._filter_events_for_client(
user.to_string(), results["events_before"] user.to_string(),
results["events_before"],
is_guest=is_guest,
require_all_visible_for_guests=False
) )
results["events_after"] = yield self._filter_events_for_client( results["events_after"] = yield self._filter_events_for_client(
user.to_string(), results["events_after"] user.to_string(),
results["events_after"],
is_guest=is_guest,
require_all_visible_for_guests=False
) )
if results["events_after"]: if results["events_after"]:

View file

@ -47,9 +47,9 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", "room_id", # str
"timeline", "timeline", # TimelineBatch
"state", "state", # dict[(str, str), FrozenEvent]
"ephemeral", "ephemeral",
"private_user_data", "private_user_data",
])): ])):
@ -68,9 +68,9 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", "room_id", # str
"timeline", "timeline", # TimelineBatch
"state", "state", # dict[(str, str), FrozenEvent]
"private_user_data", "private_user_data",
])): ])):
__slots__ = [] __slots__ = []
@ -87,8 +87,8 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
"room_id", "room_id", # str
"invite", "invite", # FrozenEvent: the invite event
])): ])):
__slots__ = [] __slots__ = []
@ -254,15 +254,12 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token room_id, sync_config, now_token, since_token=timeline_since_token
) )
current_state = yield self.state_handler.get_current_state( current_state = yield self.get_state_at(room_id, now_token)
room_id
)
current_state_events = current_state.values()
defer.returnValue(JoinedSyncResult( defer.returnValue(JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=current_state_events, state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( private_user_data=self.private_user_data_for_room(
room_id, tags_by_room room_id, tags_by_room
@ -272,7 +269,7 @@ class SyncHandler(BaseHandler):
def private_user_data_for_room(self, room_id, tags_by_room): def private_user_data_for_room(self, room_id, tags_by_room):
private_user_data = [] private_user_data = []
tags = tags_by_room.get(room_id) tags = tags_by_room.get(room_id)
if tags: if tags is not None:
private_user_data.append({ private_user_data.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
@ -311,8 +308,13 @@ class SyncHandler(BaseHandler):
ephemeral_by_room = {} ephemeral_by_room = {}
for event in typing: for event in typing:
room_id = event.pop("room_id") # we want to exclude the room_id from the event, but modifying the
ephemeral_by_room.setdefault(room_id, []).append(event) # result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
event_copy = {k: v for (k, v) in event.iteritems()
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else "0"
@ -328,8 +330,11 @@ class SyncHandler(BaseHandler):
now_token = now_token.copy_and_replace("receipt_key", receipt_key) now_token = now_token.copy_and_replace("receipt_key", receipt_key)
for event in receipts: for event in receipts:
room_id = event.pop("room_id") room_id = event["room_id"]
ephemeral_by_room.setdefault(room_id, []).append(event) # exclude room id, as above
event_copy = {k: v for (k, v) in event.iteritems()
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
defer.returnValue((now_token, ephemeral_by_room)) defer.returnValue((now_token, ephemeral_by_room))
@ -346,14 +351,12 @@ class SyncHandler(BaseHandler):
room_id, sync_config, leave_token, since_token=timeline_since_token room_id, sync_config, leave_token, since_token=timeline_since_token
) )
leave_state = yield self.store.get_state_for_events( leave_state = yield self.store.get_state_for_event(leave_event_id)
[leave_event_id], None
)
defer.returnValue(ArchivedSyncResult( defer.returnValue(ArchivedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=leave_state[leave_event_id].values(), state=leave_state,
private_user_data=self.private_user_data_for_room( private_user_data=self.private_user_data_for_room(
room_id, tags_by_room room_id, tags_by_room
), ),
@ -417,6 +420,9 @@ class SyncHandler(BaseHandler):
if len(room_events) <= timeline_limit: if len(room_events) <= timeline_limit:
# There is no gap in any of the rooms. Therefore we can just # There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them. # partition the new events by room and return them.
logger.debug("Got %i events for incremental sync - not limited",
len(room_events))
invite_events = [] invite_events = []
leave_events = [] leave_events = []
events_by_room_id = {} events_by_room_id = {}
@ -432,7 +438,12 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids: for room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) recents = events_by_room_id.get(room_id, [])
state = [event for event in recents if event.is_state()] logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
if recents: if recents:
prev_batch = now_token.copy_and_replace( prev_batch = now_token.copy_and_replace(
"room_key", recents[0].internal_metadata.before "room_key", recents[0].internal_metadata.before
@ -440,9 +451,13 @@ class SyncHandler(BaseHandler):
else: else:
prev_batch = now_token prev_batch = now_token
state, limited = yield self.check_joined_room( just_joined = yield self.check_joined_room(sync_config, state)
sync_config, room_id, state if just_joined:
) logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
@ -457,10 +472,15 @@ class SyncHandler(BaseHandler):
room_id, tags_by_room room_id, tags_by_room
), ),
) )
logger.debug("Result for room %s: %r", room_id, room_sync)
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
else: else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user( invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -499,6 +519,9 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None):
"""
:returns a Deferred TimelineBatch
"""
limited = True limited = True
recents = [] recents = []
filtering_factor = 2 filtering_factor = 2
@ -550,6 +573,8 @@ class SyncHandler(BaseHandler):
Returns: Returns:
A Deferred JoinedSyncResult A Deferred JoinedSyncResult
""" """
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed. # TODO(mjark): Check for redactions we might have missed.
@ -559,31 +584,26 @@ class SyncHandler(BaseHandler):
logging.debug("Recents %r", batch) logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a current_state = yield self.get_state_at(room_id, now_token)
# token to indicate what point in the stream this is
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
state_at_previous_sync = yield self.get_state_at_previous_sync( state_at_previous_sync = yield self.get_state_at(
room_id, since_token=since_token room_id, stream_position=since_token
) )
state_events_delta = yield self.compute_state_delta( state = yield self.compute_state_delta(
since_token=since_token, since_token=since_token,
previous_state=state_at_previous_sync, previous_state=state_at_previous_sync,
current_state=current_state_events, current_state=current_state,
) )
state_events_delta, _ = yield self.check_joined_room( just_joined = yield self.check_joined_room(sync_config, state)
sync_config, room_id, state_events_delta if just_joined:
) state = yield self.get_state_at(room_id, now_token)
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( private_user_data=self.private_user_data_for_room(
room_id, tags_by_room room_id, tags_by_room
@ -615,16 +635,12 @@ class SyncHandler(BaseHandler):
logging.debug("Recents %r", batch) logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a state_events_at_leave = yield self.store.get_state_for_event(
# token to indicate what point in the stream this is leave_event.event_id
leave_state = yield self.store.get_state_for_events(
[leave_event.event_id], None
) )
state_events_at_leave = leave_state[leave_event.event_id].values() state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token
state_at_previous_sync = yield self.get_state_at_previous_sync(
leave_event.room_id, since_token=since_token
) )
state_events_delta = yield self.compute_state_delta( state_events_delta = yield self.compute_state_delta(
@ -647,60 +663,77 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token): def get_state_after_event(self, event):
""" Get the room state at the previous sync the client made. """
Returns: Get the room state after the given event
A Deferred list of Events.
:param synapse.events.EventBase event: event of interest
:return: A Deferred map from ((type, state_key)->Event)
"""
state = yield self.store.get_state_for_event(event.event_id)
if event.is_state():
state = state.copy()
state[(event.type, event.state_key)] = event
defer.returnValue(state)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state
:returns: A Deferred map from ((type, state_key)->Event)
""" """
last_events, token = yield self.store.get_recent_events_for_room( last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=since_token.room_key, limit=1, room_id, end_token=stream_position.room_key, limit=1,
) )
if last_events: if last_events:
last_event = last_events[0] last_event = last_events[-1]
last_context = yield self.state_handler.compute_event_context( state = yield self.get_state_after_event(last_event)
last_event
)
if last_event.is_state():
state = [last_event] + last_context.current_state.values()
else:
state = last_context.current_state.values()
else: else:
state = () # no events in this room - so presumably no state
state = {}
defer.returnValue(state) defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state): def compute_state_delta(self, since_token, previous_state, current_state):
""" Works out the differnce in state between the current state and the """ Works out the differnce in state between the current state and the
state the client got when it last performed a sync. state the client got when it last performed a sync.
Returns:
A list of events. :param str since_token: the point we are comparing against
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
state to compare to
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the
new state
:returns A new event dictionary
""" """
# TODO(mjark) Check if the state events were received by the server # TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state # after the previous sync, since we need to include those state
# updates even if they occured logically before the previous event. # updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
previous_dict = {event.event_id: event for event in previous_state}
state_delta = [] state_delta = {}
for event in current_state: for key, event in current_state.iteritems():
if event.event_id not in previous_dict: if (key not in previous_state or
state_delta.append(event) previous_state[key].event_id != event.event_id):
state_delta[key] = event
return state_delta return state_delta
@defer.inlineCallbacks def check_joined_room(self, sync_config, state_delta):
def check_joined_room(self, sync_config, room_id, state_delta): """
joined = False Check if the user has just joined the given room (so should
limited = False be given the full state)
for event in state_delta:
if (
event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()
):
if event.content["membership"] == Membership.JOIN:
joined = True
if joined: :param sync_config:
res = yield self.state_handler.get_current_state(room_id) :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the
state_delta = res.values() difference in state since the last sync
limited = True
defer.returnValue((state_delta, limited)) :returns A deferred Tuple (state_delta, limited)
"""
join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None)
if join_event is not None:
if join_event.content["membership"] == Membership.JOIN:
return True
return False

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred from synapse.util.async import run_on_reactor, ObservableDeferred
@ -346,9 +348,9 @@ class Notifier(object):
room_ids = [] room_ids = []
if is_guest: if is_guest:
# TODO(daniel): Deal with non-room events too
only_room_events = True
if guest_room_id: if guest_room_id:
if not self._is_world_readable(guest_room_id):
raise AuthError(403, "Guest access not allowed")
room_ids = [guest_room_id] room_ids = [guest_room_id]
else: else:
rooms = yield self.store.get_rooms_for_user(user.to_string()) rooms = yield self.store.get_rooms_for_user(user.to_string())
@ -361,6 +363,7 @@ class Notifier(object):
events = [] events = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name keyname = "%s_key" % name
before_id = getattr(before_token, keyname) before_id = getattr(before_token, keyname)
@ -377,7 +380,7 @@ class Notifier(object):
room_ids=room_ids, room_ids=room_ids,
) )
if is_guest: if name == "room":
room_member_handler = self.hs.get_handlers().room_member_handler room_member_handler = self.hs.get_handlers().room_member_handler
new_events = yield room_member_handler._filter_events_for_client( new_events = yield room_member_handler._filter_events_for_client(
user.to_string(), user.to_string(),
@ -403,6 +406,17 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
state = yield self.hs.get_state_handler().get_current_state(
room_id,
EventTypes.RoomHistoryVisibility
)
if state and "history_visibility" in state.content:
defer.returnValue(state.content["history_visibility"] == "world_readable")
else:
defer.returnValue(False)
@log_function @log_function
def remove_expired_streams(self): def remove_expired_streams(self):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()

View file

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:

View file

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -220,7 +220,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None): def on_POST(self, request, room_identifier, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request) user, token_id, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
# the identifier could be a room alias or a room id. Try one then the # the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid # other if it fails to parse, without swallowing other valid
@ -242,16 +245,20 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
defer.returnValue((200, ret_dict)) defer.returnValue((200, ret_dict))
else: # room id else: # room id
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
content = {"membership": Membership.JOIN}
if is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": {"membership": Membership.JOIN}, "content": content,
"room_id": identifier.to_string(), "room_id": identifier.to_string(),
"sender": user.to_string(), "sender": user.to_string(),
"state_key": user.to_string(), "state_key": user.to_string(),
}, },
token_id=token_id, token_id=token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=is_guest,
) )
defer.returnValue((200, {"room_id": identifier.to_string()})) defer.returnValue((200, {"room_id": identifier.to_string()}))
@ -319,7 +326,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
})) }))
# TODO: Needs unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$") PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$")
@ -365,12 +372,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
is_guest=is_guest,
) )
defer.returnValue((200, content)) defer.returnValue((200, content))
@ -410,12 +418,12 @@ class RoomEventContext(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0]) limit = int(request.args.get("limit", [10])[0])
results = yield self.handlers.room_context_handler.get_event_context( results = yield self.handlers.room_context_handler.get_event_context(
user, room_id, event_id, limit, user, room_id, event_id, limit, is_guest
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -445,7 +453,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None): def on_POST(self, request, room_id, membership_action, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request) user, token_id, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
raise AuthError(403, "Guest access not allowed")
content = _parse_json(request) content = _parse_json(request)
@ -459,7 +473,6 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content["medium"], content["medium"],
content["address"], content["address"],
content["id_server"], content["id_server"],
content["display_name"],
token_id, token_id,
txn_id txn_id
) )
@ -479,22 +492,27 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
content = {"membership": unicode(membership_action)}
if is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": {"membership": unicode(membership_action)}, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": state_key, "state_key": state_key,
}, },
token_id=token_id, token_id=token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=is_guest,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content): def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address", "display_name"}: for key in {"id_server", "medium", "address"}:
if key not in content: if key not in content:
return False return False
return True return True

View file

@ -20,6 +20,7 @@ from synapse.http.servlet import (
) )
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events import FrozenEvent
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id, serialize_event, format_event_for_client_v2_without_event_id,
) )
@ -165,6 +166,20 @@ class SyncRestServlet(RestServlet):
return {"events": filter.filter_presence(formatted)} return {"events": filter.filter_presence(formatted)}
def encode_joined(self, rooms, filter, time_now, token_id): def encode_joined(self, rooms, filter, time_now, token_id):
"""
Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the joined rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
@ -174,6 +189,20 @@ class SyncRestServlet(RestServlet):
return joined return joined
def encode_invited(self, rooms, filter, time_now, token_id): def encode_invited(self, rooms, filter, time_now, token_id):
"""
Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
invited = {} invited = {}
for room in rooms: for room in rooms:
invite = serialize_event( invite = serialize_event(
@ -189,6 +218,20 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, filter, time_now, token_id): def encode_archived(self, rooms, filter, time_now, token_id):
"""
Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
@ -199,8 +242,28 @@ class SyncRestServlet(RestServlet):
@staticmethod @staticmethod
def encode_room(room, filter, time_now, token_id, joined=True): def encode_room(room, filter, time_now, token_id, joined=True):
"""
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a
single room
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:param joined: True if the user is joined to this room - will mean
we handle ephemeral events
:return: the room, encoded in our response format
:rtype: dict[str, object]
"""
event_map = {} event_map = {}
state_events = filter.filter_room_state(room.state) state_dict = room.state
timeline_events = filter.filter_room_timeline(room.timeline.events)
state_dict = SyncRestServlet._rollback_state_for_timeline(
state_dict, timeline_events)
state_events = filter.filter_room_state(state_dict.values())
state_event_ids = [] state_event_ids = []
for event in state_events: for event in state_events:
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.
@ -210,7 +273,6 @@ class SyncRestServlet(RestServlet):
) )
state_event_ids.append(event.event_id) state_event_ids.append(event.event_id)
timeline_events = filter.filter_room_timeline(room.timeline.events)
timeline_event_ids = [] timeline_event_ids = []
for event in timeline_events: for event in timeline_events:
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.
@ -241,6 +303,63 @@ class SyncRestServlet(RestServlet):
return result return result
@staticmethod
def _rollback_state_for_timeline(state, timeline):
"""
Wind the state dictionary backwards, so that it represents the
state at the start of the timeline, rather than at the end.
:param dict[(str, str), synapse.events.EventBase] state: the
state dictionary. Will be updated to the state before the timeline.
:param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary
"""
logger.debug("Processing state dict %r; timeline %r", state,
[e.get_dict() for e in timeline])
result = state.copy()
for timeline_event in reversed(timeline):
if not timeline_event.is_state():
continue
event_key = (timeline_event.type, timeline_event.state_key)
logger.debug("Considering %s for removal", event_key)
state_event = result.get(event_key)
if (state_event is None or
state_event.event_id != timeline_event.event_id):
# the event in the timeline isn't present in the state
# dictionary.
#
# the most likely cause for this is that there was a fork in
# the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore
# it for now, however.
logger.warn("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event)
continue
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
logger.debug("Replacing %s with %s in state dict",
timeline_event.event_id, prev_event_id)
if prev_event_id is None:
del result[event_key]
else:
result[event_key] = FrozenEvent({
"type": timeline_event.type,
"state_key": timeline_event.state_key,
"content": timeline_event.unsigned['prev_content'],
"sender": timeline_event.unsigned['prev_sender'],
"event_id": prev_event_id,
"room_id": timeline_event.room_id,
})
logger.debug("New value: %r", result.get(event_key))
return result
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server) SyncRestServlet(hs).register(http_server)

View file

@ -71,7 +71,7 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
""" Returns the current state for the room as a list. This is done by """ Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts. event graph and then resolving any of the state conflicts.
@ -80,6 +80,8 @@ class StateHandler(object):
If `event_type` is specified, then the method returns only the one If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`. event (or None) with that `event_type` and `state_key`.
:returns map from (type, state_key) to event
""" """
event_ids = yield self.store.get_latest_event_ids_in_room(room_id) event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
@ -177,9 +179,10 @@ class StateHandler(object):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
Return format is a tuple: (`state_group`, `state_events`), where the :returns a Deferred tuple of (`state_group`, `state`, `prev_state`).
first is the name of a state group if one and only one is involved, `state_group` is the name of a state group if one and only one is
otherwise `None`. involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
@ -255,6 +258,11 @@ class StateHandler(object):
return self._resolve_events(state_sets) return self._resolve_events(state_sets)
def _resolve_events(self, state_sets, event_type=None, state_key=""): def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
:returns a tuple (new_state, prev_states). new_state is a map
from (type, state_key) to event. prev_states is a list of event_ids.
:rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str])
"""
state = {} state = {}
for st in state_sets: for st in state_sets:
for e in st: for e in st:
@ -307,19 +315,23 @@ class StateHandler(object):
We resolve conflicts in the following order: We resolve conflicts in the following order:
1. power levels 1. power levels
2. memberships 2. join rules
3. other events. 3. memberships
4. other events.
""" """
resolved_state = {} resolved_state = {}
power_key = (EventTypes.PowerLevels, "") power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state.items(): if power_key in conflicted_state:
power_levels = conflicted_state[power_key] events = conflicted_state[power_key]
resolved_state[power_key] = self._resolve_auth_events(power_levels) logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = self._resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules: if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = self._resolve_auth_events( resolved_state[key] = self._resolve_auth_events(
events, events,
auth_events auth_events
@ -329,6 +341,7 @@ class StateHandler(object):
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = self._resolve_auth_events( resolved_state[key] = self._resolve_auth_events(
events, events,
auth_events auth_events
@ -338,6 +351,7 @@ class StateHandler(object):
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key not in resolved_state: if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = self._resolve_normal_events( resolved_state[key] = self._resolve_normal_events(
events, auth_events events, auth_events
) )

View file

@ -0,0 +1,256 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
import ujson as json
import logging
logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance(object):
"""Tracks the how long a background update is taking to update its items"""
def __init__(self, name):
self.name = name
self.total_item_count = 0
self.total_duration_ms = 0
self.avg_item_count = 0
self.avg_duration_ms = 0
def update(self, item_count, duration_ms):
"""Update the stats after doing an update"""
self.total_item_count += item_count
self.total_duration_ms += duration_ms
# Exponential moving averages for the number of items updated and
# the duration.
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
def average_items_per_ms(self):
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
"""
if self.total_item_count == 0:
return None
else:
# Use the exponential moving average so that we can adapt to
# changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms)
def total_items_per_ms(self):
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
"""
if self.total_item_count == 0:
return None
else:
return float(self.total_item_count) / float(self.total_duration_ms)
class BackgroundUpdateStore(SQLBaseStore):
""" Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
process and autotuning the batch size.
"""
MINIMUM_BACKGROUND_BATCH_SIZE = 100
DEFAULT_BACKGROUND_BATCH_SIZE = 100
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs):
super(BackgroundUpdateStore, self).__init__(hs)
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
self._background_update_timer = None
@defer.inlineCallbacks
def start_doing_background_updates(self):
while True:
if self._background_update_timer is not None:
return
sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
)
try:
yield sleep
finally:
self._background_update_timer = None
try:
result = yield self.do_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except:
logger.exception("Error doing update")
if result is None:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
)
return
@defer.inlineCallbacks
def do_background_update(self, desired_duration_ms):
"""Does some amount of work on a background update
Args:
desired_duration_ms(float): How long we want to spend
updating.
Returns:
A deferred that completes once some amount of work is done.
The deferred will have a value of None if there is currently
no more work to do.
"""
if not self._background_update_queue:
updates = yield self._simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name",),
)
for update in updates:
self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
defer.returnValue(None)
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name)
if performance is None:
performance = BackgroundUpdatePerformance(update_name)
self._background_update_performance[update_name] = performance
items_per_ms = performance.average_items_per_ms()
if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress
batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
progress_json = yield self._simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json"
)
progress = json.loads(progress_json)
time_start = self._clock.time_msec()
items_updated = yield update_handler(progress, batch_size)
time_stop = self._clock.time_msec()
duration_ms = time_stop - time_start
logger.info(
"Updating %r. Updated %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)",
update_name, items_updated, duration_ms,
performance.total_items_per_ms(),
performance.average_items_per_ms(),
performance.total_item_count,
)
performance.update(items_updated, duration_ms)
defer.returnValue(len(self._background_update_performance))
def register_background_update_handler(self, update_name, update_handler):
"""Register a handler for doing a background update.
The handler should take two arguments:
* A dict of the current progress
* An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated.
The hander is responsible for updating the progress of the update.
Args:
update_name(str): The name of the update that this code handles.
update_handler(function): The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
def start_background_update(self, update_name, progress):
"""Starts a background update running.
Args:
update_name: The update to set running.
progress: The initial state of the progress of the update.
Returns:
A deferred that completes once the task has been added to the
queue.
"""
# Clear the background update queue so that we will pick up the new
# task on the next iteration of do_background_update.
self._background_update_queue = []
progress_json = json.dumps(progress)
return self._simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json}
)
def _end_background_update(self, update_name):
"""Removes a completed background update task from the queue.
Args:
update_name(str): The name of the completed task to remove
Returns:
A deferred that completes once the task is removed.
"""
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
return self._simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
def _background_update_progress_txn(self, txn, update_name, progress):
"""Update the progress of a background update
Args:
txn(cursor): The transaction.
update_name(str): The name of the background update task
progress(dict): The progress of the update.
"""
progress_json = json.dumps(progress)
self._simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
updatevalues={"progress_json": progress_json},
)

View file

@ -313,6 +313,8 @@ class EventsStore(SQLBaseStore):
self._store_redaction(txn, event) self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility: elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event) self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
self._store_guest_access_txn(txn, event)
self._store_room_members_txn( self._store_room_members_txn(
txn, txn,
@ -829,7 +831,8 @@ class EventsStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
if prev: if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.content
ev.unsigned["prev_sender"] = prev.sender
self._get_event_cache.prefill( self._get_event_cache.prefill(
(ev.event_id, check_redacted, get_prev_content), ev (ev.event_id, check_redacted, get_prev_content), ev
@ -886,7 +889,8 @@ class EventsStore(SQLBaseStore):
get_prev_content=False, get_prev_content=False,
) )
if prev: if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.content
ev.unsigned["prev_sender"] = prev.sender
self._get_event_cache.prefill( self._get_event_cache.prefill(
(ev.event_id, check_redacted, get_prev_content), ev (ev.event_id, check_redacted, get_prev_content), ev

View file

@ -99,34 +99,39 @@ class RoomStore(SQLBaseStore):
""" """
def f(txn): def f(txn):
topic_subquery = ( def subquery(table_name, column_name=None):
"SELECT topics.event_id as event_id, " column_name = column_name or table_name
"topics.room_id as room_id, topic " return (
"FROM topics " "SELECT %(table_name)s.event_id as event_id, "
"INNER JOIN current_state_events as c " "%(table_name)s.room_id as room_id, %(column_name)s "
"ON c.event_id = topics.event_id " "FROM %(table_name)s "
) "INNER JOIN current_state_events as c "
"ON c.event_id = %(table_name)s.event_id " % {
"column_name": column_name,
"table_name": table_name,
}
)
name_subquery = (
"SELECT room_names.event_id as event_id, "
"room_names.room_id as room_id, name "
"FROM room_names "
"INNER JOIN current_state_events as c "
"ON c.event_id = room_names.event_id "
)
# We use non printing ascii character US (\x1F) as a separator
sql = ( sql = (
"SELECT r.room_id, max(n.name), max(t.topic)" "SELECT"
" r.room_id,"
" max(n.name),"
" max(t.topic),"
" max(v.history_visibility),"
" max(g.guest_access)"
" FROM rooms AS r" " FROM rooms AS r"
" LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id" " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
" LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id" " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
" LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
" LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
" WHERE r.is_public = ?" " WHERE r.is_public = ?"
" GROUP BY r.room_id" " GROUP BY r.room_id" % {
) % { "topic": subquery("topics", "topic"),
"topic": topic_subquery, "name": subquery("room_names", "name"),
"name": name_subquery, "history_visibility": subquery("history_visibility"),
} "guest_access": subquery("guest_access"),
}
)
txn.execute(sql, (is_public,)) txn.execute(sql, (is_public,))
@ -156,10 +161,12 @@ class RoomStore(SQLBaseStore):
"room_id": r[0], "room_id": r[0],
"name": r[1], "name": r[1],
"topic": r[2], "topic": r[2],
"aliases": r[3], "world_readable": r[3] == "world_readable",
"guest_can_join": r[4] == "can_join",
"aliases": r[5],
} }
for r in rows for r in rows
if r[3] # We only return rooms that have at least one alias. if r[5] # We only return rooms that have at least one alias.
] ]
defer.returnValue(ret) defer.returnValue(ret)
@ -203,16 +210,22 @@ class RoomStore(SQLBaseStore):
) )
def _store_history_visibility_txn(self, txn, event): def _store_history_visibility_txn(self, txn, event):
if hasattr(event, "content") and "history_visibility" in event.content: self._store_content_index_txn(txn, event, "history_visibility")
def _store_guest_access_txn(self, txn, event):
self._store_content_index_txn(txn, event, "guest_access")
def _store_content_index_txn(self, txn, event, key):
if hasattr(event, "content") and key in event.content:
sql = ( sql = (
"INSERT INTO history_visibility" "INSERT INTO %(key)s"
" (event_id, room_id, history_visibility)" " (event_id, room_id, %(key)s)"
" VALUES (?, ?, ?)" " VALUES (?, ?, ?)" % {"key": key}
) )
txn.execute(sql, ( txn.execute(sql, (
event.event_id, event.event_id,
event.room_id, event.room_id,
event.content["history_visibility"] event.content[key]
)) ))
def _store_event_search_txn(self, txn, event, key, value): def _store_event_search_txn(self, txn, event, key, value):

View file

@ -0,0 +1,21 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS background_updates(
update_name TEXT NOT NULL, -- The name of the background update.
progress_json TEXT NOT NULL, -- The current progress of the update as JSON.
CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
);

View file

@ -22,7 +22,7 @@ import ujson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
POSTGRES_SQL = """ POSTGRES_TABLE = """
CREATE TABLE IF NOT EXISTS event_search ( CREATE TABLE IF NOT EXISTS event_search (
event_id TEXT, event_id TEXT,
room_id TEXT, room_id TEXT,
@ -31,22 +31,6 @@ CREATE TABLE IF NOT EXISTS event_search (
vector tsvector vector tsvector
); );
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.body',
to_tsvector('english', json::json->'content'->>'body')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.message';
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.name',
to_tsvector('english', json::json->'content'->>'name')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.name';
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.topic',
to_tsvector('english', json::json->'content'->>'topic')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic';
CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); CREATE INDEX event_search_fts_idx ON event_search USING gin(vector);
CREATE INDEX event_search_ev_idx ON event_search(event_id); CREATE INDEX event_search_ev_idx ON event_search(event_id);
CREATE INDEX event_search_ev_ridx ON event_search(room_id); CREATE INDEX event_search_ev_ridx ON event_search(room_id);
@ -61,67 +45,34 @@ SQLITE_TABLE = (
def run_upgrade(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
run_postgres_upgrade(cur) for statement in get_statements(POSTGRES_TABLE.splitlines()):
return cur.execute(statement)
elif isinstance(database_engine, Sqlite3Engine):
if isinstance(database_engine, Sqlite3Engine):
run_sqlite_upgrade(cur)
return
def run_postgres_upgrade(cur):
for statement in get_statements(POSTGRES_SQL.splitlines()):
cur.execute(statement)
def run_sqlite_upgrade(cur):
cur.execute(SQLITE_TABLE) cur.execute(SQLITE_TABLE)
else:
raise Exception("Unrecognized database engine")
rowid = -1 cur.execute("SELECT MIN(stream_ordering) FROM events")
while True: rows = cur.fetchall()
cur.execute( min_stream_id = rows[0][0]
"SELECT rowid, json FROM event_json"
" WHERE rowid > ?"
" ORDER BY rowid ASC LIMIT 100",
(rowid,)
)
res = cur.fetchall() cur.execute("SELECT MAX(stream_ordering) FROM events")
rows = cur.fetchall()
max_stream_id = rows[0][0]
if not res: if min_stream_id is not None and max_stream_id is not None:
break progress = {
"target_min_stream_id_inclusive": min_stream_id,
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = ujson.dumps(progress)
events = [ sql = (
ujson.loads(js) "INSERT into background_updates (update_name, progress_json)"
for _, js in res " VALUES (?, ?)"
] )
rowid = max(rid for rid, _ in res) sql = database_engine.convert_param_style(sql)
rows = [] cur.execute(sql, ("event_search", progress_json))
for ev in events:
content = ev.get("content", {})
body = content.get("body", None)
name = content.get("name", None)
topic = content.get("topic", None)
sender = ev.get("sender", None)
if ev["type"] == "m.room.message" and body:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.body", body
))
if ev["type"] == "m.room.name" and name:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.name", name
))
if ev["type"] == "m.room.topic" and topic:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.topic", topic
))
if rows:
logger.info(rows)
cur.executemany(
"INSERT INTO event_search (event_id, room_id, sender, key, value)"
" VALUES (?,?,?,?,?)",
rows
)

View file

@ -0,0 +1,25 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is a manual index of guest_access content of state events,
* so that we can join on them in SELECT statements.
*/
CREATE TABLE IF NOT EXISTS guest_access(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
guest_access TEXT NOT NULL,
UNIQUE (event_id)
);

View file

@ -18,7 +18,6 @@
* so that we can join on them in SELECT statements. * so that we can join on them in SELECT statements.
*/ */
CREATE TABLE IF NOT EXISTS history_visibility( CREATE TABLE IF NOT EXISTS history_visibility(
id INTEGER PRIMARY KEY,
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
history_visibility TEXT NOT NULL, history_visibility TEXT NOT NULL,

View file

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from _base import SQLBaseStore from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@ -25,7 +25,106 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SearchStore(SQLBaseStore): class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
def __init__(self, hs):
super(SearchStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids)
event_search_rows = []
for event in events:
try:
event_id = event.event_id
room_id = event.room_id
content = event.content
if event.type == "m.room.message":
key = "content.body"
value = content["body"]
elif event.type == "m.room.topic":
key = "content.topic"
value = content["topic"]
elif event.type == "m.room.name":
key = "content.name"
value = content["name"]
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
event_search_rows.append((event_id, room_id, key, value))
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(event_search_rows)
}
self._background_update_progress_txn(
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
)
return len(event_search_rows)
result = yield self.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys): def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
@ -153,12 +252,23 @@ class SearchStore(SQLBaseStore):
" WHERE vector @@ query AND room_id = ?" " WHERE vector @@ query AND room_id = ?"
) )
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
# We want to use the full text search index on event_search to
# extract all possible matches first, then lookup those matches
# in the events table to get the topological ordering. We need
# to use the indexes in this order because sqlite refuses to
# MATCH unless it uses the full text search index
sql = ( sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" "SELECT rank(matchinfo) as rank, room_id, event_id,"
" topological_ordering, stream_ordering" " topological_ordering, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search" " FROM event_search"
" NATURAL JOIN events" " WHERE value MATCH ?"
" WHERE value MATCH ? AND room_id = ?" " )"
" CROSS JOIN events USING (event_id)"
" WHERE room_id = ?"
) )
else: else:
# This should be unreachable. # This should be unreachable.

View file

@ -237,6 +237,20 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned
:param list[(str, str)]|None types: List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
:return: a deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000) @cached(num_args=2, lru=True, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(

View file

@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
if room_ids: if room_ids:
tags_by_room = yield self.get_tags_for_user(user_id) tags_by_room = yield self.get_tags_for_user(user_id)
for room_id in room_ids: for room_id in room_ids:
results[room_id] = tags_by_room[room_id] results[room_id] = tags_by_room.get(room_id, {})
defer.returnValue(results) defer.returnValue(results)

View file

@ -59,7 +59,7 @@ class TransactionStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
if result and result.response_code: if result and result["response_code"]:
return result["response_code"], result["response_json"] return result["response_code"], result["response_json"]
else: else:
return None return None

View file

@ -53,6 +53,14 @@ class Clock(object):
loop.stop() loop.stop()
def call_later(self, delay, callback, *args, **kwargs): def call_later(self, delay, callback, *args, **kwargs):
"""Call something later
Args:
delay(float): How long to wait in seconds.
callback(function): Function to call
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(*args, **kwargs): def wrapped_callback(*args, **kwargs):

View file

@ -321,6 +321,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.handlers.room_member_handler.get_room_members = ( hs.handlers.room_member_handler.get_room_members = (
lambda r: self.room_members if r == "a-room" else [] lambda r: self.room_members if r == "a-room" else []
) )
hs.handlers.room_member_handler._filter_events_for_client = (
lambda user_id, events, **kwargs: events
)
self.mock_datastore = hs.get_datastore() self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None) self.mock_datastore.get_app_service_by_token = Mock(return_value=None)

View file

@ -994,3 +994,59 @@ class RoomInitialSyncTestCase(RestTestCase):
} }
self.assertTrue(self.user_id in presence_by_user) self.assertTrue(self.user_id in presence_by_user)
self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"])
class RoomMessageListTestCase(RestTestCase):
""" Tests /rooms/$room_id/messages REST events. """
user_id = "@sid1:red"
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.auth_user_id = self.user_id
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
"is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
self.room_id = yield self.create_room_as(self.user_id)
@defer.inlineCallbacks
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
self.assertEquals(200, code)
self.assertTrue("start" in response)
self.assertEquals(token, response['start'])
self.assertTrue("chunk" in response)
self.assertTrue("end" in response)
@defer.inlineCallbacks
def test_stream_token_is_rejected(self):
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=s0_0_0_0" %
self.room_id)
self.assertEquals(400, code)

View file

@ -0,0 +1,76 @@
from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.types import UserID, RoomID, RoomAlias
from tests.utils import setup_test_homeserver
from mock import Mock
class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.update_handler = Mock()
yield self.store.register_background_update_handler(
"test_update", self.update_handler
)
@defer.inlineCallbacks
def test_do_background_update(self):
desired_count = 1000;
duration_ms = 42;
@defer.inlineCallbacks
def update(progress, count):
self.clock.advance_time_msec(count * duration_ms)
progress = {"my_key": progress["my_key"] + 1}
yield self.store.runInteraction(
"update_progress",
self.store._background_update_progress_txn,
"test_update",
progress,
)
defer.returnValue(count)
self.update_handler.side_effect = update
yield self.store.start_background_update("test_update", {"my_key": 1})
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with(
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
)
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
defer.returnValue(count)
self.update_handler.side_effect = update
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with(
{"my_key": 2}, desired_count
)
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
duration_ms * desired_count
)
self.assertIsNone(result)
self.assertFalse(self.update_handler.called)

View file

@ -73,6 +73,8 @@ class RoomStoreTestCase(unittest.TestCase):
"room_id": self.room.to_string(), "room_id": self.room.to_string(),
"topic": None, "topic": None,
"aliases": [self.alias.to_string()], "aliases": [self.alias.to_string()],
"world_readable": False,
"guest_can_join": False,
}, rooms[0]) }, rooms[0])

View file

@ -317,6 +317,99 @@ class StateTestCase(unittest.TestCase):
{e.event_id for e in context_store["E"].current_state.values()} {e.event_id for e in context_store["E"].current_state.values()}
) )
@defer.inlineCallbacks
def test_branch_have_perms_conflict(self):
userid1 = "@user_id:example.com"
userid2 = "@user_id2:example.com"
nodes = {
"A1": DictObj(
type=EventTypes.Create,
state_key="",
content={"creator": userid1},
depth=1,
),
"A2": DictObj(
type=EventTypes.Member,
state_key=userid1,
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
),
"A3": DictObj(
type=EventTypes.Member,
state_key=userid2,
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
),
"A4": DictObj(
type=EventTypes.PowerLevels,
state_key="",
content={
"events": {"m.room.name": 50},
"users": {userid1: 100,
userid2: 60},
},
),
"A5": DictObj(
type=EventTypes.Name,
state_key="",
),
"B": DictObj(
type=EventTypes.PowerLevels,
state_key="",
content={
"events": {"m.room.name": 50},
"users": {userid2: 30},
},
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
sender=userid2,
),
"D": DictObj(
type=EventTypes.Message,
),
}
edges = {
"A2": ["A1"],
"A3": ["A2"],
"A4": ["A3"],
"A5": ["A4"],
"B": ["A5"],
"C": ["A5"],
"D": ["B", "C"]
}
self._add_depths(nodes, edges)
graph = Graph(nodes, edges)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
{e.event_id for e in context_store["D"].current_state.values()}
)
def _add_depths(self, nodes, edges):
def _get_depth(ev):
node = nodes[ev]
if 'depth' not in node:
prevs = edges[ev]
depth = max(_get_depth(prev) for prev in prevs) + 1
node['depth'] = depth
return node['depth']
for n in nodes:
_get_depth(n)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_message(self): def test_annotate_with_old_message(self):
event = create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")

View file

@ -243,6 +243,9 @@ class MockClock(object):
else: else:
self.timers.append(t) self.timers.append(t)
def advance_time_msec(self, ms):
self.advance_time(ms / 1000.)
class SQLiteMemoryDbPool(ConnectionPool, object): class SQLiteMemoryDbPool(ConnectionPool, object):
def __init__(self): def __init__(self):

View file

@ -6,10 +6,12 @@ deps =
coverage coverage
Twisted>=15.1 Twisted>=15.1
mock mock
python-subunit
junitxml
setenv = setenv =
PYTHONDONTWRITEBYTECODE = no_byte_code PYTHONDONTWRITEBYTECODE = no_byte_code
commands = commands =
coverage run --source=synapse {envbindir}/trial {posargs:tests} /bin/bash -c "coverage run --source=synapse {envbindir}/trial {posargs:tests} {env:TOXSUFFIX:}"
coverage report -m coverage report -m
[testenv:packaging] [testenv:packaging]