Convert more cached return values to immutable types (#16356)

This commit is contained in:
Patrick Cloke 2023-09-20 07:48:55 -04:00 committed by GitHub
parent d7c89c5908
commit 7ec0a141b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 52 additions and 36 deletions

1
changelog.d/16356.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -37,7 +37,7 @@ from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.types import JsonDict, RoomID, UserID from synapse.types import JsonDict, JsonMapping, RoomID, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -191,7 +191,7 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
class FilterCollection: class FilterCollection:
def __init__(self, hs: "HomeServer", filter_json: JsonDict): def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
self._filter_json = filter_json self._filter_json = filter_json
room_filter_json = self._filter_json.get("room", {}) room_filter_json = self._filter_json.get("room", {})
@ -219,7 +219,7 @@ class FilterCollection:
def __repr__(self) -> str: def __repr__(self) -> str:
return "<FilterCollection %s>" % (json.dumps(self._filter_json),) return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
def get_filter_json(self) -> JsonDict: def get_filter_json(self) -> JsonMapping:
return self._filter_json return self._filter_json
def timeline_limit(self) -> int: def timeline_limit(self) -> int:
@ -313,7 +313,7 @@ class FilterCollection:
class Filter: class Filter:
def __init__(self, hs: "HomeServer", filter_json: JsonDict): def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
self._hs = hs self._hs = hs
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self.filter_json = filter_json self.filter_json = filter_json

View file

@ -64,7 +64,7 @@ from synapse.federation.transport.client import SendJoinResponse
from synapse.http.client import is_unknown_endpoint from synapse.http.client import is_unknown_endpoint
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -1704,7 +1704,7 @@ class FederationClient(FederationBase):
async def timestamp_to_event( async def timestamp_to_event(
self, self,
*, *,
destinations: List[str], destinations: StrCollection,
room_id: str, room_id: str,
timestamp: int, timestamp: int,
direction: Direction, direction: Direction,

View file

@ -1538,7 +1538,7 @@ class FederationEventHandler:
logger.exception("Failed to resync device for %s", sender) logger.exception("Failed to resync device for %s", sender)
async def backfill_event_id( async def backfill_event_id(
self, destinations: List[str], room_id: str, event_id: str self, destinations: StrCollection, room_id: str, event_id: str
) -> PulledPduInfo: ) -> PulledPduInfo:
"""Backfill a single event and persist it as a non-outlier which means """Backfill a single event and persist it as a non-outlier which means
we also pull in all of the state and auth events necessary for it. we also pull in all of the state and auth events necessary for it.

View file

@ -13,7 +13,17 @@
# limitations under the License. # limitations under the License.
import enum import enum
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Sequence,
)
import attr import attr
@ -245,7 +255,7 @@ class RelationsHandler:
async def get_references_for_events( async def get_references_for_events(
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
) -> Dict[str, List[_RelatedEvent]]: ) -> Mapping[str, Sequence[_RelatedEvent]]:
"""Get a list of references to the given events. """Get a list of references to the given events.
Args: Args:

View file

@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseErro
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, JsonMapping, UserID
from ._base import client_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
@ -41,7 +41,7 @@ class GetFilterRestServlet(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str, filter_id: str self, request: SynapseRequest, user_id: str, filter_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonMapping]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)

View file

@ -582,7 +582,7 @@ class StateStorageController:
@trace @trace
@tag_args @tag_args
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
"""Get current hosts in room based on current state. """Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms Blocks until we have full state for the given room. This only happens for rooms

View file

@ -25,7 +25,7 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING: if TYPE_CHECKING:
@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore):
@cached(num_args=2) @cached(num_args=2)
async def get_user_filter( async def get_user_filter(
self, user_id: UserID, filter_id: Union[int, str] self, user_id: UserID, filter_id: Union[int, str]
) -> JsonDict: ) -> JsonMapping:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN. # with a coherent error message rather than 500 M_UNKNOWN.
try: try:

View file

@ -465,7 +465,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_references_for_event", list_name="event_ids") @cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
async def get_references_for_events( async def get_references_for_events(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Mapping[str, Optional[List[_RelatedEvent]]]: ) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]:
"""Get a list of references to the given events. """Get a list of references to the given events.
Args: Args:
@ -931,7 +931,7 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: str, room_id: str,
limit: int = 5, limit: int = 5,
from_token: Optional[ThreadsNextBatch] = None, from_token: Optional[ThreadsNextBatch] = None,
) -> Tuple[List[str], Optional[ThreadsNextBatch]]: ) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]:
"""Get a list of thread IDs, ordered by topological ordering of their """Get a list of thread IDs, ordered by topological ordering of their
latest reply. latest reply.

View file

