All get_current_state_delta_membership_changes_for_user(...) tests passing

This commit is contained in:
Eric Eastwood 2024-06-27 00:48:17 -05:00
parent 7eb1806ee3
commit 935b98c474
2 changed files with 75 additions and 44 deletions

View file

@ -63,7 +63,7 @@ from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Direction, EventTypes from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -125,12 +125,12 @@ class CurrentStateDeltaMembership:
sender: The person who sent the membership event sender: The person who sent the membership event
""" """
event_id: str event_id: Optional[str]
event_pos: PersistedEventPosition event_pos: PersistedEventPosition
prev_event_id: Optional[str] prev_event_id: Optional[str]
room_id: str room_id: str
membership: str membership: str
sender: str sender: Optional[str]
def generate_pagination_where_clause( def generate_pagination_where_clause(
@ -819,22 +819,32 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# longer in the room or a state reset happened and it was unset. # longer in the room or a state reset happened and it was unset.
# `stream_ordering` is unique across the Synapse instance so this should # `stream_ordering` is unique across the Synapse instance so this should
# work fine. # work fine.
#
# We `COALESCE` the `instance_name` and `stream_ordering` because we prefer
# the source of truth from the events table. This gives slightly more
# accurate results when available since `current_state_delta_stream` only
# tracks that the current state is at this stream position (not what stream
# position the state event was added) and batches events at the same
# `stream_id` in certain cases.
#
# TODO: We need to add indexes for `current_state_delta_stream.event_id` and
# `current_state_delta_stream.state_key`/`current_state_delta_stream.type`
# for this to be efficient.
sql = """ sql = """
SELECT SELECT
e.event_id, e.event_id,
s.prev_event_id, s.prev_event_id,
s.room_id, s.room_id,
e.instance_name, COALESCE(e.instance_name, s.instance_name),
e.stream_ordering, COALESCE(e.stream_ordering, s.stream_id),
e.topological_ordering, e.topological_ordering,
m.membership, m.membership,
e.sender e.sender
FROM current_state_delta_stream AS s FROM current_state_delta_stream AS s
INNER JOIN events AS e ON e.stream_ordering = s.stream_id LEFT JOIN events AS e ON e.event_id = s.event_id
INNER JOIN room_memberships AS m ON m.event_stream_ordering = s.stream_id LEFT JOIN room_memberships AS m ON m.event_id = s.event_id
WHERE s.stream_id > ? AND s.stream_id <= ? WHERE s.stream_id > ? AND s.stream_id <= ?
AND m.user_id = ? AND s.state_key = ?
AND s.state_key = m.user_id
AND s.type = ? AND s.type = ?
ORDER BY s.stream_id ASC ORDER BY s.stream_id ASC
""" """
@ -842,6 +852,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
membership_changes: List[CurrentStateDeltaMembership] = [] membership_changes: List[CurrentStateDeltaMembership] = []
membership_change_map: Dict[str, CurrentStateDeltaMembership] = {}
for ( for (
event_id, event_id,
prev_event_id, prev_event_id,
@ -852,36 +863,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
membership, membership,
sender, sender,
) in txn: ) in txn:
assert event_id is not None
# `prev_event_id` can be `None`
assert room_id is not None assert room_id is not None
assert instance_name is not None assert instance_name is not None
assert stream_ordering is not None assert stream_ordering is not None
assert topological_ordering is not None
assert membership is not None
assert sender is not None
if _filter_results( if _filter_results(
from_key, from_key,
to_key, to_key,
instance_name, instance_name,
# TODO: This isn't always filled now
topological_ordering, topological_ordering,
stream_ordering, stream_ordering,
): ):
membership_changes.append( # When the server leaves a room, it will insert new rows with
CurrentStateDeltaMembership( # `event_id = null` for all current state. This means we might
event_id=event_id, # already have a row for the leave event and then another for the
event_pos=PersistedEventPosition( # same leave where the `event_id=null` but the `prev_event_id` is
instance_name=instance_name, # pointing back at the earlier leave event. Since we're assuming the
stream=stream_ordering, # `event_id = null` row is a `leave` and we don't want duplicate
), # membership changes in our results, let's get rid of those
prev_event_id=prev_event_id, # (deduplicate) (see `test_server_left_after_us_room`).
room_id=room_id, if event_id is None:
membership=membership, already_tracked_membership_change = membership_change_map.get(
sender=sender, prev_event_id
) )
if (
already_tracked_membership_change is not None
and already_tracked_membership_change.membership
== Membership.LEAVE
):
continue
membership_change = CurrentStateDeltaMembership(
event_id=event_id,
event_pos=PersistedEventPosition(
instance_name=instance_name,
stream=stream_ordering,
),
prev_event_id=prev_event_id,
room_id=room_id,
membership=(
membership if membership is not None else Membership.LEAVE
),
sender=sender,
) )
membership_changes.append(membership_change)
if event_id:
membership_change_map[event_id] = membership_change
return membership_changes return membership_changes
membership_changes = await self.db_pool.runInteraction( membership_changes = await self.db_pool.runInteraction(

View file

@ -46,7 +46,7 @@ from synapse.types import (
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils.event_injection import create_event from tests.test_utils.event_injection import create_event
from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase, skip_unless from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -829,17 +829,16 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
sender=user1_id, sender=user1_id,
), ),
CurrentStateDeltaMembership( CurrentStateDeltaMembership(
event_id=leave_response1["event_id"], event_id=None, # leave_response1["event_id"],
event_pos=leave_pos1, event_pos=leave_pos1,
prev_event_id=join_response1["event_id"], prev_event_id=join_response1["event_id"],
room_id=room_id1, room_id=room_id1,
membership="leave", membership="leave",
sender=user1_id, sender=None, # user1_id,
), ),
], ],
) )
@skip_unless(False, "We don't support this yet")
def test_membership_persisted_in_same_batch(self) -> None: def test_membership_persisted_in_same_batch(self) -> None:
""" """
Test batch of membership events being processed at once. This will result in all Test batch of membership events being processed at once. This will result in all
@ -954,7 +953,6 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
], ],
) )
@skip_unless(False, "We don't support this yet")
def test_state_reset(self) -> None: def test_state_reset(self) -> None:
""" """
Test a state reset scenario where the user gets removed from the room (when Test a state reset scenario where the user gets removed from the room (when
@ -970,7 +968,7 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
before_reset_token = self.event_sources.get_current_token() before_reset_token = self.event_sources.get_current_token()
# Send another state event which we will cause the reset at # Send another state event to make a position for the state reset to happen at
dummy_state_response = self.helper.send_state( dummy_state_response = self.helper.send_state(
room_id1, room_id1,
event_type="foobarbaz", event_type="foobarbaz",
@ -1011,6 +1009,12 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
) )
) )
# Manually bust the cache since we we're just manually messing with the database
# and not causing an actual state reset.
self.store._membership_stream_cache.entity_has_changed(
user1_id, dummy_state_pos.stream
)
after_reset_token = self.event_sources.get_current_token() after_reset_token = self.event_sources.get_current_token()
membership_changes = self.get_success( membership_changes = self.get_success(
@ -1025,19 +1029,16 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
self.maxDiff = None self.maxDiff = None
self.assertEqual( self.assertEqual(
membership_changes, membership_changes,
# TODO: Uncomment the expected membership. We just have a `False` value [
# here so the test expectation fails and you look here. CurrentStateDeltaMembership(
False, event_id=None,
# [ event_pos=dummy_state_pos,
# CurrentStateDeltaMembership( prev_event_id=None,
# event_id=TODO, room_id=room_id1,
# event_pos=TODO, membership="leave",
# prev_event_id=None, sender=None, # user1_id,
# room_id=room_id1, ),
# membership="leave", ],
# sender=user1_id,
# ),
# ],
) )
def test_excluded_room_ids(self) -> None: def test_excluded_room_ids(self) -> None: