From 935b98c474f030f92bdd28cd69fcf20f3d6045fd Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 27 Jun 2024 00:48:17 -0500 Subject: [PATCH] All `get_current_state_delta_membership_changes_for_user(...)` tests passing --- synapse/storage/databases/main/stream.py | 80 ++++++++++++++++-------- tests/storage/test_stream.py | 39 ++++++------ 2 files changed, 75 insertions(+), 44 deletions(-) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index e222f36bab..9ae1fe6c15 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -63,7 +63,7 @@ from typing_extensions import Literal from twisted.internet import defer -from synapse.api.constants import Direction, EventTypes +from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -125,12 +125,12 @@ class CurrentStateDeltaMembership: sender: The person who sent the membership event """ - event_id: str + event_id: Optional[str] event_pos: PersistedEventPosition prev_event_id: Optional[str] room_id: str membership: str - sender: str + sender: Optional[str] def generate_pagination_where_clause( @@ -819,22 +819,32 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # longer in the room or a state reset happened and it was unset. # `stream_ordering` is unique across the Synapse instance so this should # work fine. + # + # We `COALESCE` the `instance_name` and `stream_ordering` because we prefer + # the source of truth from the events table. This gives slightly more + # accurate results when available since `current_state_delta_stream` only + # tracks that the current state is at this stream position (not what stream + # position the state event was added) and batches events at the same + # `stream_id` in certain cases. + # + # TODO: We need to add indexes for `current_state_delta_stream.event_id` and + # `current_state_delta_stream.state_key`/`current_state_delta_stream.type` + # for this to be efficient. sql = """ SELECT e.event_id, s.prev_event_id, s.room_id, - e.instance_name, - e.stream_ordering, + COALESCE(e.instance_name, s.instance_name), + COALESCE(e.stream_ordering, s.stream_id), e.topological_ordering, m.membership, e.sender FROM current_state_delta_stream AS s - INNER JOIN events AS e ON e.stream_ordering = s.stream_id - INNER JOIN room_memberships AS m ON m.event_stream_ordering = s.stream_id + LEFT JOIN events AS e ON e.event_id = s.event_id + LEFT JOIN room_memberships AS m ON m.event_id = s.event_id WHERE s.stream_id > ? AND s.stream_id <= ? - AND m.user_id = ? - AND s.state_key = m.user_id + AND s.state_key = ? AND s.type = ? ORDER BY s.stream_id ASC """ @@ -842,6 +852,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute(sql, args) membership_changes: List[CurrentStateDeltaMembership] = [] + membership_change_map: Dict[str, CurrentStateDeltaMembership] = {} for ( event_id, prev_event_id, @@ -852,36 +863,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): membership, sender, ) in txn: - assert event_id is not None - # `prev_event_id` can be `None` assert room_id is not None assert instance_name is not None assert stream_ordering is not None - assert topological_ordering is not None - assert membership is not None - assert sender is not None if _filter_results( from_key, to_key, instance_name, + # TODO: This isn't always filled now topological_ordering, stream_ordering, ): - membership_changes.append( - CurrentStateDeltaMembership( - event_id=event_id, - event_pos=PersistedEventPosition( - instance_name=instance_name, - stream=stream_ordering, - ), - prev_event_id=prev_event_id, - room_id=room_id, - membership=membership, - sender=sender, + # When the server leaves a room, it will insert new rows with + # `event_id = null` for all current state. This means we might + # already have a row for the leave event and then another for the + # same leave where the `event_id=null` but the `prev_event_id` is + # pointing back at the earlier leave event. Since we're assuming the + # `event_id = null` row is a `leave` and we don't want duplicate + # membership changes in our results, let's get rid of those + # (deduplicate) (see `test_server_left_after_us_room`). + if event_id is None: + already_tracked_membership_change = membership_change_map.get( + prev_event_id ) + if ( + already_tracked_membership_change is not None + and already_tracked_membership_change.membership + == Membership.LEAVE + ): + continue + + membership_change = CurrentStateDeltaMembership( + event_id=event_id, + event_pos=PersistedEventPosition( + instance_name=instance_name, + stream=stream_ordering, + ), + prev_event_id=prev_event_id, + room_id=room_id, + membership=( + membership if membership is not None else Membership.LEAVE + ), + sender=sender, ) + membership_changes.append(membership_change) + if event_id: + membership_change_map[event_id] = membership_change + return membership_changes membership_changes = await self.db_pool.runInteraction( diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 5b30d7106f..ffa763bff2 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -46,7 +46,7 @@ from synapse.types import ( from synapse.util import Clock from tests.test_utils.event_injection import create_event -from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase, skip_unless +from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase logger = logging.getLogger(__name__) @@ -829,17 +829,16 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase): sender=user1_id, ), CurrentStateDeltaMembership( - event_id=leave_response1["event_id"], + event_id=None, # leave_response1["event_id"], event_pos=leave_pos1, prev_event_id=join_response1["event_id"], room_id=room_id1, membership="leave", - sender=user1_id, + sender=None, # user1_id, ), ], ) - @skip_unless(False, "We don't support this yet") def test_membership_persisted_in_same_batch(self) -> None: """ Test batch of membership events being processed at once. This will result in all @@ -954,7 +953,6 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase): ], ) - @skip_unless(False, "We don't support this yet") def test_state_reset(self) -> None: """ Test a state reset scenario where the user gets removed from the room (when @@ -970,7 +968,7 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase): before_reset_token = self.event_sources.get_current_token() - # Send another state event which we will cause the reset at + # Send another state event to make a position for the state reset to happen at dummy_state_response = self.helper.send_state( room_id1, event_type="foobarbaz", @@ -1011,6 +1009,12 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase): ) ) + # Manually bust the cache since we we're just manually messing with the database + # and not causing an actual state reset. + self.store._membership_stream_cache.entity_has_changed( + user1_id, dummy_state_pos.stream + ) + after_reset_token = self.event_sources.get_current_token() membership_changes = self.get_success( @@ -1025,19 +1029,16 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase): self.maxDiff = None self.assertEqual( membership_changes, - # TODO: Uncomment the expected membership. We just have a `False` value - # here so the test expectation fails and you look here. - False, - # [ - # CurrentStateDeltaMembership( - # event_id=TODO, - # event_pos=TODO, - # prev_event_id=None, - # room_id=room_id1, - # membership="leave", - # sender=user1_id, - # ), - # ], + [ + CurrentStateDeltaMembership( + event_id=None, + event_pos=dummy_state_pos, + prev_event_id=None, + room_id=room_id1, + membership="leave", + sender=None, # user1_id, + ), + ], ) def test_excluded_room_ids(self) -> None: