From b83e82255630f2ba7ea959b68e5a82b4d938f000 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 13 Oct 2021 10:38:22 +0100 Subject: [PATCH] Stop user directory from failing if it encounters users not in the `users` table. (#11053) The following scenarios would halt the user directory updater: - user joins room - user leaves room - user present in room which switches from private to public, or vice versa. for two classes of users: - appservice senders - users missing from the user table. If this happened, the user directory would be stuck, unable to make forward progress. Exclude both cases from the user directory, so that we ignore them. Co-authored-by: Eric Eastwood Co-authored-by: reivilibre Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> Co-authored-by: Brendan Abolivier --- changelog.d/10825.misc | 1 + changelog.d/10970.misc | 1 + changelog.d/10996.misc | 1 + changelog.d/11053.bugfix | 2 + synapse/logging/opentracing.py | 8 + synapse/replication/http/_base.py | 148 ++--- synapse/storage/databases/main/client_ips.py | 13 +- .../storage/databases/main/user_directory.py | 24 +- synapse/storage/state.py | 172 +++++- tests/handlers/test_user_directory.py | 65 ++- tests/storage/test_client_ips.py | 43 ++ tests/storage/test_state.py | 513 +++++++++++++++++- tests/storage/test_user_directory.py | 17 +- 13 files changed, 918 insertions(+), 90 deletions(-) create mode 100644 changelog.d/10825.misc create mode 100644 changelog.d/10970.misc create mode 100644 changelog.d/10996.misc create mode 100644 changelog.d/11053.bugfix diff --git a/changelog.d/10825.misc b/changelog.d/10825.misc new file mode 100644 index 0000000000..f9786164d7 --- /dev/null +++ b/changelog.d/10825.misc @@ -0,0 +1 @@ +Add an 'approximate difference' method to `StateFilter`. diff --git a/changelog.d/10970.misc b/changelog.d/10970.misc new file mode 100644 index 0000000000..bb75ea79a6 --- /dev/null +++ b/changelog.d/10970.misc @@ -0,0 +1 @@ +Fix inconsistent behavior of `get_last_client_by_ip` when reporting data that has not been stored in the database yet. diff --git a/changelog.d/10996.misc b/changelog.d/10996.misc new file mode 100644 index 0000000000..c830d7ec2c --- /dev/null +++ b/changelog.d/10996.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.21.0 that causes opentracing and Prometheus metrics for replication requests to be measured incorrectly. diff --git a/changelog.d/11053.bugfix b/changelog.d/11053.bugfix new file mode 100644 index 0000000000..a59cfac931 --- /dev/null +++ b/changelog.d/11053.bugfix @@ -0,0 +1,2 @@ +Fix a bug introduced in Synapse 1.45.0rc1 where the user directory would stop updating if it processed an event from a +user not in the `users` table. diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 5276c4bfcc..20d23a4260 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -807,6 +807,14 @@ def trace(func=None, opname=None): result.addCallbacks(call_back, err_back) else: + if inspect.isawaitable(result): + logger.error( + "@trace may not have wrapped %s correctly! " + "The function is not async but returned a %s.", + func.__qualname__, + type(result).__name__, + ) + scope.__exit__(None, None, None) return result diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index f1b78d09f9..e047ec74d8 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -182,85 +182,87 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): ) @trace(opname="outgoing_replication_request") - @outgoing_gauge.track_inprogress() async def send_request(*, instance_name="master", **kwargs): - if instance_name == local_instance_name: - raise Exception("Trying to send HTTP request to self") - if instance_name == "master": - host = master_host - port = master_port - elif instance_name in instance_map: - host = instance_map[instance_name].host - port = instance_map[instance_name].port - else: - raise Exception( - "Instance %r not in 'instance_map' config" % (instance_name,) + with outgoing_gauge.track_inprogress(): + if instance_name == local_instance_name: + raise Exception("Trying to send HTTP request to self") + if instance_name == "master": + host = master_host + port = master_port + elif instance_name in instance_map: + host = instance_map[instance_name].host + port = instance_map[instance_name].port + else: + raise Exception( + "Instance %r not in 'instance_map' config" % (instance_name,) + ) + + data = await cls._serialize_payload(**kwargs) + + url_args = [ + urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS + ] + + if cls.CACHE: + txn_id = random_string(10) + url_args.append(txn_id) + + if cls.METHOD == "POST": + request_func = client.post_json_get_json + elif cls.METHOD == "PUT": + request_func = client.put_json + elif cls.METHOD == "GET": + request_func = client.get_json + else: + # We have already asserted in the constructor that a + # compatible was picked, but lets be paranoid. + raise Exception( + "Unknown METHOD on %s replication endpoint" % (cls.NAME,) + ) + + uri = "http://%s:%s/_synapse/replication/%s/%s" % ( + host, + port, + cls.NAME, + "/".join(url_args), ) - data = await cls._serialize_payload(**kwargs) + try: + # We keep retrying the same request for timeouts. This is so that we + # have a good idea that the request has either succeeded or failed + # on the master, and so whether we should clean up or not. + while True: + headers: Dict[bytes, List[bytes]] = {} + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [ + b"Bearer " + replication_secret + ] + opentracing.inject_header_dict(headers, check_destination=False) + try: + result = await request_func(uri, data, headers=headers) + break + except RequestTimedOutError: + if not cls.RETRY_ON_TIMEOUT: + raise - url_args = [ - urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS - ] + logger.warning("%s request timed out; retrying", cls.NAME) - if cls.CACHE: - txn_id = random_string(10) - url_args.append(txn_id) + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + await clock.sleep(1) + except HttpResponseException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the main process that we should send to the client. (And + # importantly, not stack traces everywhere) + _outgoing_request_counter.labels(cls.NAME, e.code).inc() + raise e.to_synapse_error() + except Exception as e: + _outgoing_request_counter.labels(cls.NAME, "ERR").inc() + raise SynapseError(502, "Failed to talk to main process") from e - if cls.METHOD == "POST": - request_func = client.post_json_get_json - elif cls.METHOD == "PUT": - request_func = client.put_json - elif cls.METHOD == "GET": - request_func = client.get_json - else: - # We have already asserted in the constructor that a - # compatible was picked, but lets be paranoid. - raise Exception( - "Unknown METHOD on %s replication endpoint" % (cls.NAME,) - ) - - uri = "http://%s:%s/_synapse/replication/%s/%s" % ( - host, - port, - cls.NAME, - "/".join(url_args), - ) - - try: - # We keep retrying the same request for timeouts. This is so that we - # have a good idea that the request has either succeeded or failed on - # the master, and so whether we should clean up or not. - while True: - headers: Dict[bytes, List[bytes]] = {} - # Add an authorization header, if configured. - if replication_secret: - headers[b"Authorization"] = [b"Bearer " + replication_secret] - opentracing.inject_header_dict(headers, check_destination=False) - try: - result = await request_func(uri, data, headers=headers) - break - except RequestTimedOutError: - if not cls.RETRY_ON_TIMEOUT: - raise - - logger.warning("%s request timed out; retrying", cls.NAME) - - # If we timed out we probably don't need to worry about backing - # off too much, but lets just wait a little anyway. - await clock.sleep(1) - except HttpResponseException as e: - # We convert to SynapseError as we know that it was a SynapseError - # on the main process that we should send to the client. (And - # importantly, not stack traces everywhere) - _outgoing_request_counter.labels(cls.NAME, e.code).inc() - raise e.to_synapse_error() - except Exception as e: - _outgoing_request_counter.labels(cls.NAME, "ERR").inc() - raise SynapseError(502, "Failed to talk to main process") from e - - _outgoing_request_counter.labels(cls.NAME, 200).inc() - return result + _outgoing_request_counter.labels(cls.NAME, 200).inc() + return result return send_request diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c77acc7c84..6c1ef09049 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -538,15 +538,20 @@ class ClientIpStore(ClientIpWorkerStore): """ ret = await super().get_last_client_ip_by_device(user_id, device_id) - # Update what is retrieved from the database with data which is pending insertion. + # Update what is retrieved from the database with data which is pending + # insertion, as if it has already been stored in the database. for key in self._batch_row_update: - uid, access_token, ip = key + uid, _access_token, ip = key if uid == user_id: user_agent, did, last_seen = self._batch_row_update[key] + + if did is None: + # These updates don't make it to the `devices` table + continue + if not device_id or did == device_id: - ret[(user_id, device_id)] = { + ret[(user_id, did)] = { "user_id": user_id, - "access_token": access_token, "ip": ip, "user_agent": user_agent, "device_id": did, diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 5c713a732e..e98a45b6af 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,6 +26,8 @@ from typing import ( cast, ) +from synapse.api.errors import StoreError + if TYPE_CHECKING: from synapse.server import HomeServer @@ -383,7 +385,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): """Certain classes of local user are omitted from the user directory. Is this user one of them? """ - # App service users aren't usually contactable, so exclude them. + # We're opting to exclude the appservice sender (user defined by the + # `sender_localpart` in the appservice registration) even though + # technically it could be DM-able. In the future, this could potentially + # be configurable per-appservice whether the appservice sender can be + # contacted. + if self.get_app_service_by_user_id(user) is not None: + return False + + # We're opting to exclude appservice users (anyone matching the user + # namespace regex in the appservice registration) even though technically + # they could be DM-able. In the future, this could potentially + # be configurable per-appservice whether the appservice users can be + # contacted. if self.get_if_app_services_interested_in_user(user): # TODO we might want to make this configurable for each app service return False @@ -393,8 +407,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return False # Deactivated users aren't contactable, so should not appear in the user directory. - if await self.get_user_deactivated_status(user): + try: + if await self.get_user_deactivated_status(user): + return False + except StoreError: + # No such user in the users table. No need to do this when calling + # is_support_user---that returns False if the user is missing. return False + return True async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5e86befde4..b5ba1560d1 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -15,9 +15,11 @@ import logging from typing import ( TYPE_CHECKING, Awaitable, + Collection, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -29,7 +31,7 @@ from frozendict import frozendict from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad @@ -134,6 +136,23 @@ class StateFilter: include_others=True, ) + @staticmethod + def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): + """ + Returns a (frozen) StateFilter with the same contents as the parameters + specified here, which can be made of mutable types. + """ + types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} + for state_types, state_keys in types.items(): + if state_keys is not None: + types_with_frozen_values[state_types] = frozenset(state_keys) + else: + types_with_frozen_values[state_types] = None + + return StateFilter( + frozendict(types_with_frozen_values), include_others=include_others + ) + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the @@ -356,6 +375,157 @@ class StateFilter: return member_filter, non_member_filter + def _decompose_into_four_parts( + self, + ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: + """ + Decomposes this state filter into 4 constituent parts, which can be + thought of as this: + all? - minus_wildcards + plus_wildcards + plus_state_keys + + where + * all represents ALL state + * minus_wildcards represents entire state types to remove + * plus_wildcards represents entire state types to add + * plus_state_keys represents individual state keys to add + + See `recompose_from_four_parts` for the other direction of this + correspondence. + """ + is_all = self.include_others + excluded_types: Set[str] = {t for t in self.types if is_all} + wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} + concrete_keys: Set[StateKey] = set(self.concrete_types()) + + return (is_all, excluded_types), (wildcard_types, concrete_keys) + + @staticmethod + def _recompose_from_four_parts( + all_part: bool, + minus_wildcards: Set[str], + plus_wildcards: Set[str], + plus_state_keys: Set[StateKey], + ) -> "StateFilter": + """ + Recomposes a state filter from 4 parts. + + See `decompose_into_four_parts` (the other direction of this + correspondence) for descriptions on each of the parts. + """ + + # {state type -> set of state keys OR None for wildcard} + # (The same structure as that of a StateFilter.) + new_types: Dict[str, Optional[Set[str]]] = {} + + # if we start with all, insert the excluded statetypes as empty sets + # to prevent them from being included + if all_part: + new_types.update({state_type: set() for state_type in minus_wildcards}) + + # insert the plus wildcards + new_types.update({state_type: None for state_type in plus_wildcards}) + + # insert the specific state keys + for state_type, state_key in plus_state_keys: + if state_type in new_types: + entry = new_types[state_type] + if entry is not None: + entry.add(state_key) + elif not all_part: + # don't insert if the entire type is already included by + # include_others as this would actually shrink the state allowed + # by this filter. + new_types[state_type] = {state_key} + + return StateFilter.freeze(new_types, include_others=all_part) + + def approx_difference(self, other: "StateFilter") -> "StateFilter": + """ + Returns a state filter which represents `self - other`. + + This is useful for determining what state remains to be pulled out of the + database if we want the state included by `self` but already have the state + included by `other`. + + The returned state filter + - MUST include all state events that are included by this filter (`self`) + unless they are included by `other`; + - MUST NOT include state events not included by this filter (`self`); and + - MAY be an over-approximation: the returned state filter + MAY additionally include some state events from `other`. + + This implementation attempts to return the narrowest such state filter. + In the case that `self` contains wildcards for state types where + `other` contains specific state keys, an approximation must be made: + the returned state filter keeps the wildcard, as state filters are not + able to express 'all state keys except some given examples'. + e.g. + StateFilter(m.room.member -> None (wildcard)) + minus + StateFilter(m.room.member -> {'@wombat:example.org'}) + is approximated as + StateFilter(m.room.member -> None (wildcard)) + """ + + # We first transform self and other into an alternative representation: + # - whether or not they include all events to begin with ('all') + # - if so, which event types are excluded? ('excludes') + # - which entire event types to include ('wildcards') + # - which concrete state keys to include ('concrete state keys') + (self_all, self_excludes), ( + self_wildcards, + self_concrete_keys, + ) = self._decompose_into_four_parts() + (other_all, other_excludes), ( + other_wildcards, + other_concrete_keys, + ) = other._decompose_into_four_parts() + + # Start with an estimate of the difference based on self + new_all = self_all + # Wildcards from the other can be added to the exclusion filter + new_excludes = self_excludes | other_wildcards + # We remove wildcards that appeared as wildcards in the other + new_wildcards = self_wildcards - other_wildcards + # We filter out the concrete state keys that appear in the other + # as wildcards or concrete state keys. + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in self_concrete_keys + if state_type not in other_wildcards + } - other_concrete_keys + + if other_all: + if self_all: + # If self starts with all, then we add as wildcards any + # types which appear in the other's exclusion filter (but + # aren't in the self exclusion filter). This is as the other + # filter will return everything BUT the types in its exclusion, so + # we need to add those excluded types that also match the self + # filter as wildcard types in the new filter. + new_wildcards |= other_excludes.difference(self_excludes) + + # If other is an `include_others` then the difference isn't. + new_all = False + # (We have no need for excludes when we don't start with all, as there + # is nothing to exclude.) + new_excludes = set() + + # We also filter out all state types that aren't in the exclusion + # list of the other. + new_wildcards &= other_excludes + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in new_concrete_keys + if state_type in other_excludes + } + + # Transform our newly-constructed state filter from the alternative + # representation back into the normal StateFilter representation. + return StateFilter._recompose_from_four_parts( + new_all, new_excludes, new_wildcards, new_concrete_keys + ) + class StateGroupStorage: """High level interface to fetching state for event.""" diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index db65253773..0120b4688b 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -63,7 +63,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, - sender="@as:test", + # Note: this user does not match the regex above, so that tests + # can distinguish the sender from the AS user. + sender="@as_main:test", ) mock_load_appservices = Mock(return_value=[self.appservice]) @@ -122,7 +124,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): {(alice, bob, private), (bob, alice, private)}, ) - # The next three tests (test_population_excludes_*) all setup + # The next four tests (test_excludes_*) all setup # - A normal user included in the user dir # - A public and private room created by that user # - A user excluded from the room dir, belonging to both rooms @@ -179,6 +181,34 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) self._check_only_one_user_in_directory(user, public) + def test_excludes_appservice_sender(self) -> None: + user = self.register_user("user", "pass") + token = self.login(user, "pass") + room = self.helper.create_room_as(user, is_public=True, tok=token) + self.helper.join(room, self.appservice.sender, tok=self.appservice.token) + self._check_only_one_user_in_directory(user, room) + + def test_user_not_in_users_table(self) -> None: + """Unclear how it happens, but on matrix.org we've seen join events + for users who aren't in the users table. Test that we don't fall over + when processing such a user. + """ + user1 = self.register_user("user1", "pass") + token1 = self.login(user1, "pass") + room = self.helper.create_room_as(user1, is_public=True, tok=token1) + + # Inject a join event for a user who doesn't exist + self.get_success(inject_member_event(self.hs, room, "@not-a-user:test", "join")) + + # Another new user registers and joins the room + user2 = self.register_user("user2", "pass") + token2 = self.login(user2, "pass") + self.helper.join(room, user2, tok=token2) + + # The dodgy event should not have stopped us from processing user2's join. + in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) + self.assertEqual(set(in_public), {(user1, room), (user2, room)}) + def _create_rooms_and_inject_memberships( self, creator: str, token: str, joiner: str ) -> Tuple[str, str]: @@ -230,7 +260,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) profile = self.get_success(self.store.get_user_in_directory(support_user_id)) - self.assertTrue(profile is None) + self.assertIsNone(profile) display_name = "display_name" profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) @@ -264,7 +294,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is not in directory profile = self.get_success(self.store.get_user_in_directory(r_user_id)) - self.assertTrue(profile is None) + self.assertIsNone(profile) # update profile after deactivation self.get_success( @@ -273,7 +303,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is furthermore not in directory profile = self.get_success(self.store.get_user_in_directory(r_user_id)) - self.assertTrue(profile is None) + self.assertIsNone(profile) def test_handle_local_profile_change_with_appservice_user(self) -> None: # create user @@ -283,7 +313,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is not in directory profile = self.get_success(self.store.get_user_in_directory(as_user_id)) - self.assertTrue(profile is None) + self.assertIsNone(profile) # update profile profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3") @@ -293,7 +323,28 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is still not in directory profile = self.get_success(self.store.get_user_in_directory(as_user_id)) - self.assertTrue(profile is None) + self.assertIsNone(profile) + + def test_handle_local_profile_change_with_appservice_sender(self) -> None: + # profile is not in directory + profile = self.get_success( + self.store.get_user_in_directory(self.appservice.sender) + ) + self.assertIsNone(profile) + + # update profile + profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3") + self.get_success( + self.handler.handle_local_profile_change( + self.appservice.sender, profile_info + ) + ) + + # profile is still not in directory + profile = self.get_success( + self.store.get_user_in_directory(self.appservice.sender) + ) + self.assertIsNone(profile) def test_handle_user_deactivated_support_user(self) -> None: s_user_id = "@support:test" diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index dada4f98c9..0e4013ebea 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -146,6 +146,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) + @parameterized.expand([(False,), (True,)]) + def test_get_last_client_ip_by_device(self, after_persisting: bool): + """Test `get_last_client_ip_by_device` for persisted and unpersisted data""" + self.reactor.advance(12345678) + + user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success( + self.store.store_device( + user_id, + device_id, + "display name", + ) + ) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", device_id + ) + ) + + if after_persisting: + # Trigger the storage loop + self.reactor.advance(10) + + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + self.assertEqual( + result, + { + (user_id, device_id): { + "user_id": user_id, + "device_id": device_id, + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 12345678000, + }, + }, + ) + @parameterized.expand([(False,), (True,)]) def test_get_user_ip_and_agents(self, after_persisting: bool): """Test `get_user_ip_and_agents` for persisted and unpersisted data""" diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 32060f2abd..70d52b088c 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, TestCase logger = logging.getLogger(__name__) @@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) def test_get_state_for_event(self): - # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) @@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + +class StateFilterDifferenceTestCase(TestCase): + def assert_difference( + self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter + ): + self.assertEqual( + minuend.approx_difference(subtrahend), + expected, + f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", + ) + + def test_state_filter_difference_no_include_other_minus_no_include_other(self): + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b do not have the + include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Create: None}, include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.CanonicalAlias: {""}}, + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_include_other_minus_no_include_other(self): + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only a has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Create: None, + EventTypes.Member: set(), + EventTypes.CanonicalAlias: set(), + }, + include_others=True, + ), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + # This also shows that the resultant state filter is normalised. + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + StateFilter(types=frozendict(), include_others=True), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter( + types=frozendict(), + include_others=True, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.CanonicalAlias: {""}, + EventTypes.Member: set(), + }, + include_others=True, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + def test_state_filter_difference_include_other_minus_include_other(self): + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b have the include_others + flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + EventTypes.Create: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_no_include_other_minus_include_other(self): + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only b has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=True, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_simple_cases(self): + """ + Tests some very simple cases of the StateFilter approx_difference, + that are not explicitly tested by the more in-depth tests. + """ + + self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) + + self.assert_difference( + StateFilter.all(), + StateFilter.none(), + StateFilter.all(), + ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 9f483ad681..be3ed64f5e 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -256,7 +256,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) self.assertEqual(users, {u1, u2, u3}) - # The next three tests (test_population_excludes_*) all set up + # The next four tests (test_population_excludes_*) all set up # - A normal user included in the user dir # - A public and private room created by that user # - A user excluded from the room dir, belonging to both rooms @@ -364,6 +364,21 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): # Check the AS user is not in the directory. self._check_room_sharing_tables(user, public, private) + def test_population_excludes_appservice_sender(self) -> None: + user = self.register_user("user", "pass") + token = self.login(user, "pass") + + # Join the AS sender to rooms owned by the normal user. + public, private = self._create_rooms_and_inject_memberships( + user, token, self.appservice.sender + ) + + # Rebuild the directory. + self._purge_and_rebuild_user_dir() + + # Check the AS sender is not in the directory. + self._check_room_sharing_tables(user, public, private) + def test_population_conceals_private_nickname(self) -> None: # Make a private room, and set a nickname within user = self.register_user("aaaa", "pass")