@ -984,7 +984,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
) )
@cached(iterable=True, max_entries=10000) @cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
""" """
Get current hosts in room based on current state. Get current hosts in room based on current state.
@ -1013,12 +1013,14 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
# `get_users_in_room` rather than funky SQL. # `get_users_in_room` rather than funky SQL.
domains = await self.get_current_hosts_in_room(room_id) domains = await self.get_current_hosts_in_room(room_id)
return list(domains) return tuple(domains)
# For PostgreSQL we can use a regex to pull out the domains from the # For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex. # joined users in `current_state_events` via regex.
def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]: def get_current_hosts_in_room_ordered_txn(
txn: LoggingTransaction,
) -> Tuple[str, ...]:
# Returns a list of servers currently joined in the room sorted by # Returns a list of servers currently joined in the room sorted by
# longest in the room first (aka. with the lowest depth). The # longest in the room first (aka. with the lowest depth). The
# heuristic of sorting by servers who have been in the room the # heuristic of sorting by servers who have been in the room the
@ -1043,7 +1045,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
""" """
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
# `server_domain` will be `NULL` for malformed MXIDs with no colons. # `server_domain` will be `NULL` for malformed MXIDs with no colons.
return [d for d, in txn if d is not None] return tuple(d for d, in txn if d is not None)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn

View file

@ -15,10 +15,10 @@
import logging import logging
from typing import ( from typing import (
Any, Any,
Dict,
Generator, Generator,
Iterable, Iterable,
List, List,
Mapping,
NoReturn, NoReturn,
Optional, Optional,
Set, Set,
@ -96,7 +96,7 @@ class DescriptorTestCase(unittest.TestCase):
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached(num_args=1) @descriptors.cached(num_args=1)
def fn(self, arg1: int, arg2: int) -> mock.Mock: def fn(self, arg1: int, arg2: int) -> str:
return self.mock(arg1, arg2) return self.mock(arg1, arg2)
obj = Cls() obj = Cls()
@ -228,8 +228,9 @@ class DescriptorTestCase(unittest.TestCase):
call_count = 0 call_count = 0
@cached() @cached()
def fn(self, arg1: int) -> Optional[Deferred]: def fn(self, arg1: int) -> Deferred:
self.call_count += 1 self.call_count += 1
assert self.result is not None
return self.result return self.result
obj = Cls() obj = Cls()
@ -401,21 +402,21 @@ class DescriptorTestCase(unittest.TestCase):
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached(iterable=True) @descriptors.cached(iterable=True)
def fn(self, arg1: int, arg2: int) -> List[str]: def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
return self.mock(arg1, arg2) return self.mock(arg1, arg2)
obj = Cls() obj = Cls()
obj.mock.return_value = ["spam", "eggs"] obj.mock.return_value = ("spam", "eggs")
r = obj.fn(1, 2) r = obj.fn(1, 2)
self.assertEqual(r.result, ["spam", "eggs"]) self.assertEqual(r.result, ("spam", "eggs"))
obj.mock.assert_called_once_with(1, 2) obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock() obj.mock.reset_mock()
# a call with different params should call the mock again # a call with different params should call the mock again
obj.mock.return_value = ["chips"] obj.mock.return_value = ("chips",)
r = obj.fn(1, 3) r = obj.fn(1, 3)
self.assertEqual(r.result, ["chips"]) self.assertEqual(r.result, ("chips",))
obj.mock.assert_called_once_with(1, 3) obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock() obj.mock.reset_mock()
@ -423,9 +424,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(len(obj.fn.cache.cache), 3) self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2) r = obj.fn(1, 2)
self.assertEqual(r.result, ["spam", "eggs"]) self.assertEqual(r.result, ("spam", "eggs"))
r = obj.fn(1, 3) r = obj.fn(1, 3)
self.assertEqual(r.result, ["chips"]) self.assertEqual(r.result, ("chips",))
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self) -> None: def test_cache_iterable_with_sync_exception(self) -> None:
@ -784,7 +785,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
pass pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1") @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]: async def list_fn(
self, args1: Iterable[int], arg2: int
) -> Mapping[int, str]:
context = current_context() context = current_context()
assert isinstance(context, LoggingContext) assert isinstance(context, LoggingContext)
assert context.name == "c1" assert context.name == "c1"
@ -847,11 +850,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
pass pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1") @descriptors.cachedList(cached_method_name="fn", list_name="args1")
def list_fn(self, args1: List[int]) -> "Deferred[dict]": def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]":
return self.mock(args1) return self.mock(args1)
obj = Cls() obj = Cls()
deferred_result: "Deferred[dict]" = Deferred() deferred_result: "Deferred[Mapping[int, str]]" = Deferred()
obj.mock.return_value = deferred_result obj.mock.return_value = deferred_result
# start off several concurrent lookups of the same key # start off several concurrent lookups of the same key
@ -890,7 +893,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
pass pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1") @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]: async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]:
# we want this to behave like an asynchronous function # we want this to behave like an asynchronous function
await run_on_reactor() await run_on_reactor()
return self.mock(args1, arg2) return self.mock(args1, arg2)
@ -929,7 +932,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
pass pass
@cachedList(cached_method_name="fn", list_name="args") @cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args: List[int]) -> Dict[int, str]: async def list_fn(self, args: List[int]) -> Mapping[int, str]:
await complete_lookup await complete_lookup
return {arg: str(arg) for arg in args} return {arg: str(arg) for arg in args}
@ -964,7 +967,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
pass pass
@cachedList(cached_method_name="fn", list_name="args") @cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args: List[int]) -> Dict[int, str]: async def list_fn(self, args: List[int]) -> Mapping[int, str]:
await make_deferred_yieldable(complete_lookup) await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished self.inner_context_was_finished = current_context().finished
return {arg: str(arg) for arg in args} return {arg: str(arg) for arg in args}