diff --git a/changelog.d/7944.misc b/changelog.d/7944.misc new file mode 100644 index 0000000000..afbc91a494 --- /dev/null +++ b/changelog.d/7944.misc @@ -0,0 +1 @@ +Convert the interactive_auth_handler wrapper to async/await. diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index b21538766d..f016b4f1bd 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -17,8 +17,7 @@ """ import logging import re - -from twisted.internet import defer +from typing import Iterable, Pattern from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX @@ -27,15 +26,23 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) -def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): +def client_patterns( + path_regex: str, + releases: Iterable[int] = (0,), + unstable: bool = True, + v1: bool = False, +) -> Iterable[Pattern]: """Creates a regex compiled client path with the correct client path prefix. Args: - path_regex (str): The regex string to match. This should NOT have a ^ + path_regex: The regex string to match. This should NOT have a ^ as this will be prefixed. + releases: An iterable of releases to include this endpoint under. + unstable: If true, include this endpoint under the "unstable" prefix. + v1: If true, include this endpoint under the "api/v1" prefix. Returns: - SRE_Pattern + An iterable of patterns. """ patterns = [] @@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) def interactive_auth_handler(orig): """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors - Takes a on_POST method which returns a deferred (errcode, body) response + Takes a on_POST method which returns an Awaitable (errcode, body) response and adds exception handling to turn a InteractiveAuthIncompleteError into a 401 response. Normal usage is: @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): # ... - yield self.auth_handler.check_auth - """ + await self.auth_handler.check_auth + """ - def wrapped(*args, **kwargs): - res = defer.ensureDeferred(orig(*args, **kwargs)) - res.addErrback(_catch_incomplete_interactive_auth) - return res + async def wrapped(*args, **kwargs): + try: + return await orig(*args, **kwargs) + except InteractiveAuthIncompleteError as e: + return 401, e.result return wrapped - - -def _catch_incomplete_interactive_auth(f): - """helper for interactive_auth_handler - - Catches InteractiveAuthIncompleteErrors and turns them into 401 responses - - Args: - f (failure.Failure): - """ - f.trap(InteractiveAuthIncompleteError) - return 401, f.value.result