Add type hints to response cache. (#8507)

This commit is contained in:
Patrick Cloke 2020-10-09 11:35:11 -04:00 committed by GitHub
parent 66ac4b1e34
commit 1781bbe319
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 48 additions and 34 deletions

1
changelog.d/8507.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to various parts of the code base.

View file

@ -65,6 +65,7 @@ files =
synapse/types.py, synapse/types.py,
synapse/util/async_helpers.py, synapse/util/async_helpers.py,
synapse/util/caches/descriptors.py, synapse/util/caches/descriptors.py,
synapse/util/caches/response_cache.py,
synapse/util/caches/stream_change_cache.py, synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py, synapse/util/metrics.py,
tests/replication, tests/replication,

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib import urllib
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Tuple
from prometheus_client import Counter from prometheus_client import Counter
@ -93,7 +93,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.protocol_meta_cache = ResponseCache( self.protocol_meta_cache = ResponseCache(
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
) ) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id): async def query_user(self, service, user_id):
if service.url is None: if service.url is None:

View file

@ -116,7 +116,7 @@ class FederationServer(FederationBase):
# We cache results for transaction with the same ID # We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache( self._transaction_resp_cache = ResponseCache(
hs, "fed_txn_handler", timeout_ms=30000 hs, "fed_txn_handler", timeout_ms=30000
) ) # type: ResponseCache[Tuple[str, str]]
self.transaction_actions = TransactionActions(self.store) self.transaction_actions = TransactionActions(self.store)
@ -124,10 +124,12 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) self._state_resp_cache = ResponseCache(
hs, "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache( self._state_ids_resp_cache = ResponseCache(
hs, "state_ids_resp", timeout_ms=30000 hs, "state_ids_resp", timeout_ms=30000
) ) # type: ResponseCache[Tuple[str, str]]
self._federation_metrics_domains = ( self._federation_metrics_domains = (
hs.get_config().federation.federation_metrics_domains hs.get_config().federation.federation_metrics_domains

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -47,12 +47,14 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self.snapshot_cache = ResponseCache(
hs, "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
def snapshot_all_rooms( async def snapshot_all_rooms(
self, self,
user_id: str, user_id: str,
pagin_config: PaginationConfig, pagin_config: PaginationConfig,
@ -84,7 +86,7 @@ class InitialSyncHandler(BaseHandler):
include_archived, include_archived,
) )
return self.snapshot_cache.wrap( return await self.snapshot_cache.wrap(
key, key,
self._snapshot_all_rooms, self._snapshot_all_rooms,
user_id, user_id,

View file

@ -120,7 +120,7 @@ class RoomCreationHandler(BaseHandler):
# subsequent requests # subsequent requests
self._upgrade_response_cache = ResponseCache( self._upgrade_response_cache = ResponseCache(
hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) ) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules() self.third_party_event_rules = hs.get_third_party_event_rules()

View file

@ -243,7 +243,9 @@ class SyncHandler:
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs, "sync") self.response_cache = ResponseCache(
hs, "sync"
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.storage = hs.get_storage() self.storage = hs.get_storage()

View file

@ -92,7 +92,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self.CACHE: if self.CACHE:
self.response_cache = ResponseCache( self.response_cache = ResponseCache(
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) ) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we # We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name. # assert here that sub classes don't try and use the name.

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer from twisted.internet import defer
@ -20,10 +21,15 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T")
class ResponseCache:
class ResponseCache(Generic[T]):
""" """
This caches a deferred response. Until the deferred completes it will be This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request returned from the cache. This means that if the client retries the request
@ -31,8 +37,9 @@ class ResponseCache:
used rather than trying to compute a new response. used rather than trying to compute a new response.
""" """
def __init__(self, hs, name, timeout_ms=0): def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
self.pending_result_cache = {} # Requests that haven't finished yet. # Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0 self.timeout_sec = timeout_ms / 1000.0
@ -40,13 +47,13 @@ class ResponseCache:
self._name = name self._name = name
self._metrics = register_cache("response_cache", name, self, resizable=False) self._metrics = register_cache("response_cache", name, self, resizable=False)
def size(self): def size(self) -> int:
return len(self.pending_result_cache) return len(self.pending_result_cache)
def __len__(self): def __len__(self) -> int:
return self.size() return self.size()
def get(self, key): def get(self, key: T) -> Optional[defer.Deferred]:
"""Look up the given key. """Look up the given key.
Can return either a new Deferred (which also doesn't follow the synapse Can return either a new Deferred (which also doesn't follow the synapse
@ -58,12 +65,11 @@ class ResponseCache:
from an absent cache entry. from an absent cache entry.
Args: Args:
key (hashable): key: key to get/set in the cache
Returns: Returns:
twisted.internet.defer.Deferred|None|E: None if there is no entry None if there is no entry for this key; otherwise a deferred which
for this key; otherwise either a deferred result or the result resolves to the result.
itself.
""" """
result = self.pending_result_cache.get(key) result = self.pending_result_cache.get(key)
if result is not None: if result is not None:
@ -73,7 +79,7 @@ class ResponseCache:
self._metrics.inc_misses() self._metrics.inc_misses()
return None return None
def set(self, key, deferred): def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
"""Set the entry for the given key to the given deferred. """Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie, *deferred* should run its callbacks in the sentinel logcontext (ie,
@ -85,12 +91,11 @@ class ResponseCache:
result. You will probably want to make_deferred_yieldable the result. result. You will probably want to make_deferred_yieldable the result.
Args: Args:
key (hashable): key: key to get/set in the cache
deferred (twisted.internet.defer.Deferred[T): deferred: The deferred which resolves to the result.
Returns: Returns:
twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual A new deferred which resolves to the actual result.
result.
""" """
result = ObservableDeferred(deferred, consumeErrors=True) result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result self.pending_result_cache[key] = result
@ -107,7 +112,9 @@ class ResponseCache:
result.addBoth(remove) result.addBoth(remove)
return result.observe() return result.observe()
def wrap(self, key, callback, *args, **kwargs): def wrap(
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
) -> defer.Deferred:
"""Wrap together a *get* and *set* call, taking care of logcontexts """Wrap together a *get* and *set* call, taking care of logcontexts
First looks up the key in the cache, and if it is present makes it First looks up the key in the cache, and if it is present makes it
@ -118,21 +125,20 @@ class ResponseCache:
Example usage: Example usage:
@defer.inlineCallbacks async def handle_request(request):
def handle_request(request):
# etc # etc
return result return result
result = yield response_cache.wrap( result = await response_cache.wrap(
key, key,
handle_request, handle_request,
request, request,
) )
Args: Args:
key (hashable): key to get/set in the cache key: key to get/set in the cache
callback (callable): function to call if the key is not found in callback: function to call if the key is not found in
the cache the cache
*args: positional parameters to pass to the callback, if it is used *args: positional parameters to pass to the callback, if it is used
@ -140,7 +146,7 @@ class ResponseCache:
**kwargs: named parameters to pass to the callback, if it is used **kwargs: named parameters to pass to the callback, if it is used
Returns: Returns:
twisted.internet.defer.Deferred: yieldable result Deferred which resolves to the result
""" """
result = self.get(key) result = self.get(key)
if not result: if not result: