Fix filtering

This commit is contained in:
Eric Eastwood 2024-06-13 14:53:23 -05:00
parent 8244b25be8
commit 35808b38db
2 changed files with 35 additions and 16 deletions

View file

@ -18,7 +18,7 @@
#
#
import logging
from typing import TYPE_CHECKING, AbstractSet, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from immutabledict import immutabledict
@ -210,14 +210,16 @@ class SlidingSyncHandler:
for list_key, list_config in sync_config.lists.items():
# Apply filters
filtered_room_ids = room_id_set
filtered_sync_room_map = sync_room_map
if list_config.filters is not None:
filtered_room_ids = await self.filter_rooms(
sync_config.user, room_id_set, list_config.filters, to_token
filtered_sync_room_map = await self.filter_rooms(
sync_config.user, sync_room_map, list_config.filters, to_token
)
# TODO: Apply sorts
sorted_room_info = await self.sort_rooms(filtered_room_map, to_token)
sorted_room_info = await self.sort_rooms(
filtered_sync_room_map, to_token
)
ops: List[SlidingSyncResult.SlidingWindowList.Operation] = []
if list_config.ranges:
@ -509,18 +511,23 @@ class SlidingSyncHandler:
async def filter_rooms(
self,
user: UserID,
room_id_set: AbstractSet[str],
sync_room_map: Dict[str, RoomsForUser],
filters: SlidingSyncConfig.SlidingSyncList.Filters,
to_token: StreamToken,
) -> AbstractSet[str]:
) -> Dict[str, RoomsForUser]:
"""
Filter rooms based on the sync request.
Args:
user: User to filter rooms for
room_id_set: Set of room IDs to filter down
sync_room_map: Dictionary of room IDs to sort along with membership
information in the room at the time of `to_token`.
filters: Filters to apply
to_token: We filter based on the state of the room at this token
Returns:
A filtered dictionary of room IDs along with membership information in the
room at the time of `to_token`.
"""
user_id = user.to_string()
@ -529,7 +536,7 @@ class SlidingSyncHandler:
# TODO: Exclude partially stated rooms unless the `required_state` has
# `["m.room.member", "$LAZY"]`
filtered_room_id_set = set(room_id_set)
filtered_room_id_set = set(sync_room_map.keys())
# Filter for Direct-Message (DM) rooms
if filters.is_dm is not None:
@ -585,7 +592,8 @@ class SlidingSyncHandler:
if filters.not_tags:
raise NotImplementedError()
return filtered_room_id_set
# Assemble a new sync room map but only with the `filtered_room_id_set`
return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set}
async def sort_rooms(
self,

View file

@ -1216,11 +1216,20 @@ class FilterRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
sync_room_map = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=None,
to_token=after_rooms_token,
)
)
# Try with `is_dm=True`
truthy_filtered_room_ids = self.get_success(
truthy_filtered_room_map = self.get_success(
self.sliding_sync_handler.filter_rooms(
UserID.from_string(user1_id),
{room_id, dm_room_id},
sync_room_map,
SlidingSyncConfig.SlidingSyncList.Filters(
is_dm=True,
),
@ -1228,13 +1237,13 @@ class FilterRoomsTestCase(HomeserverTestCase):
)
)
self.assertEqual(truthy_filtered_room_ids, {dm_room_id})
self.assertEqual(truthy_filtered_room_map.keys(), {dm_room_id})
# Try with `is_dm=False`
falsy_filtered_room_ids = self.get_success(
falsy_filtered_room_map = self.get_success(
self.sliding_sync_handler.filter_rooms(
UserID.from_string(user1_id),
{room_id, dm_room_id},
sync_room_map,
SlidingSyncConfig.SlidingSyncList.Filters(
is_dm=False,
),
@ -1242,7 +1251,7 @@ class FilterRoomsTestCase(HomeserverTestCase):
)
)
self.assertEqual(falsy_filtered_room_ids, {room_id})
self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
class SortRoomsTestCase(HomeserverTestCase):
@ -1287,6 +1296,7 @@ class SortRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
sync_room_map = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
@ -1295,6 +1305,7 @@ class SortRoomsTestCase(HomeserverTestCase):
)
)
# Sort the rooms (what we're testing)
sorted_room_info = self.get_success(
self.sliding_sync_handler.sort_rooms(
sync_room_map=sync_room_map,