diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 88445fe999..dfbbc5a1cd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -49,6 +49,7 @@ class Auth(object): self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self._KNOWN_CAVEAT_PREFIXES = set([ "gen = ", + "guest = ", "type = ", "time < ", "user_id = ", @@ -183,15 +184,11 @@ class Auth(object): defer.returnValue(member) @defer.inlineCallbacks - def check_user_was_in_room(self, room_id, user_id, current_state=None): + def check_user_was_in_room(self, room_id, user_id): """Check if the user was in the room at some point. Args: room_id(str): The room to check. user_id(str): The user to check. - current_state(dict): Optional map of the current state of the room. - If provided then that map is used to check whether they are a - member of the room. Otherwise the current membership is - loaded from the database. Raises: AuthError if the user was never in the room. Returns: @@ -199,17 +196,11 @@ class Auth(object): room. This will be the join event if they are currently joined to the room. This will be the leave event if they have left the room. """ - if current_state: - member = current_state.get( - (EventTypes.Member, user_id), - None - ) - else: - member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) membership = member.membership if member else None if membership not in (Membership.JOIN, Membership.LEAVE): @@ -497,7 +488,7 @@ class Auth(object): return default @defer.inlineCallbacks - def get_user_by_req(self, request): + def get_user_by_req(self, request, allow_guest=False): """ Get a registered user's ID. Args: @@ -535,7 +526,7 @@ class Auth(object): request.authenticated_entity = user_id - defer.returnValue((UserID.from_string(user_id), "")) + defer.returnValue((UserID.from_string(user_id), "", False)) return except KeyError: pass # normal users won't have the user_id query parameter set. @@ -543,6 +534,7 @@ class Auth(object): user_info = yield self._get_user_by_access_token(access_token) user = user_info["user"] token_id = user_info["token_id"] + is_guest = user_info["is_guest"] ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( @@ -557,9 +549,14 @@ class Auth(object): user_agent=user_agent ) + if is_guest and not allow_guest: + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) + request.authenticated_entity = user.to_string() - defer.returnValue((user, token_id,)) + defer.returnValue((user, token_id, is_guest,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -592,31 +589,45 @@ class Auth(object): self._validate_macaroon(macaroon) user_prefix = "user_id = " + user = None + guest = False for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) - # This codepath exists so that we can actually return a - # token ID, because we use token IDs in place of device - # identifiers throughout the codebase. - # TODO(daniel): Remove this fallback when device IDs are - # properly implemented. - ret = yield self._look_up_user_by_access_token(macaroon_str) - if ret["user"] != user: - logger.error( - "Macaroon user (%s) != DB user (%s)", - user, - ret["user"] - ) - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "User mismatch in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) - defer.returnValue(ret) - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) + elif caveat.caveat_id == "guest = true": + guest = True + + if user is None: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + + if guest: + ret = { + "user": user, + "is_guest": True, + "token_id": None, + } + else: + # This codepath exists so that we can actually return a + # token ID, because we use token IDs in place of device + # identifiers throughout the codebase. + # TODO(daniel): Remove this fallback when device IDs are + # properly implemented. + ret = yield self._look_up_user_by_access_token(macaroon_str) + if ret["user"] != user: + logger.error( + "Macaroon user (%s) != DB user (%s)", + user, + ret["user"] + ) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "User mismatch in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + defer.returnValue(ret) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", @@ -629,6 +640,7 @@ class Auth(object): v.satisfy_exact("type = access") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(self._verify_expiry) + v.satisfy_exact("guest = true") v.verify(macaroon, self.hs.config.macaroon_secret_key) v = pymacaroons.Verifier() @@ -666,6 +678,7 @@ class Auth(object): user_info = { "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), + "is_guest": False, } defer.returnValue(user_info) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b3fea27d0e..d4037b3d55 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -33,6 +33,7 @@ class Codes(object): NOT_FOUND = "M_NOT_FOUND" MISSING_TOKEN = "M_MISSING_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" + GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index e4e3d1c59d..aaa2433cae 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -50,11 +50,11 @@ class Filtering(object): # many definitions. top_level_definitions = [ - "public_user_data", "private_user_data", "server_data" + "presence" ] room_level_definitions = [ - "state", "timeline", "ephemeral" + "state", "timeline", "ephemeral", "private_user_data" ] for key in top_level_definitions: @@ -114,22 +114,6 @@ class Filtering(object): if not isinstance(event_type, basestring): raise SynapseError(400, "Event type should be a string") - if "format" in definition: - event_format = definition["format"] - if event_format not in ["federation", "events"]: - raise SynapseError(400, "Invalid format: %s" % (event_format,)) - - if "select" in definition: - event_select_list = definition["select"] - for select_key in event_select_list: - if select_key not in ["event_id", "origin_server_ts", - "thread_id", "content", "content.body"]: - raise SynapseError(400, "Bad select: %s" % (select_key,)) - - if ("bundle_updates" in definition and - type(definition["bundle_updates"]) != bool): - raise SynapseError(400, "Bad bundle_updates: expected bool.") - class FilterCollection(object): def __init__(self, filter_json): diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f5ef36a9f4..dca391f7af 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -34,6 +34,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) + self.allow_guest_access = config.get("allow_guest_access", False) def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -54,6 +55,11 @@ class RegistrationConfig(Config): # Larger numbers increase the work factor needed to generate the hash. # The default number of rounds is 12. bcrypt_rounds: 12 + + # Allows users to register as guests without a password/email/etc, and + # participate in rooms hosted on this server which have been made + # accessible to anonymous users. + allow_guest_access: False """ % locals() def add_arguments(self, parser): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6a26cb1879..a9e43052b7 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -47,37 +47,24 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, events): - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) - ) + def _filter_events_for_client(self, user_id, events, is_guest=False, + require_all_visible_for_guests=True): + # Assumes that user has at some point joined the room if not is_guest. - def allowed(event, state): - if event.type == EventTypes.RoomHistoryVisibility: + def allowed(event, membership, visibility): + if visibility == "world_readable": return True - membership_ev = state.get((EventTypes.Member, user_id), None) - if membership_ev: - membership = membership_ev.membership - else: - membership = Membership.LEAVE + if is_guest: + return False if membership == Membership.JOIN: return True - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - else: - visibility = "shared" + if event.type == EventTypes.RoomHistoryVisibility: + return not is_guest - if visibility == "public": - return True - elif visibility == "shared": + if visibility == "shared": return True elif visibility == "joined": return membership == Membership.JOIN @@ -86,11 +73,46 @@ class BaseHandler(object): return True - defer.returnValue([ - event - for event in events - if allowed(event, event_id_to_state[event.event_id]) - ]) + event_id_to_state = yield self.store.get_state_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) + ) + + events_to_return = [] + for event in events: + state = event_id_to_state[event.event_id] + + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + membership = membership_event.membership + else: + membership = None + + visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) + if visibility_event: + visibility = visibility_event.content.get("history_visibility", "shared") + else: + visibility = "shared" + + should_include = allowed(event, membership, visibility) + if should_include: + events_to_return.append(event) + + if (require_all_visible_for_guests + and is_guest + and len(events_to_return) < len(events)): + # This indicates that some events in the requested range were not + # visible to guest users. To be safe, we reject the entire request, + # so that we don't have to worry about interpreting visibility + # boundaries. + raise AuthError(403, "User %s does not have permission" % ( + user_id + )) + + defer.returnValue(events_to_return) def ratelimit(self, user_id): time_now = self.clock.time() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 055d395b20..1b11dbdffd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -372,12 +372,15 @@ class AuthHandler(BaseHandler): yield self.store.add_refresh_token_to_user(user_id, refresh_token) defer.returnValue(refresh_token) - def generate_access_token(self, user_id): + def generate_access_token(self, user_id, extra_caveats=None): + extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() expiry = now + (60 * 60 * 1000) macaroon.add_first_party_caveat("time < %d" % (expiry,)) + for caveat in extra_caveats: + macaroon.add_first_party_caveat(caveat) return macaroon.serialize() def generate_refresh_token(self, user_id): diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 92afa35d57..0e4c0d4d06 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler): @log_function def get_stream(self, auth_user_id, pagin_config, timeout=0, as_client_event=True, affect_presence=True, - only_room_events=False): + only_room_events=False, room_id=None, is_guest=False): """Fetches the events stream for a given user. If `only_room_events` is `True` only room events will be returned. @@ -111,17 +111,6 @@ class EventStreamHandler(BaseHandler): if affect_presence: yield self.started_stream(auth_user) - rm_handler = self.hs.get_handlers().room_member_handler - - app_service = yield self.store.get_app_service_by_user_id( - auth_user.to_string() - ) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - room_ids = set(r.room_id for r in rooms) - else: - room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user) - if timeout: # If they've set a timeout set a minimum limit. timeout = max(timeout, 500) @@ -130,9 +119,15 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) + if is_guest: + yield self.distributor.fire( + "user_joined_room", user=auth_user, room_id=room_id + ) + events, tokens = yield self.notifier.get_events_for( - auth_user, room_ids, pagin_config, timeout, - only_room_events=only_room_events + auth_user, pagin_config, timeout, + only_room_events=only_room_events, + is_guest=is_guest, guest_room_id=room_id ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ae9d227586..b2395b28d1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -72,8 +72,6 @@ class FederationHandler(BaseHandler): self.server_name = hs.hostname self.keyring = hs.get_keyring() - self.lock_manager = hs.get_room_lock_manager() - self.replication_layer.set_handler(self) # When joining a room we need to queue any events for that room up diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0f947993d1..654ecd2b37 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, AuthError, Codes from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -71,20 +71,20 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, is_guest=False): """Get messages in a room. Args: user_id (str): The user requesting messages. room_id (str): The room they want messages from. pagin_config (synapse.api.streams.PaginationConfig): The pagination - config rules to apply, if any. + config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + is_guest (bool): Whether the requesting user is a guest (as opposed + to a fully registered user). Returns: dict: Pagination API results """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: @@ -107,23 +107,27 @@ class MessageHandler(BaseHandler): source_config = pagin_config.get_source_config("room") - if member_event.membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room - leave_token = yield self.store.get_topological_token_for_event( - member_event.event_id - ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < room_token.topological: - source_config.from_key = str(leave_token) + if not is_guest: + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + if member_event.membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room. + # If they're a guest, we'll just 403 them if they're asking for + # events they can't see. + leave_token = yield self.store.get_topological_token_for_event( + member_event.event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < room_token.topological: + source_config.from_key = str(leave_token) - if source_config.direction == "f": - if source_config.to_key is None: - source_config.to_key = str(leave_token) - else: - to_token = RoomStreamToken.parse(source_config.to_key) - if leave_token.topological < to_token.topological: + if source_config.direction == "f": + if source_config.to_key is None: source_config.to_key = str(leave_token) + else: + to_token = RoomStreamToken.parse(source_config.to_key) + if leave_token.topological < to_token.topological: + source_config.to_key = str(leave_token) yield self.hs.get_handlers().federation_handler.maybe_backfill( room_id, room_token.topological @@ -146,7 +150,7 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) - events = yield self._filter_events_for_client(user_id, events) + events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest) time_now = self.clock.time_msec() @@ -225,7 +229,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key=""): + event_type=None, state_key="", is_guest=False): """ Get data from a room. Args: @@ -235,23 +239,42 @@ class MessageHandler(BaseHandler): Raises: SynapseError if something went wrong. """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + membership, membership_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, is_guest + ) - if member_event.membership == Membership.JOIN: + if membership == Membership.JOIN: data = yield self.state_handler.get_current_state( room_id, event_type, state_key ) - elif member_event.membership == Membership.LEAVE: + elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - [member_event.event_id], [key] + [membership_event_id], [key] ) - data = room_state[member_event.event_id].get(key) + data = room_state[membership_event_id].get(key) defer.returnValue(data) @defer.inlineCallbacks - def get_state_events(self, user_id, room_id): + def _check_in_room_or_world_readable(self, room_id, user_id, is_guest): + if is_guest: + visibility = yield self.state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if visibility.content["history_visibility"] == "world_readable": + defer.returnValue((Membership.JOIN, None)) + return + else: + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) + else: + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + + @defer.inlineCallbacks + def get_state_events(self, user_id, room_id, is_guest=False): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has left the room return the state events from when they left. @@ -262,15 +285,17 @@ class MessageHandler(BaseHandler): Returns: A list of dicts representing state events. [{}, {}, {}] """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + membership, membership_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, is_guest + ) - if member_event.membership == Membership.JOIN: + if membership == Membership.JOIN: room_state = yield self.state_handler.get_current_state(room_id) - elif member_event.membership == Membership.LEAVE: + elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [member_event.event_id], None + [membership_event_id], None ) - room_state = room_state[member_event.event_id] + room_state = room_state[membership_event_id] now = self.clock.time_msec() defer.returnValue( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index ce60642127..0b780cd528 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1142,8 +1142,9 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, user, from_key, room_ids=None, **kwargs): from_key = int(from_key) + room_ids = room_ids or [] presence = self.hs.get_handlers().presence_handler cachemap = presence._user_cachemap @@ -1161,7 +1162,6 @@ class PresenceEventSource(object): user_ids_to_check |= set( UserID.from_string(p["observed_user_id"]) for p in presence_list ) - room_ids = yield presence.get_joined_rooms_for_user(user) for room_id in set(room_ids) & set(presence._room_serials): if presence._room_serials[room_id] > from_key: joined = yield presence.get_joined_users_for_room_id(room_id) diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/private_user_data.py index 1778c71325..1abe45ed7b 100644 --- a/synapse/handlers/private_user_data.py +++ b/synapse/handlers/private_user_data.py @@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object): return self.store.get_max_private_user_data_stream_id() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a47ae3df42..973f4d5cae 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -164,17 +164,15 @@ class ReceiptEventSource(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) to_key = yield self.get_current_key() if from_key == to_key: defer.returnValue(([], to_key)) - rooms = yield self.store.get_rooms_for_user(user.to_string()) - rooms = [room.room_id for room in rooms] events = yield self.store.get_linearized_receipts_for_rooms( - rooms, + room_ids, from_key=from_key, to_key=to_key, ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ef4081e3fe..493a087031 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler): ) @defer.inlineCallbacks - def register(self, localpart=None, password=None): + def register(self, localpart=None, password=None, generate_token=True): """Registers a new client on the server. Args: @@ -89,7 +89,9 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token(user_id) + token = None + if generate_token: + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -102,14 +104,14 @@ class RegistrationHandler(BaseHandler): attempts = 0 user_id = None token = None - while not user_id and not token: + while not user_id: try: localpart = self._generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - - token = self.auth_handler().generate_access_token(user_id) + if generate_token: + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 36878a6c20..736ffe9066 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -807,7 +807,14 @@ class RoomEventSource(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events( + self, + user, + from_key, + limit, + room_ids, + is_guest, + ): # We just ignore the key for now. to_key = yield self.get_current_key() @@ -827,8 +834,9 @@ class RoomEventSource(object): user_id=user.to_string(), from_key=from_key, to_key=to_key, - room_id=None, limit=limit, + room_ids=room_ids, + is_guest=is_guest, ) defer.returnValue((events, end_key)) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d6527c1ae8..5294d96466 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -143,21 +143,8 @@ class SyncHandler(BaseHandler): def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) - rm_handler = self.hs.get_handlers().room_member_handler - - app_service = yield self.store.get_app_service_by_user_id( - sync_config.user.to_string() - ) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - room_ids = set(r.room_id for r in rooms) - else: - room_ids = yield rm_handler.get_joined_rooms_for_user( - sync_config.user - ) - result = yield self.notifier.wait_for_events( - sync_config.user, room_ids, timeout, current_sync_callback, + sync_config.user, timeout, current_sync_callback, from_token=since_token ) defer.returnValue(result) @@ -308,11 +295,16 @@ class SyncHandler(BaseHandler): typing_key = since_token.typing_key if since_token else "0" + rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) + room_ids = [room.room_id for room in rooms] + typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events_for_user( + typing, typing_key = yield typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter.ephemeral_limit(), + room_ids=room_ids, + is_guest=False, ) now_token = now_token.copy_and_replace("typing_key", typing_key) @@ -325,10 +317,13 @@ class SyncHandler(BaseHandler): receipt_key = since_token.receipt_key if since_token else "0" receipt_source = self.event_sources.sources["receipt"] - receipts, receipt_key = yield receipt_source.get_new_events_for_user( + receipts, receipt_key = yield receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, limit=sync_config.filter.ephemeral_limit(), + room_ids=room_ids, + # /sync doesn't support guest access, they can't get to this point in code + is_guest=False, ) now_token = now_token.copy_and_replace("receipt_key", receipt_key) @@ -373,11 +368,17 @@ class SyncHandler(BaseHandler): """ now_token = yield self.event_sources.get_current_token() + rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) + room_ids = [room.room_id for room in rooms] + presence_source = self.event_sources.sources["presence"] - presence, presence_key = yield presence_source.get_new_events_for_user( + presence, presence_key = yield presence_source.get_new_events( user=sync_config.user, from_key=since_token.presence_key, limit=sync_config.filter.presence_limit(), + room_ids=room_ids, + # /sync doesn't support guest access, they can't get to this point in code + is_guest=False, ) now_token = now_token.copy_and_replace("presence_key", presence_key) @@ -403,7 +404,6 @@ class SyncHandler(BaseHandler): sync_config.user.to_string(), from_key=since_token.room_key, to_key=now_token.room_key, - room_id=None, limit=timeline_limit + 1, ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d7096aab8c..2846f3e6e8 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -246,17 +246,12 @@ class TypingNotificationEventSource(object): }, } - @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) handler = self.handler() - joined_room_ids = ( - yield self.room_member_handler().get_joined_rooms_for_user(user) - ) - events = [] - for room_id in joined_room_ids: + for room_id in room_ids: if room_id not in handler._room_serials: continue if handler._room_serials[room_id] <= from_key: @@ -264,7 +259,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - defer.returnValue((events, handler._latest_room_serial)) + return events, handler._latest_room_serial def get_current_key(self): return self.handler()._latest_room_serial diff --git a/synapse/notifier.py b/synapse/notifier.py index a78ee3c1e7..56c4c863b5 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -269,7 +269,7 @@ class Notifier(object): logger.exception("Failed to notify listener") @defer.inlineCallbacks - def wait_for_events(self, user, rooms, timeout, callback, + def wait_for_events(self, user, timeout, callback, room_ids=None, from_token=StreamToken("s0", "0", "0", "0", "0")): """Wait until the callback returns a non empty response or the timeout fires. @@ -279,11 +279,12 @@ class Notifier(object): if user_stream is None: appservice = yield self.store.get_app_service_by_user_id(user) current_token = yield self.event_sources.get_current_token() - rooms = yield self.store.get_rooms_for_user(user) - rooms = [room.room_id for room in rooms] + if room_ids is None: + rooms = yield self.store.get_rooms_for_user(user) + room_ids = [room.room_id for room in rooms] user_stream = _NotifierUserStream( user=user, - rooms=rooms, + rooms=room_ids, appservice=appservice, current_token=current_token, time_now_ms=self.clock.time_msec(), @@ -328,8 +329,9 @@ class Notifier(object): defer.returnValue(result) @defer.inlineCallbacks - def get_events_for(self, user, rooms, pagination_config, timeout, - only_room_events=False): + def get_events_for(self, user, pagination_config, timeout, + only_room_events=False, + is_guest=False, guest_room_id=None): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. @@ -342,6 +344,16 @@ class Notifier(object): limit = pagination_config.limit + room_ids = [] + if is_guest: + # TODO(daniel): Deal with non-room events too + only_room_events = True + if guest_room_id: + room_ids = [guest_room_id] + else: + rooms = yield self.store.get_rooms_for_user(user.to_string()) + room_ids = [room.room_id for room in rooms] + @defer.inlineCallbacks def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): @@ -357,9 +369,23 @@ class Notifier(object): continue if only_room_events and name != "room": continue - new_events, new_key = yield source.get_new_events_for_user( - user, getattr(from_token, keyname), limit, + new_events, new_key = yield source.get_new_events( + user=user, + from_key=getattr(from_token, keyname), + limit=limit, + is_guest=is_guest, + room_ids=room_ids, ) + + if is_guest: + room_member_handler = self.hs.get_handlers().room_member_handler + new_events = yield room_member_handler._filter_events_for_client( + user.to_string(), + new_events, + is_guest=is_guest, + require_all_visible_for_guests=False + ) + events.extend(new_events) end_token = end_token.copy_and_replace(keyname, new_key) @@ -369,7 +395,7 @@ class Notifier(object): defer.returnValue(None) result = yield self.wait_for_events( - user, rooms, timeout, check_for_updates, from_token=from_token + user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token ) if result is None: diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 504b63eab4..bdde43864c 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4dcda57c1b..240eedac75 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet): try: # try to auth as a user - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) try: user_id = user.to_string() yield dir_handler.create_association( @@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet): # fallback to default user behaviour if they aren't an AS pass - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 582148b659..3e1750d1a1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, is_guest = yield self.auth.get_user_by_req( + request, + allow_guest=True + ) + room_id = None + if is_guest: + if "room_id" not in request.args: + raise SynapseError(400, "Guest users must specify room_id param") + room_id = request.args["room_id"][0] try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet): chunk = yield handler.get_stream( auth_user.to_string(), pagin_config, timeout=timeout, - as_client_event=as_client_event + as_client_event=as_client_event, affect_presence=(not is_guest), + room_id=room_id, is_guest=is_guest ) except: logger.exception("Event stream failed") @@ -71,7 +80,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 52c7943400..856a70f297 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a770efd841..6fe5d19a22 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index fdde88a60d..6b379e4e5f 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index bd759a2589..b0870db1ac 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet): except InvalidRuleException as e: raise SynapseError(400, e.message) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) if '/' in spec['rule_id'] or '\\' in spec['rule_id']: raise SynapseError(400, "rule_id may not contain slashes") @@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_DELETE(self, request): spec = _rule_spec_from_path(request.postpath) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) @@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 3aabc93b8b..a110c0a4f0 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2dcaee86cd..afb802baec 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -133,6 +133,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): room_id=room_id, event_type=event_type, state_key=state_key, + is_guest=is_guest, ) if not data: @@ -143,7 +144,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -175,7 +176,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -220,7 +221,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -289,7 +290,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler events = yield handler.get_state_events( room_id=room_id, @@ -325,7 +326,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -334,6 +335,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): msgs = yield handler.get_messages( room_id=room_id, user_id=user.to_string(), + is_guest=is_guest, pagin_config=pagination_config, as_client_event=as_client_event ) @@ -347,12 +349,13 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( room_id=room_id, user_id=user.to_string(), + is_guest=is_guest, ) defer.returnValue((200, events)) @@ -363,7 +366,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -443,7 +446,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -524,7 +527,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -564,7 +567,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) @@ -597,7 +600,7 @@ class SearchRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 0a863e1c61..eb7c57cade 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 4692ba413c..1970ad3458 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet): if LoginType.PASSWORD in result: # if using password, they should also be logged in - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if auth_user.to_string() != result[LoginType.PASSWORD]: raise LoginError(400, "", Codes.UNKNOWN) user_id = auth_user.to_string() @@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet): def on_GET(self, request): yield run_on_reactor() - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) threepids = yield self.hs.get_datastore().user_get_threepids( auth_user.to_string() @@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) threePidCreds = body['threePidCreds'] - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index f8f91b63f5..97956a4b91 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot get filters for other users") @@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot create filters for other users") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index a1f4423101..820d33336f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. @@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() result = yield self.store.count_e2e_one_time_keys(user_id, device_id) @@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index b107b7ce17..788acd4adb 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1ba2f29711..f899376311 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError, Codes +from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet from ._base import client_v2_pattern, parse_json_dict_from_request @@ -55,6 +55,19 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() + kind = "user" + if "kind" in request.args: + kind = request.args["kind"][0] + + if kind == "guest": + ret = yield self._do_guest_registration() + defer.returnValue(ret) + return + elif kind != "user": + raise UnrecognizedRequestError( + "Do not understand membership kind: %s" % (kind,) + ) + if '/register/email/requestToken' in request.path: ret = yield self.onEmailTokenRequest(request) defer.returnValue(ret) @@ -236,6 +249,18 @@ class RegisterRestServlet(RestServlet): ret = yield self.identity_handler.requestEmailToken(**body) defer.returnValue((200, ret)) + @defer.inlineCallbacks + def _do_guest_registration(self): + if not self.hs.config.allow_guest_access: + defer.returnValue((403, "Guest access is disabled")) + user_id, _ = yield self.registration_handler.register(generate_token=False) + access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + defer.returnValue((200, { + "user_id": user_id, + "access_token": access_token, + "home_server": self.hs.hostname, + })) + def register_servlets(hs, http_server): RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 32a1087c91..d24507effa 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -81,7 +81,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index dcfe6bd20e..35482ae6a6 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -42,7 +42,7 @@ class TagListServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, room_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot get tags for other users.") @@ -68,7 +68,7 @@ class TagServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id, room_id, tag): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") @@ -88,7 +88,7 @@ class TagServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, user_id, room_id, tag): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index c28dc86cd7..e4fa8c4647 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 6abaf56b25..7d61596082 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/server.py b/synapse/server.py index 8424798b1b..f75d5358b2 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -29,7 +29,6 @@ from synapse.state import StateHandler from synapse.storage import DataStore from synapse.util import Clock from synapse.util.distributor import Distributor -from synapse.util.lockutils import LockManager from synapse.streams.events import EventSources from synapse.api.ratelimiting import Ratelimiter from synapse.crypto.keyring import Keyring @@ -70,7 +69,6 @@ class BaseHomeServer(object): 'auth', 'rest_servlet_factory', 'state_handler', - 'room_lock_manager', 'notifier', 'distributor', 'resource_for_client', @@ -201,9 +199,6 @@ class HomeServer(BaseHomeServer): def build_state_handler(self): return StateHandler(self) - def build_room_lock_manager(self): - return LockManager() - def build_distributor(self): return Distributor() diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e6c1abfc27..59c9987202 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore): self._store_room_message_txn(txn, event) elif event.type == EventTypes.Redaction: self._store_redaction(txn, event) + elif event.type == EventTypes.RoomHistoryVisibility: + self._store_history_visibility_txn(txn, event) self._store_room_members_txn( txn, diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index b454dd5b3a..2e5eddd259 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -102,13 +102,14 @@ class RegistrationStore(SQLBaseStore): 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) - # it's possible for this to get a conflict, but only for a single user - # since tokens are namespaced based on their user ID - txn.execute( - "INSERT INTO access_tokens(id, user_id, token)" - " VALUES (?,?,?)", - (next_id, user_id, token,) - ) + if token: + # it's possible for this to get a conflict, but only for a single user + # since tokens are namespaced based on their user ID + txn.execute( + "INSERT INTO access_tokens(id, user_id, token)" + " VALUES (?,?,?)", + (next_id, user_id, token,) + ) def get_user_by_id(self, user_id): return self._simple_select_one( diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 13441fcdce..1c79626736 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore): txn, event, "content.body", event.content["body"] ) + def _store_history_visibility_txn(self, txn, event): + if hasattr(event, "content") and "history_visibility" in event.content: + sql = ( + "INSERT INTO history_visibility" + " (event_id, room_id, history_visibility)" + " VALUES (?, ?, ?)" + ) + txn.execute(sql, ( + event.event_id, + event.room_id, + event.content["history_visibility"] + )) + def _store_event_search_txn(self, txn, event, key, value): if isinstance(self.database_engine, PostgresEngine): sql = ( diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/schema/delta/25/history_visibility.sql new file mode 100644 index 0000000000..9f387ed69f --- /dev/null +++ b/synapse/storage/schema/delta/25/history_visibility.sql @@ -0,0 +1,26 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of history_visibility content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS history_visibility( + id INTEGER PRIMARY KEY, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + history_visibility TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 15d4c2bf68..be8ba76aae 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -158,14 +158,40 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @log_function - def get_room_events_stream(self, user_id, from_key, to_key, room_id, - limit=0): - current_room_membership_sql = ( - "SELECT m.room_id FROM room_memberships as m " - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id AND c.state_key = m.user_id" - " WHERE m.user_id = ? AND m.membership = 'join'" - ) + def get_room_events_stream( + self, + user_id, + from_key, + to_key, + limit=0, + is_guest=False, + room_ids=None + ): + room_ids = room_ids or [] + room_ids = [r for r in room_ids] + if is_guest: + current_room_membership_sql = ( + "SELECT c.room_id FROM history_visibility AS h" + " INNER JOIN current_state_events AS c" + " ON h.event_id = c.event_id" + " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % ( + ",".join(map(lambda _: "?", room_ids)) + ) + ) + current_room_membership_args = room_ids + else: + current_room_membership_sql = ( + "SELECT m.room_id FROM room_memberships as m " + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id AND c.state_key = m.user_id" + " WHERE m.user_id = ? AND m.membership = 'join'" + ) + current_room_membership_args = [user_id] + if room_ids: + current_room_membership_sql += " AND m.room_id in (%s)" % ( + ",".join(map(lambda _: "?", room_ids)) + ) + current_room_membership_args = [user_id] + room_ids # We also want to get any membership events about that user, e.g. # invites or leave notifications. @@ -174,6 +200,7 @@ class StreamStore(SQLBaseStore): "INNER JOIN current_state_events as c ON m.event_id = c.event_id " "WHERE m.user_id = ? " ) + membership_args = [user_id] if limit: limit = max(limit, MAX_STREAM_SIZE) @@ -200,7 +227,9 @@ class StreamStore(SQLBaseStore): } def f(txn): - txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,)) + args = ([False] + current_room_membership_args + membership_args + + [from_id.stream, to_id.stream]) + txn.execute(sql, args) rows = self.cursor_to_dict(txn) diff --git a/synapse/util/lockutils.py b/synapse/util/lockutils.py deleted file mode 100644 index 33edc5c20e..0000000000 --- a/synapse/util/lockutils.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -import logging - - -logger = logging.getLogger(__name__) - - -class Lock(object): - - def __init__(self, deferred, key): - self._deferred = deferred - self.released = False - self.key = key - - def release(self): - self.released = True - self._deferred.callback(None) - - def __del__(self): - if not self.released: - logger.critical("Lock was destructed but never released!") - self.release() - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - logger.debug("Releasing lock for key=%r", self.key) - self.release() - - -class LockManager(object): - """ Utility class that allows us to lock based on a `key` """ - - def __init__(self): - self._lock_deferreds = {} - - @defer.inlineCallbacks - def lock(self, key): - """ Allows us to block until it is our turn. - Args: - key (str) - Returns: - Lock - """ - new_deferred = defer.Deferred() - old_deferred = self._lock_deferreds.get(key) - self._lock_deferreds[key] = new_deferred - - if old_deferred: - logger.debug("Queueing on lock for key=%r", key) - yield old_deferred - logger.debug("Obtained lock for key=%r", key) - else: - logger.debug("Entering uncontended lock for key=%r", key) - - defer.returnValue(Lock(new_deferred, key)) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index c96273480d..70d928defe 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_appservice_bad_token(self): @@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase): request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), masquerading_user_id) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): @@ -158,6 +158,25 @@ class AuthTestCase(unittest.TestCase): user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) + @defer.inlineCallbacks + def test_get_guest_user_from_macaroon(self): + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("guest = true") + serialized = macaroon.serialize() + + user_info = yield self.auth._get_user_from_macaroon(serialized) + user = user_info["user"] + is_guest = user_info["is_guest"] + self.assertEqual(UserID.from_string(user_id), user) + self.assertTrue(is_guest) + @defer.inlineCallbacks def test_get_user_from_macaroon_user_db_mismatch(self): self.store.get_user_by_access_token = Mock( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 29372d488a..10d4482cce 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): {"presence": ONLINE} ) + # Apple sees self-reflection even without room_id + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + ) + + self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEquals(events, + [ + {"type": "m.presence", + "content": { + "user_id": "@apple:test", + "presence": ONLINE, + "last_active_ago": 0, + }}, + ], + msg="Presence event should be visible to self-reflection" + ) + # Apple sees self-reflection - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Banana sees it because of presence subscription - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_banana, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_banana, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Elderberry sees it because of same room - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_elderberry, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_elderberry, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Durian is not in the room, should not see this event - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_durian, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_durian, + from_key=0, + room_ids=[], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): "accepted": True}, ], presence) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 1, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=1, ) self.assertEquals(self.event_source.get_current_key(), 2) @@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) ) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.assertEquals(self.event_source.get_current_key(), 1) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id,] ) self.assertEquals(events, [ @@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.assertEquals(self.event_source.get_current_key(), 2) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id,] ) self.assertEquals(events, [ @@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.room_members.append(self.u_clementine) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, ) self.assertEquals(self.event_source.get_current_key(), 1) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 41bb08b7ca..2d7ba43561 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0 + ) self.assertEquals( events[0], [ @@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase): yield put_json.await_calls() self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 2) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=1, + ) self.assertEquals( events[0], [ @@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0e3b922246..7f29d73d95 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -47,7 +47,14 @@ class NullSource(object): def __init__(self, hs): pass - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events( + self, + user, + from_key, + room_ids=None, + limit=None, + is_guest=None + ): return defer.succeed(([], from_key)) def get_current_key(self, direction='f'): @@ -86,10 +93,11 @@ class PresenceStateTestCase(unittest.TestCase): return defer.succeed([]) self.datastore.get_presence_list = get_presence_list - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(myid), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -173,10 +181,11 @@ class PresenceListTestCase(unittest.TestCase): ) self.datastore.has_presence_state = has_presence_state - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(myid), "token_id": 1, + "is_guest": False, } hs.handlers.room_member_handler = Mock( @@ -291,8 +300,8 @@ class PresenceEventStreamTestCase(unittest.TestCase): hs.get_clock().time_msec.return_value = 1000000 - def _get_user_by_req(req=None): - return (UserID.from_string(myid), "") + def _get_user_by_req(req=None, allow_guest=False): + return (UserID.from_string(myid), "", False) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 929e5e5dd4..adcc1d1969 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase): replication_layer=Mock(), ) - def _get_user_by_req(request=None): - return (UserID.from_string(myid), "") + def _get_user_by_req(request=None, allow_guest=False): + return (UserID.from_string(myid), "", False) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 93896dd076..b43563fa4b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase): self.auth_user_id = self.user_id - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 6395ce79db..61b9cc743b 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -115,7 +116,10 @@ class RoomTypingTestCase(RestTestCase): self.assertEquals(200, code) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.user, 0, None) + events = yield self.event_source.get_new_events( + from_key=0, + room_ids=[self.room_id], + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index f45570a1c0..fa9e17ec4f 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase): resource_for_federation=self.mock_resource, ) - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.USER_ID), "token_id": 1, + "is_guest": False, } hs.get_auth()._get_user_by_access_token = _get_user_by_access_token diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index b57006fcb4..dbf9700e6a 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -120,7 +120,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -149,7 +148,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -199,7 +197,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -228,7 +225,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index a658a789aa..e5c2c5cc8e 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -68,7 +68,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -105,7 +104,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -147,7 +145,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) # We should not get the message, as it happened *after* bob left. @@ -175,7 +172,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) # We should not get the message, as it happened *after* bob left. diff --git a/tests/util/test_lock.py b/tests/util/test_lock.py deleted file mode 100644 index 6a1e521b1e..0000000000 --- a/tests/util/test_lock.py +++ /dev/null @@ -1,108 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer -from tests import unittest - -from synapse.util.lockutils import LockManager - - -class LockManagerTestCase(unittest.TestCase): - - def setUp(self): - self.lock_manager = LockManager() - - @defer.inlineCallbacks - def test_one_lock(self): - key = "test" - deferred_lock1 = self.lock_manager.lock(key) - - self.assertTrue(deferred_lock1.called) - - lock1 = yield deferred_lock1 - - self.assertFalse(lock1.released) - - lock1.release() - - self.assertTrue(lock1.released) - - @defer.inlineCallbacks - def test_concurrent_locks(self): - key = "test" - deferred_lock1 = self.lock_manager.lock(key) - deferred_lock2 = self.lock_manager.lock(key) - - self.assertTrue(deferred_lock1.called) - self.assertFalse(deferred_lock2.called) - - lock1 = yield deferred_lock1 - - self.assertFalse(lock1.released) - self.assertFalse(deferred_lock2.called) - - lock1.release() - - self.assertTrue(lock1.released) - self.assertTrue(deferred_lock2.called) - - lock2 = yield deferred_lock2 - - lock2.release() - - @defer.inlineCallbacks - def test_sequential_locks(self): - key = "test" - deferred_lock1 = self.lock_manager.lock(key) - - self.assertTrue(deferred_lock1.called) - - lock1 = yield deferred_lock1 - - self.assertFalse(lock1.released) - - lock1.release() - - self.assertTrue(lock1.released) - - deferred_lock2 = self.lock_manager.lock(key) - - self.assertTrue(deferred_lock2.called) - - lock2 = yield deferred_lock2 - - self.assertFalse(lock2.released) - - lock2.release() - - self.assertTrue(lock2.released) - - @defer.inlineCallbacks - def test_with_statement(self): - key = "test" - with (yield self.lock_manager.lock(key)) as lock: - self.assertFalse(lock.released) - - self.assertTrue(lock.released) - - @defer.inlineCallbacks - def test_two_with_statement(self): - key = "test" - with (yield self.lock_manager.lock(key)): - pass - - with (yield self.lock_manager.lock(key)): - pass diff --git a/tests/utils.py b/tests/utils.py index 4da51291a4..ca2c33cf8e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -335,7 +335,7 @@ class MemoryDataStore(object): ] def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - room_id=None, limit=0, with_feedback=False): + limit=0, with_feedback=False): return ([], from_key) # TODO def get_joined_hosts_for_room(self, room_id):