Allow non-guests to peek on rooms using /events

This commit is contained in:
Daniel Wagner-Hall 2016-01-20 15:34:07 +00:00 committed by review.rocks
parent ea5eea2424
commit da417aa56d
8 changed files with 107 additions and 84 deletions

View file

@ -84,7 +84,7 @@ class BaseHandler(object):
row["event_id"] for rows in forgotten for row in rows row["event_id"] for rows in forgotten for row in rows
) )
def allowed(event, user_id, is_guest): def allowed(event, user_id, is_peeking):
state = event_id_to_state[event.event_id] state = event_id_to_state[event.event_id]
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
@ -96,7 +96,7 @@ class BaseHandler(object):
if visibility == "world_readable": if visibility == "world_readable":
return True return True
if is_guest: if is_peeking:
return False return False
membership_event = state.get((EventTypes.Member, user_id), None) membership_event = state.get((EventTypes.Member, user_id), None)
@ -112,7 +112,7 @@ class BaseHandler(object):
return True return True
if event.type == EventTypes.RoomHistoryVisibility: if event.type == EventTypes.RoomHistoryVisibility:
return not is_guest return not is_peeking
if visibility == "shared": if visibility == "shared":
return True return True
@ -127,15 +127,15 @@ class BaseHandler(object):
user_id: [ user_id: [
event event
for event in events for event in events
if allowed(event, user_id, is_guest) if allowed(event, user_id, is_peeking)
] ]
for user_id, is_guest in user_tuples for user_id, is_peeking in user_tuples
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False): def _filter_events_for_client(self, user_id, events, is_peeking=False):
# Assumes that user has at some point joined the room if not is_guest. # Assumes that user has at some point joined the room if not is_guest.
res = yield self._filter_events_for_clients([(user_id, is_guest)], events) res = yield self._filter_events_for_clients([(user_id, is_peeking)], events)
defer.returnValue(res.get(user_id, [])) defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id): def ratelimit(self, user_id):

View file

