From 2bf39231ede3a9bcad65ad3f1321e788acfdcd15 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 26 Jun 2024 18:40:36 -0500 Subject: [PATCH] Add some tests for `get_current_state_delta_membership_changes_for_user(...)` --- synapse/storage/databases/main/stream.py | 14 +- tests/storage/test_stream.py | 515 +++++++++++++++++++++++ 2 files changed, 523 insertions(+), 6 deletions(-) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index f6be97698e..e222f36bab 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -63,7 +63,7 @@ from typing_extensions import Literal from twisted.internet import defer -from synapse.api.constants import Direction +from synapse.api.constants import Direction, EventTypes from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -807,7 +807,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): min_from_id = from_key.stream max_to_id = to_key.get_max_stream_pos() - args: List[Any] = [user_id, min_from_id, max_to_id] + args: List[Any] = [min_from_id, max_to_id, user_id, EventTypes.Member] # TODO: It would be good to assert that the `from_token`/`to_token` is >= # the first row in `current_state_delta_stream` for the rooms we're @@ -824,16 +824,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): e.event_id, s.prev_event_id, s.room_id, - s.instance_name, - s.stream_id, + e.instance_name, + e.stream_ordering, e.topological_ordering, m.membership, e.sender FROM current_state_delta_stream AS s INNER JOIN events AS e ON e.stream_ordering = s.stream_id INNER JOIN room_memberships AS m ON m.event_stream_ordering = s.stream_id - WHERE m.user_id = ? - AND s.stream_id > ? AND s.stream_id <= ? + WHERE s.stream_id > ? AND s.stream_id <= ? + AND m.user_id = ? + AND s.state_key = m.user_id + AND s.type = ? ORDER BY s.stream_id ASC """ diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index fe1e873e15..64f123987a 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -28,9 +28,12 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.filtering import Filter +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer +from synapse.storage.databases.main.stream import CurrentStateDeltaMembership from synapse.types import JsonDict, PersistedEventPosition, RoomStreamToken from synapse.util import Clock @@ -543,3 +546,515 @@ class GetLastEventInRoomBeforeStreamOrderingTestCase(HomeserverTestCase): } ), ) + + +class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase): + """ + Test `get_current_state_delta_membership_changes_for_user(...)` + """ + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.event_sources = hs.get_event_sources() + self.state_handler = self.hs.get_state_handler() + persistence = hs.get_storage_controllers().persistence + assert persistence is not None + self.persistence = persistence + + def test_returns_membership_events(self) -> None: + """ + A basic test that a membership event in the token range is returned for the user. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + before_room1_token = self.event_sources.get_current_token() + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) + join_pos = self.get_success( + self.store.get_position_for_event(join_response["event_id"]) + ) + + after_room1_token = self.event_sources.get_current_token() + + membership_changes = self.get_success( + self.store.get_current_state_delta_membership_changes_for_user( + user1_id, + from_key=before_room1_token.room_key, + to_key=after_room1_token.room_key, + ) + ) + + # Let the whole diff show on failure + self.maxDiff = None + self.assertEqual( + membership_changes, + [ + CurrentStateDeltaMembership( + event_id=join_response["event_id"], + event_pos=join_pos, + prev_event_id=None, + room_id=room_id1, + membership="join", + sender=user1_id, + ) + ], + ) + + def test_server_left_after_us_room(self) -> None: + """ + Test that when probing over part of the DAG where the server left the room *after + us*, we still see the join and leave changes. + + This is to make sure we play nicely with this behavior: When the server leaves a + room, it will insert new rows with `event_id = null` into the + `current_state_delta_stream` table for all current state. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + before_room1_token = self.event_sources.get_current_token() + + room_id1 = self.helper.create_room_as( + user2_id, + tok=user2_tok, + extra_content={ + "power_level_content_override": { + "users": { + user2_id: 100, + # Allow user1 to send state in the room + user1_id: 100, + } + } + }, + ) + join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) + join_pos1 = self.get_success( + self.store.get_position_for_event(join_response1["event_id"]) + ) + # Make sure random other non-member state that happens to have a state_key + # matching the user ID doesn't mess with things. + self.helper.send_state( + room_id1, + event_type="foobarbazdummy", + state_key=user1_id, + body={"foo": "bar"}, + tok=user1_tok, + ) + # User1 should leave the room first + leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) + leave_pos1 = self.get_success( + self.store.get_position_for_event(leave_response1["event_id"]) + ) + + # User2 should also leave the room (everyone has left the room which means the + # server is no longer in the room). + self.helper.leave(room_id1, user2_id, tok=user2_tok) + + after_room1_token = self.event_sources.get_current_token() + + membership_changes = self.get_success( + self.store.get_current_state_delta_membership_changes_for_user( + user1_id, + from_key=before_room1_token.room_key, + to_key=after_room1_token.room_key, + ) + ) + + # Let the whole diff show on failure + self.maxDiff = None + self.assertEqual( + membership_changes, + [ + CurrentStateDeltaMembership( + event_id=join_response1["event_id"], + event_pos=join_pos1, + prev_event_id=None, + room_id=room_id1, + membership="join", + sender=user1_id, + ), + CurrentStateDeltaMembership( + event_id=leave_response1["event_id"], + event_pos=leave_pos1, + prev_event_id=join_response1["event_id"], + room_id=room_id1, + membership="leave", + sender=user1_id, + ), + ], + ) + + def test_server_left_room(self) -> None: + """ + Test that when probing over part of the DAG where we leave the room causing the + server to leave the room (because we were the last local user in the room), we + still see the join and leave changes. + + This is to make sure we play nicely with this behavior: When the server leaves a + room, it will insert new rows with `event_id = null` into the + `current_state_delta_stream` table for all current state. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + before_room1_token = self.event_sources.get_current_token() + + room_id1 = self.helper.create_room_as( + user2_id, + tok=user2_tok, + extra_content={ + "power_level_content_override": { + "users": { + user2_id: 100, + # Allow user1 to send state in the room + user1_id: 100, + } + } + }, + ) + join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) + join_pos1 = self.get_success( + self.store.get_position_for_event(join_response1["event_id"]) + ) + # Make sure random other non-member state that happens to have a state_key + # matching the user ID doesn't mess with things. + self.helper.send_state( + room_id1, + event_type="foobarbazdummy", + state_key=user1_id, + body={"foo": "bar"}, + tok=user1_tok, + ) + + # User2 should leave the room first. + self.helper.leave(room_id1, user2_id, tok=user2_tok) + + # User1 (the person we're testing with) should also leave the room (everyone has + # left the room which means the server is no longer in the room). + leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) + leave_pos1 = self.get_success( + self.store.get_position_for_event(leave_response1["event_id"]) + ) + + after_room1_token = self.event_sources.get_current_token() + + membership_changes = self.get_success( + self.store.get_current_state_delta_membership_changes_for_user( + user1_id, + from_key=before_room1_token.room_key, + to_key=after_room1_token.room_key, + ) + ) + + # Let the whole diff show on failure + self.maxDiff = None + self.assertEqual( + membership_changes, + [ + CurrentStateDeltaMembership( + event_id=join_response1["event_id"], + event_pos=join_pos1, + prev_event_id=None, + room_id=room_id1, + membership="join", + sender=user1_id, + ), + CurrentStateDeltaMembership( + event_id=leave_response1["event_id"], + event_pos=leave_pos1, + prev_event_id=join_response1["event_id"], + room_id=room_id1, + membership="leave", + sender=user1_id, + ), + ], + ) + + def test_membership_persisted_in_same_batch(self) -> None: + """ + Test batch of membership events being processed at once. This will result in all + of the memberships being stored in the `current_state_delta_stream` table with + the same `stream_ordering` even though the individual events have different + `stream_ordering`s. + """ + user1_id = self.register_user("user1", "pass") + _user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + user3_id = self.register_user("user3", "pass") + _user3_tok = self.login(user3_id, "pass") + user4_id = self.register_user("user4", "pass") + _user4_tok = self.login(user4_id, "pass") + + before_room1_token = self.event_sources.get_current_token() + + # User2 is just the designated person to create the room (we do this across the + # tests to be consistent) + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + + # Persist the user1, user3, and user4 join events in the same batch so they all + # end up in the `current_state_delta_stream` table with the same + # stream_ordering. + join_event1 = make_event_from_dict( + { + "sender": user1_id, + "type": EventTypes.Member, + "state_key": user1_id, + "content": {"membership": "join"}, + "room_id": room_id1, + "depth": 0, + "origin_server_ts": 0, + "prev_events": [], + "auth_events": [], + }, + room_version=RoomVersions.V10, + ) + join_event_context1 = self.get_success( + self.state_handler.compute_event_context(join_event1) + ) + join_event3 = make_event_from_dict( + { + "sender": user3_id, + "type": EventTypes.Member, + "state_key": user3_id, + "content": {"membership": "join"}, + "room_id": room_id1, + "depth": 1, + "origin_server_ts": 1, + "prev_events": [], + "auth_events": [], + }, + room_version=RoomVersions.V10, + ) + join_event_context3 = self.get_success( + self.state_handler.compute_event_context(join_event3) + ) + join_event4 = make_event_from_dict( + { + "sender": user4_id, + "type": EventTypes.Member, + "state_key": user4_id, + "content": {"membership": "join"}, + "room_id": room_id1, + "depth": 2, + "origin_server_ts": 2, + "prev_events": [], + "auth_events": [], + }, + room_version=RoomVersions.V10, + ) + join_event_context4 = self.get_success( + self.state_handler.compute_event_context(join_event4) + ) + self.get_success( + self.persistence.persist_events( + [ + (join_event1, join_event_context1), + (join_event3, join_event_context3), + (join_event4, join_event_context4), + ] + ) + ) + + after_room1_token = self.event_sources.get_current_token() + + # Let's get membership changes from user3's perspective because it was in the + # middle of the batch. This way, if rows in` current_state_delta_stream` are + # stored with the first or last event's `stream_ordering`, we will still catch + # bugs. + membership_changes = self.get_success( + self.store.get_current_state_delta_membership_changes_for_user( + user3_id, + from_key=before_room1_token.room_key, + to_key=after_room1_token.room_key, + ) + ) + + join_pos3 = self.get_success( + self.store.get_position_for_event(join_event3.event_id) + ) + + # Let the whole diff show on failure + self.maxDiff = None + self.assertEqual( + membership_changes, + [ + CurrentStateDeltaMembership( + event_id=join_event3.event_id, + event_pos=join_pos3, + prev_event_id=None, + room_id=room_id1, + membership="join", + sender=user1_id, + ), + ], + ) + + # TODO: Test remote join where the first rows will just be the state when you joined + + # TODO: Test state reset where the user gets removed from the room (when there is no + # corresponding leave event) + + def test_excluded_room_ids(self) -> None: + """ + Test that the `excluded_room_ids` option excludes changes from the specified rooms. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + before_room1_token = self.event_sources.get_current_token() + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) + join_pos1 = self.get_success( + self.store.get_position_for_event(join_response1["event_id"]) + ) + + room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) + join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok) + join_pos2 = self.get_success( + self.store.get_position_for_event(join_response2["event_id"]) + ) + + after_room1_token = self.event_sources.get_current_token() + + # First test the the room is returned without the `excluded_room_ids` option + membership_changes = self.get_success( + self.store.get_current_state_delta_membership_changes_for_user( + user1_id, + from_key=before_room1_token.room_key, + to_key=after_room1_token.room_key, + ) + ) + + # Let the whole diff show on failure + self.maxDiff = None + self.assertEqual( + membership_changes, + [ + CurrentStateDeltaMembership( + event_id=join_response1["event_id"], + event_pos=join_pos1, + prev_event_id=None, + room_id=room_id1, + membership="join", + sender=user1_id, + ), + CurrentStateDeltaMembership( + event_id=join_response2["event_id"], + event_pos=join_pos2, + prev_event_id=None, + room_id=room_id2, + membership="join", + sender=user1_id, + ), + ], + ) + + # The test that `excluded_room_ids` excludes room2 as expected + membership_changes = self.get_success( + self.store.get_current_state_delta_membership_changes_for_user( + user1_id, + from_key=before_room1_token.room_key, + to_key=after_room1_token.room_key, + excluded_room_ids=[room_id2], + ) + ) + + # Let the whole diff show on failure + self.maxDiff = None + self.assertEqual( + membership_changes, + [ + CurrentStateDeltaMembership( + event_id=join_response1["event_id"], + event_pos=join_pos1, + prev_event_id=None, + room_id=room_id1, + membership="join", + sender=user1_id, + ) + ], + ) + + +# class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase(BaseMultiWorkerStreamTestCase): +# """ +# TODO +# """ + +# servlets = [ +# admin.register_servlets_for_client_rest_resource, +# room.register_servlets, +# login.register_servlets, +# ] + +# def default_config(self) -> dict: +# conf = super().default_config() +# conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] +# return conf + +# def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: +# self.sliding_sync_handler = self.hs.get_sliding_sync_handler() +# self.store = self.hs.get_datastores().main +# self.event_sources = hs.get_event_sources() + + +# def test_sharded_event_persisters(self) -> None: +# """ +# TODO +# """ +# user1_id = self.register_user("user1", "pass") +# user1_tok = self.login(user1_id, "pass") +# user2_id = self.register_user("user2", "pass") +# user2_tok = self.login(user2_id, "pass") + +# remote_hs = self.make_worker_hs("synapse.app.generic_worker") + +# channel = make_request( +# self.reactor, +# self._hs_to_site[hs], +# "GET", +# f"/_matrix/media/r0/download/{target}/{media_id}", +# shorthand=False, +# access_token=self.access_token, +# await_result=False, +# ) + +# remote_hs + +# worker_store2 = worker_hs2.get_datastores().main +# assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) +# actx = worker_store2._stream_id_gen.get_next() + +# self.assertEqual( +# room_id_results.keys(), +# { +# room_id1, +# # room_id2 shouldn't show up because we left before the from/to range +# # and the join event during the range happened while worker2 was stuck. +# # This means that from the perspective of the master, where the +# # `stuck_activity_token` is generated, the stream position for worker2 +# # wasn't advanced to the join yet. Looking at the `instance_map`, the +# # join technically comes after `stuck_activity_token``. +# # +# # room_id2, +# room_id3, +# }, +# )