diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 133183a257..71f7ab3d22 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -416,8 +416,6 @@ class RoomMemberHandler(BaseHandler): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" - elif action == "forget": - effective_membership_state = "leave" if third_party_signed is not None: replication = self.hs.get_replication_layer() @@ -473,9 +471,6 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts=remote_room_hosts, ) - if action == "forget": - yield self.forget(requester.user, room_id) - @defer.inlineCallbacks def send_membership_event( self, @@ -935,8 +930,24 @@ class RoomMemberHandler(BaseHandler): display_name = data["display_name"] defer.returnValue((token, public_keys, fallback_public_key, display_name)) + @defer.inlineCallbacks def forget(self, user, room_id): - return self.store.forget(user.to_string(), room_id) + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership != Membership.LEAVE: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + if membership: + yield self.store.forget(user_id, room_id) class RoomListHandler(BaseHandler): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index a1fa7daf79..b223fb7e5f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -405,6 +405,42 @@ class RoomEventContext(ClientV1RestServlet): defer.returnValue((200, results)) +class RoomForgetRestServlet(ClientV1RestServlet): + def register(self, http_server): + PATTERNS = ("/rooms/(?P[^/]*)/forget") + register_txn_path(self, PATTERNS, http_server) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, txn_id=None): + requester = yield self.auth.get_user_by_req( + request, + allow_guest=False, + ) + + yield self.handlers.room_member_handler.forget( + user=requester.user, + room_id=room_id, + ) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id, txn_id): + try: + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) + except KeyError: + pass + + response = yield self.on_POST( + request, room_id, txn_id + ) + + self.txns.store_client_transaction(request, txn_id, response) + defer.returnValue(response) + + # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): @@ -624,6 +660,7 @@ def register_servlets(hs, http_server): RoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server)