Make StateFilter frozen so we can hash it (#10816)

Also enables Mypy for related tests.
This commit is contained in:
reivilibre 2021-09-14 16:35:53 +01:00 committed by GitHub
parent 14b8c0476f
commit 8eb7cb2e0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 30 deletions

1
changelog.d/10816.misc Normal file
View file

@ -0,0 +1 @@
Make `StateFilter` frozen so it is hashable.

View file

@ -86,6 +86,7 @@ files =
tests/handlers/test_sync.py, tests/handlers/test_sync.py,
tests/rest/client/test_login.py, tests/rest/client/test_login.py,
tests/rest/client/test_auth.py, tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py

View file

@ -25,12 +25,15 @@ from typing import (
) )
import attr import attr
from frozendict import frozendict
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases import Databases from synapse.storage.databases import Databases
@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
@attr.s(slots=True) @attr.s(slots=True, frozen=True)
class StateFilter: class StateFilter:
"""A filter used when querying for state. """A filter used when querying for state.
@ -53,14 +56,19 @@ class StateFilter:
appear in `types`. appear in `types`.
""" """
types = attr.ib(type=Dict[str, Optional[Set[str]]]) types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool) include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self): def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary # wildcards from the types dictionary
if self.include_others: if self.include_others:
self.types = {k: v for k, v in self.types.items() if v is not None} # this is needed to work around the fact that StateFilter is frozen
object.__setattr__(
self,
"types",
frozendict({k: v for k, v in self.types.items() if v is not None}),
)
@staticmethod @staticmethod
def all() -> "StateFilter": def all() -> "StateFilter":
@ -69,7 +77,7 @@ class StateFilter:
Returns: Returns:
The new state filter. The new state filter.
""" """
return StateFilter(types={}, include_others=True) return StateFilter(types=frozendict(), include_others=True)
@staticmethod @staticmethod
def none() -> "StateFilter": def none() -> "StateFilter":
@ -78,7 +86,7 @@ class StateFilter:
Returns: Returns:
The new state filter. The new state filter.
""" """
return StateFilter(types={}, include_others=False) return StateFilter(types=frozendict(), include_others=False)
@staticmethod @staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@ -103,7 +111,12 @@ class StateFilter:
type_dict.setdefault(typ, set()).add(s) # type: ignore type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict) return StateFilter(
types=frozendict(
(k, frozenset(v) if v is not None else None)
for k, v in type_dict.items()
)
)
@staticmethod @staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@ -116,7 +129,10 @@ class StateFilter:
Returns: Returns:
The new state filter The new state filter
""" """
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) return StateFilter(
types=frozendict({EventTypes.Member: frozenset(members)}),
include_others=True,
)
def return_expanded(self) -> "StateFilter": def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
@ -173,7 +189,7 @@ class StateFilter:
# We want to return all non-members, but only particular # We want to return all non-members, but only particular
# memberships # memberships
return StateFilter( return StateFilter(
types={EventTypes.Member: self.types[EventTypes.Member]}, types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True, include_others=True,
) )
@ -245,14 +261,15 @@ class StateFilter:
return len(self.concrete_types()) return len(self.concrete_types())
def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]: def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
"""Returns the state filtered with by this StateFilter """Returns the state filtered with by this StateFilter.
Args: Args:
state: The state map to filter state: The state map to filter
Returns: Returns:
The filtered state map The filtered state map.
This is a copy, so it's safe to mutate.
""" """
if self.is_full(): if self.is_full():
return dict(state_dict) return dict(state_dict)
@ -324,14 +341,16 @@ class StateFilter:
if state_keys is None: if state_keys is None:
member_filter = StateFilter.all() member_filter = StateFilter.all()
else: else:
member_filter = StateFilter({EventTypes.Member: state_keys}) member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others: elif self.include_others:
member_filter = StateFilter.all() member_filter = StateFilter.all()
else: else:
member_filter = StateFilter.none() member_filter = StateFilter.none()
non_member_filter = StateFilter( non_member_filter = StateFilter(
types={k: v for k, v in self.types.items() if k != EventTypes.Member}, types=frozendict(
{k: v for k, v in self.types.items() if k != EventTypes.Member}
),
include_others=self.include_others, include_others=self.include_others,
) )

View file

@ -14,6 +14,8 @@
import logging import logging
from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}}, types=frozendict(
{EventTypes.Member: frozenset({self.u_alice.to_string()})}
),
include_others=True, include_others=True,
), ),
) )
@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}),
include_others=True,
), ),
) )
) )
@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )
@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )
@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )