diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 17c75f33c9..00c7dbec88 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -120,28 +120,7 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) - states = yield self.store.get_state_for_events( - room_id, [e.event_id for e in events], - ) - - events_and_states = zip(events, states) - - def allowed(event_and_state): - _, state = event_and_state - - membership = state.get((EventTypes.Member, user_id), None) - if membership and membership.membership == Membership.JOIN: - return True - - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history and history.content["visibility"] == "after_join": - return False - - events_and_states = filter(allowed, events_and_states) - events = [ - ev - for ev, _ in events_and_states - ] + events = yield self._filter_events_for_client(user_id, room_id, events) time_now = self.clock.time_msec() @@ -156,6 +135,36 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) + @defer.inlineCallbacks + def _filter_events_for_client(self, user_id, room_id, events): + states = yield self.store.get_state_for_events( + room_id, [e.event_id for e in events], + ) + + events_and_states = zip(events, states) + + def allowed(event_and_state): + event, state = event_and_state + + if event.type == EventTypes.RoomHistoryVisibility: + return True + + membership = state.get((EventTypes.Member, user_id), None) + if membership and membership.membership == Membership.JOIN: + return True + + history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + if history and history.content.get("visibility", None) == "after_join": + return False + + return True + + events_and_states = filter(allowed, events_and_states) + defer.returnValue([ + ev + for ev, _ in events_and_states + ]) + @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, client=None, txn_id=None): @@ -347,6 +356,10 @@ class MessageHandler(BaseHandler): ] ).addErrback(unwrapFirstError) + messages = yield self._filter_events_for_client( + user_id, event.room_id, messages + ) + start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) time_now = self.clock.time_msec() @@ -448,6 +461,10 @@ class MessageHandler(BaseHandler): consumeErrors=True, ).addErrback(unwrapFirstError) + messages = yield self._filter_events_for_client( + user_id, room_id, messages + ) + start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1])