Refactor filter_events_for_server (#15240)

* Tweak docstring and type hint

* Flip logic and provide better name

* Separate decision from action

* Track a set of strings, not EventBases

* Require explicit boolean options from callers

* Add explicit option for partial state rooms

* Changelog

* Rename param
This commit is contained in:
David Robertson 2023-03-10 15:31:25 +00:00 committed by GitHub
parent e157c63f68
commit 4bb26c95a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 30 deletions

1
changelog.d/15240.misc Normal file
View file

@ -0,0 +1 @@
Refactor `filter_events_for_server`.

View file

@ -547,6 +547,8 @@ class PerDestinationQueue:
self._server_name, self._server_name,
new_pdus, new_pdus,
redact=False, redact=False,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
# If we've filtered out all the extremities, fall back to # If we've filtered out all the extremities, fall back to

View file

@ -392,7 +392,7 @@ class FederationHandler:
get_prev_content=False, get_prev_content=False,
) )
# We set `check_history_visibility_only` as we might otherwise get false # We unset `filter_out_erased_senders` as we might otherwise get false
# positives from users having been erased. # positives from users having been erased.
filtered_extremities = await filter_events_for_server( filtered_extremities = await filter_events_for_server(
self._storage_controllers, self._storage_controllers,
@ -400,7 +400,8 @@ class FederationHandler:
self.server_name, self.server_name,
events_to_check, events_to_check,
redact=False, redact=False,
check_history_visibility_only=True, filter_out_erased_senders=False,
filter_out_remote_partial_state_events=False,
) )
if filtered_extremities: if filtered_extremities:
extremities_to_request.append(bp.event_id) extremities_to_request.append(bp.event_id)
@ -1331,7 +1332,13 @@ class FederationHandler:
) )
events = await filter_events_for_server( events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, events self._storage_controllers,
origin,
self.server_name,
events,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
return events return events
@ -1362,7 +1369,13 @@ class FederationHandler:
await self._event_auth_handler.assert_host_in_room(event.room_id, origin) await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
events = await filter_events_for_server( events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, [event] self._storage_controllers,
origin,
self.server_name,
[event],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
event = events[0] event = events[0]
return event return event
@ -1390,7 +1403,13 @@ class FederationHandler:
) )
missing_events = await filter_events_for_server( missing_events = await filter_events_for_server(
self._storage_controllers, origin, self.server_name, missing_events self._storage_controllers,
origin,
self.server_name,
missing_events,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
return missing_events return missing_events

View file

@ -14,7 +14,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
from enum import Enum, auto from enum import Enum, auto
from typing import Collection, Dict, FrozenSet, List, Optional, Tuple from typing import (
Collection,
Dict,
FrozenSet,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
import attr import attr
from typing_extensions import Final from typing_extensions import Final
@ -565,29 +575,43 @@ async def filter_events_for_server(
storage: StorageControllers, storage: StorageControllers,
target_server_name: str, target_server_name: str,
local_server_name: str, local_server_name: str,
events: List[EventBase], events: Sequence[EventBase],
redact: bool = True, *,
check_history_visibility_only: bool = False, redact: bool,
filter_out_erased_senders: bool,
filter_out_remote_partial_state_events: bool,
) -> List[EventBase]: ) -> List[EventBase]:
"""Filter a list of events based on whether given server is allowed to """Filter a list of events based on whether the target server is allowed to
see them. see them.
For a fully stated room, the target server is allowed to see an event E if:
- the state at E has world readable or shared history vis, OR
- the state at E says that the target server is in the room.
For a partially stated room, the target server is allowed to see E if:
- E was created by this homeserver, AND:
- the partial state at E has world readable or shared history vis, OR
- the partial state at E says that the target server is in the room.
TODO: state before or state after?
Args: Args:
storage storage
server_name target_server_name
local_server_name
events events
redact: Whether to return a redacted version of the event, or redact: Controls what to do with events which have been filtered out.
to filter them out entirely. If True, include their redacted forms; if False, omit them entirely.
check_history_visibility_only: Whether to only check the filter_out_erased_senders: If true, also filter out events whose sender has been
history visibility, rather than things like if the sender has been
erased. This is used e.g. during pagination to decide whether to erased. This is used e.g. during pagination to decide whether to
backfill or not. backfill or not.
filter_out_remote_partial_state_events: If True, also filter out events in
partial state rooms created by other homeservers.
Returns Returns
The filtered events. The filtered events.
""" """
def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool: def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool:
if erased_senders and erased_senders[event.sender]: if erased_senders and erased_senders[event.sender]:
logger.info("Sender of %s has been erased, redacting", event.event_id) logger.info("Sender of %s has been erased, redacting", event.event_id)
return True return True
@ -616,7 +640,7 @@ async def filter_events_for_server(
# server has no users in the room: redact # server has no users in the room: redact
return False return False
if not check_history_visibility_only: if filter_out_erased_senders:
erased_senders = await storage.main.are_users_erased(e.sender for e in events) erased_senders = await storage.main.are_users_erased(e.sender for e in events)
else: else:
# We don't want to check whether users are erased, which is equivalent # We don't want to check whether users are erased, which is equivalent
@ -631,15 +655,15 @@ async def filter_events_for_server(
# otherwise a room could be fully joined after we retrieve those, which would then bypass # otherwise a room could be fully joined after we retrieve those, which would then bypass
# this check but would base the filtering on an outdated view of the membership events. # this check but would base the filtering on an outdated view of the membership events.
partial_state_invisible_events = set() partial_state_invisible_event_ids: Set[str] = set()
if not check_history_visibility_only: if filter_out_remote_partial_state_events:
for e in events: for e in events:
sender_domain = get_domain_from_id(e.sender) sender_domain = get_domain_from_id(e.sender)
if ( if (
sender_domain != local_server_name sender_domain != local_server_name
and await storage.main.is_partial_state_room(e.room_id) and await storage.main.is_partial_state_room(e.room_id)
): ):
partial_state_invisible_events.add(e) partial_state_invisible_event_ids.add(e.event_id)
# Let's check to see if all the events have a history visibility # Let's check to see if all the events have a history visibility
# of "shared" or "world_readable". If that's the case then we don't # of "shared" or "world_readable". If that's the case then we don't
@ -658,17 +682,20 @@ async def filter_events_for_server(
target_server_name, target_server_name,
) )
to_return = [] def include_event_in_output(e: EventBase) -> bool:
for e in events:
erased = is_sender_erased(e, erased_senders) erased = is_sender_erased(e, erased_senders)
visible = check_event_is_visible( visible = check_event_is_visible(
event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
) )
if e in partial_state_invisible_events: if e.event_id in partial_state_invisible_event_ids:
visible = False visible = False
if visible and not erased: return visible and not erased
to_return = []
for e in events:
if include_event_in_output(e):
to_return.append(e) to_return.append(e)
elif redact: elif redact:
to_return.append(prune_event(e)) to_return.append(prune_event(e))

View file

@ -63,7 +63,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server( filter_events_for_server(
self._storage_controllers, "test_server", "hs", events_to_filter self._storage_controllers,
"test_server",
"hs",
events_to_filter,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
) )
@ -85,7 +91,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
self.get_success( self.get_success(
filter_events_for_server( filter_events_for_server(
self._storage_controllers, "remote_hs", "hs", [outlier] self._storage_controllers,
"remote_hs",
"hs",
[outlier],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
), ),
[outlier], [outlier],
@ -96,7 +108,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server( filter_events_for_server(
self._storage_controllers, "remote_hs", "local_hs", [outlier, evt] self._storage_controllers,
"remote_hs",
"local_hs",
[outlier, evt],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
) )
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
@ -108,7 +126,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# be redacted) # be redacted)
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server( filter_events_for_server(
self._storage_controllers, "other_server", "local_hs", [outlier, evt] self._storage_controllers,
"other_server",
"local_hs",
[outlier, evt],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
) )
self.assertEqual(filtered[0], outlier) self.assertEqual(filtered[0], outlier)
@ -143,7 +167,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens. # ... and the filtering happens.
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server( filter_events_for_server(
self._storage_controllers, "test_server", "local_hs", events_to_filter self._storage_controllers,
"test_server",
"local_hs",
events_to_filter,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
) )
) )