@ -135,7 +135,7 @@ class EventStreamHandler(BaseHandler):
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout, auth_user, pagin_config, timeout,
only_room_events=only_room_events, only_room_events=only_room_events,
is_guest=is_guest, guest_room_id=room_id is_guest=is_guest, explicit_room_id=room_id
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -78,21 +78,20 @@ class MessageHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True, is_guest=False): as_client_event=True):
"""Get messages in a room. """Get messages in a room.
Args: Args:
user_id (str): The user requesting messages. requester (Requester): The user requesting messages.
room_id (str): The room they want messages from. room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
is_guest (bool): Whether the requesting user is a guest (as opposed
to a fully registered user).
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"] data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
@ -115,15 +114,14 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
if not is_guest: membership, member_event_id = yield self._check_in_room_or_world_readable(
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) room_id, user_id
if member_event.membership == Membership.LEAVE: )
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room. # they left the room.
# If they're a guest, we'll just 403 them if they're asking for
# events they can't see.
leave_token = yield self.store.get_topological_token_for_event( leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological: if leave_token.topological < room_token.topological:
@ -141,10 +139,8 @@ class MessageHandler(BaseHandler):
room_id, room_token.topological room_id, room_token.topological
) )
user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(
user, source_config, room_id requester.user, source_config, room_id
) )
next_token = pagin_config.from_token.copy_and_replace( next_token = pagin_config.from_token.copy_and_replace(
@ -158,7 +154,11 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest) events = yield self._filter_events_for_client(
user_id,
events,
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -289,7 +289,7 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
membership, membership_event_id = yield self._check_in_room_or_world_readable( membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest room_id, user_id
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
@ -306,7 +306,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id, is_guest): def _check_in_room_or_world_readable(self, room_id, user_id):
try: try:
# check_user_was_in_room will return the most recent membership # check_user_was_in_room will return the most recent membership
# event for the user if: # event for the user if:
@ -316,7 +316,7 @@ class MessageHandler(BaseHandler):
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id)) defer.returnValue((member_event.membership, member_event.event_id))
return return
except AuthError, auth_error: except AuthError:
visibility = yield self.state_handler.get_current_state( visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
@ -326,8 +326,6 @@ class MessageHandler(BaseHandler):
): ):
defer.returnValue((Membership.JOIN, None)) defer.returnValue((Membership.JOIN, None))
return return
if not is_guest:
raise auth_error
raise AuthError( raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
) )
@ -345,7 +343,7 @@ class MessageHandler(BaseHandler):
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" """
membership, membership_event_id = yield self._check_in_room_or_world_readable( membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest room_id, user_id
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
@ -556,13 +554,13 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None, is_guest=False): def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of """Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left. the room this will be what was in the room when they left.
Args: Args:
user_id(str): The user to get a snapshot for. requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of. room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig): pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to The pagination config used to determine how many messages to
@ -573,19 +571,20 @@ class MessageHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room. A JSON serialisable dict with the snapshot of the room.
""" """
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable( membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, room_id, user_id,
user_id,
is_guest
) )
is_peeking = member_event_id is None
if membership == Membership.JOIN: if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined( result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_guest user_id, room_id, pagin_config, membership, is_peeking
) )
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted( result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_guest user_id, room_id, pagin_config, membership, member_event_id, is_peeking
) )
account_data_events = [] account_data_events = []
@ -609,7 +608,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config, def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_guest): membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event_id], None [member_event_id], None
) )
@ -631,7 +630,7 @@ class MessageHandler(BaseHandler):
) )
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, messages, is_guest=is_guest user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken(token[0], 0, 0, 0, 0) start_token = StreamToken(token[0], 0, 0, 0, 0)
@ -654,7 +653,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config, def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_guest): membership, is_peeking):
current_state = yield self.state.get_current_state( current_state = yield self.state.get_current_state(
room_id=room_id, room_id=room_id,
) )
@ -718,7 +717,7 @@ class MessageHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, messages, is_guest=is_guest, user_id, messages, is_peeking=is_peeking,
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -737,7 +736,7 @@ class MessageHandler(BaseHandler):
"presence": presence, "presence": presence,
"receipts": receipts, "receipts": receipts,
} }
if not is_guest: if not is_peeking:
ret["membership"] = membership ret["membership"] = membership
defer.returnValue(ret) defer.returnValue(ret)

View file

@ -936,7 +936,7 @@ class RoomContextHandler(BaseHandler):
return self._filter_events_for_client( return self._filter_events_for_client(
user.to_string(), user.to_string(),
events, events,
is_guest=is_guest) is_peeking=is_guest)
event = yield self.store.get_event(event_id, get_prev_content=True, event = yield self.store.get_event(event_id, get_prev_content=True,
allow_none=True) allow_none=True)

View file

@ -150,7 +150,7 @@ class SyncHandler(BaseHandler):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user, timeout, current_sync_callback, sync_config.user.to_string(), timeout, current_sync_callback,
from_token=since_token from_token=since_token
) )
defer.returnValue(result) defer.returnValue(result)
@ -640,7 +640,7 @@ class SyncHandler(BaseHandler):
loaded_recents = yield self._filter_events_for_client( loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
is_guest=sync_config.is_guest, is_peeking=sync_config.is_guest,
) )
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents

View file

@ -63,9 +63,9 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, current_token, time_now_ms, def __init__(self, user_id, rooms, current_token, time_now_ms,
appservice=None): appservice=None):
self.user = str(user) self.user_id = user_id
self.appservice = appservice self.appservice = appservice
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
@ -98,7 +98,7 @@ class _NotifierUserStream(object):
lst = notifier.room_to_user_streams.get(room, set()) lst = notifier.room_to_user_streams.get(room, set())
lst.discard(self) lst.discard(self)
notifier.user_to_user_stream.pop(self.user) notifier.user_to_user_stream.pop(self.user_id)
if self.appservice: if self.appservice:
notifier.appservice_to_user_streams.get( notifier.appservice_to_user_streams.get(
@ -271,21 +271,20 @@ class Notifier(object):
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
user = str(user) user_stream = self.user_to_user_stream.get(user_id)
user_stream = self.user_to_user_stream.get(user)
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user) appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user) rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [room.room_id for room in rooms] room_ids = [room.room_id for room in rooms]
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user=user, user_id=user_id,
rooms=room_ids, rooms=room_ids,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
@ -333,12 +332,17 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_for(self, user, pagination_config, timeout, def get_events_for(self, user, pagination_config, timeout,
only_room_events=False, only_room_events=False,
is_guest=False, guest_room_id=None): is_guest=False, explicit_room_id=None):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
If `only_room_events` is `True` only room events will be returned. If `only_room_events` is `True` only room events will be returned.
If explicit_room_id is not set, the user's joined rooms will be polled
for events.
If explicit_room_id is set, that room will be polled for events only if
it is world readable or the user has joined the room.
""" """
from_token = pagination_config.from_token from_token = pagination_config.from_token
if not from_token: if not from_token:
@ -346,15 +350,8 @@ class Notifier(object):
limit = pagination_config.limit limit = pagination_config.limit
room_ids = [] room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
if is_guest: is_peeking = not is_joined
if guest_room_id:
if not (yield self._is_world_readable(guest_room_id)):
raise AuthError(403, "Guest access not allowed")
room_ids = [guest_room_id]
else:
rooms = yield self.store.get_rooms_for_user(user.to_string())
room_ids = [room.room_id for room in rooms]
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
@ -376,7 +373,7 @@ class Notifier(object):
user=user, user=user,
from_key=getattr(from_token, keyname), from_key=getattr(from_token, keyname),
limit=limit, limit=limit,
is_guest=is_guest, is_guest=is_peeking,
room_ids=room_ids, room_ids=room_ids,
) )
@ -385,7 +382,7 @@ class Notifier(object):
new_events = yield room_member_handler._filter_events_for_client( new_events = yield room_member_handler._filter_events_for_client(
user.to_string(), user.to_string(),
new_events, new_events,
is_guest=is_guest, is_peeking=is_peeking,
) )
events.extend(new_events) events.extend(new_events)
@ -396,8 +393,24 @@ class Notifier(object):
else: else:
defer.returnValue(None) defer.returnValue(None)
user_id_for_stream = user.to_string()
if is_peeking:
# Internally, the notifier keeps an event stream per user_id.
# This is used by both /sync and /events.
# We want /events to be used for peeking independently of /sync,
# without polluting its contents. So we invent an illegal user ID
# (which thus cannot clash with any real users) for keying peeking
# over /events.
#
# I am sorry for what I have done.
user_id_for_stream = "_PEEKING_" + user_id_for_stream
result = yield self.wait_for_events( result = yield self.wait_for_events(
user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token user_id_for_stream,
timeout,
check_for_updates,
room_ids=room_ids,
from_token=from_token,
) )
if result is None: if result is None:
@ -405,6 +418,18 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id):
joined_rooms = yield self.store.get_rooms_for_user(user.to_string())
joined_room_ids = map(lambda r: r.room_id, joined_rooms)
if explicit_room_id:
if explicit_room_id in joined_room_ids:
defer.returnValue(([explicit_room_id], True))
if (yield self._is_world_readable(explicit_room_id)):
defer.returnValue(([explicit_room_id], False))
raise AuthError(403, "Non-joined access not allowed")
defer.returnValue((joined_room_ids, True))
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_world_readable(self, room_id): def _is_world_readable(self, room_id):
state = yield self.hs.get_state_handler().get_current_state( state = yield self.hs.get_state_handler().get_current_state(
@ -432,7 +457,7 @@ class Notifier(object):
@log_function @log_function
def _register_with_keys(self, user_stream): def _register_with_keys(self, user_stream):
self.user_to_user_stream[user_stream.user] = user_stream self.user_to_user_stream[user_stream.user_id] = user_stream
for room in user_stream.rooms: for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())

View file

@ -43,6 +43,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
if is_guest: if is_guest:
if "room_id" not in request.args: if "room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param") raise SynapseError(400, "Guest users must specify room_id param")
if "room_id" in request.args:
room_id = request.args["room_id"][0] room_id = request.args["room_id"][0]
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler

View file

@ -348,8 +348,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
handler = self.handlers.message_handler handler = self.handlers.message_handler
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), requester=requester,
is_guest=requester.is_guest,
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event
) )
@ -384,9 +383,8 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), requester=requester,
pagin_config=pagination_config, pagin_config=pagination_config,
is_guest=requester.is_guest,
) )
defer.returnValue((200, content)) defer.returnValue((200, content))