diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d0dfa959dc..79b35b3e7c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -140,6 +140,22 @@ class RegistrationError(SynapseError): pass +class InteractiveAuthIncompleteError(Exception): + """An error raised when UI auth is not yet complete + + (This indicates we should return a 401 with 'result' as the body) + + Attributes: + result (dict): the server response to the request, which should be + passed back to the client + """ + def __init__(self, result): + super(InteractiveAuthIncompleteError, self).__init__( + "Interactive auth not yet complete", + ) + self.result = result + + class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__(self, *args, **kwargs): diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2f30f183ce..28c80608a7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -17,7 +17,10 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType -from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError +from synapse.api.errors import ( + AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError, + SynapseError, +) from synapse.module_api import ModuleApi from synapse.types import UserID from synapse.util.async import run_on_reactor @@ -95,26 +98,36 @@ class AuthHandler(BaseHandler): session with a map, which maps each auth-type (str) to the relevant identity authenticated by that auth-type (mostly str, but for captcha, bool). + If no auth flows have been completed successfully, raises an + InteractiveAuthIncompleteError. To handle this, you can use + synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + decorator. + Args: flows (list): A list of login flows. Each flow is an ordered list of strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. + clientip (str): The IP address of the client. + Returns: - A tuple of (authed, dict, dict, session_id) where authed is true if - the client has successfully completed an auth flow. If it is true - the first dict contains the authenticated credentials of each stage. + defer.Deferred[dict, dict, str]: a deferred tuple of + (creds, params, session_id). - If authed is false, the first dictionary is the server response to - the login request and should be passed back to the client. + 'creds' contains the authenticated credentials of each stage. - In either case, the second dict contains the parameters for this - request (which may have been given only in a previous call). + 'params' contains the parameters for this request (which may + have been given only in a previous call). - session_id is the ID of this session, either passed in by the client - or assigned by the call to check_auth + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call + + Raises: + InteractiveAuthIncompleteError if the client has not yet completed + all the stages in any of the permitted flows. """ authdict = None @@ -142,11 +155,8 @@ class AuthHandler(BaseHandler): clientdict = session['clientdict'] if not authdict: - defer.returnValue( - ( - False, self._auth_dict_for_flows(flows, session), - clientdict, session['id'] - ) + raise InteractiveAuthIncompleteError( + self._auth_dict_for_flows(flows, session), ) if 'creds' not in session: @@ -190,12 +200,14 @@ class AuthHandler(BaseHandler): "Auth completed with creds: %r. Client dict has keys: %r", creds, clientdict.keys() ) - defer.returnValue((True, creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() ret.update(errordict) - defer.returnValue((False, ret, clientdict, session['id'])) + raise InteractiveAuthIncompleteError( + ret, + ) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 1f5bc24cc3..77434937ff 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -15,12 +15,13 @@ """This module contains base REST classes for constructing client v1 servlets. """ - -from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +import logging import re -import logging +from twisted.internet import defer +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX logger = logging.getLogger(__name__) @@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit): filter_json['room']['timeline']["limit"] = min( filter_json['room']['timeline']['limit'], filter_timeline_limit) + + +def interactive_auth_handler(orig): + """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors + + Takes a on_POST method which returns a deferred (errcode, body) response + and adds exception handling to turn a InteractiveAuthIncompleteError into + a 401 response. + + Normal usage is: + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + # ... + yield self.auth_handler.check_auth + """ + def wrapped(*args, **kwargs): + res = defer.maybeDeferred(orig, *args, **kwargs) + res.addErrback(_catch_incomplete_interactive_auth) + return res + return wrapped + + +def _catch_incomplete_interactive_auth(f): + """helper for interactive_auth_handler + + Catches InteractiveAuthIncompleteErrors and turns them into 401 responses + + Args: + f (failure.Failure): + """ + f.trap(InteractiveAuthIncompleteError) + return 401, f.value.result diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c26ce63bcf..0d59a93222 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -26,7 +26,7 @@ from synapse.http.servlet import ( ) from synapse.util.async import run_on_reactor from synapse.util.msisdn import phone_number_to_msisdn -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -100,21 +100,19 @@ class PasswordRestServlet(RestServlet): self.datastore = self.hs.get_datastore() self._set_password_handler = hs.get_set_password_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): yield run_on_reactor() body = parse_json_object_from_request(request) - authed, result, params, _ = yield self.auth_handler.check_auth([ + result, params, _ = yield self.auth_handler.check_auth([ [LoginType.PASSWORD], [LoginType.EMAIL_IDENTITY], [LoginType.MSISDN], ], body, self.hs.get_ip_from_request(request)) - if not authed: - defer.returnValue((401, result)) - user_id = None requester = None @@ -168,6 +166,7 @@ class DeactivateAccountRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) @@ -186,13 +185,10 @@ class DeactivateAccountRestServlet(RestServlet): ) defer.returnValue((200, {})) - authed, result, params, _ = yield self.auth_handler.check_auth([ + result, params, _ = yield self.auth_handler.check_auth([ [LoginType.PASSWORD], ], body, self.hs.get_ip_from_request(request)) - if not authed: - defer.returnValue((401, result)) - if LoginType.PASSWORD in result: user_id = result[LoginType.PASSWORD] # if using password, they should also be logged in diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 5321e5abbb..909f9c087b 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api import constants, errors from synapse.http import servlet -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -60,6 +60,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): try: @@ -77,13 +78,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet): 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM ) - authed, result, params, _ = yield self.auth_handler.check_auth([ + result, params, _ = yield self.auth_handler.check_auth([ [constants.LoginType.PASSWORD], ], body, self.hs.get_ip_from_request(request)) - if not authed: - defer.returnValue((401, result)) - requester = yield self.auth.get_user_by_req(request) yield self.device_handler.delete_devices( requester.user.to_string(), @@ -115,6 +113,7 @@ class DeviceRestServlet(servlet.RestServlet): ) defer.returnValue((200, device)) + @interactive_auth_handler @defer.inlineCallbacks def on_DELETE(self, request, device_id): requester = yield self.auth.get_user_by_req(request) @@ -130,13 +129,10 @@ class DeviceRestServlet(servlet.RestServlet): else: raise - authed, result, params, _ = yield self.auth_handler.check_auth([ + result, params, _ = yield self.auth_handler.check_auth([ [constants.LoginType.PASSWORD], ], body, self.hs.get_ip_from_request(request)) - if not authed: - defer.returnValue((401, result)) - # check that the UI auth matched the access token user_id = result[constants.LoginType.PASSWORD] if user_id != requester.user.to_string(): diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 9e2f7308ce..e9d88a8895 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -27,7 +27,7 @@ from synapse.http.servlet import ( ) from synapse.util.msisdn import phone_number_to_msisdn -from ._base import client_v2_patterns +from ._base import client_v2_patterns, interactive_auth_handler import logging import hmac @@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): yield run_on_reactor() @@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet): [LoginType.MSISDN, LoginType.EMAIL_IDENTITY], ]) - authed, auth_result, params, session_id = yield self.auth_handler.check_auth( + auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) - if not authed: - defer.returnValue((401, auth_result)) - return - if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 821c735528..096f771bea 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,5 +1,7 @@ +from twisted.python import failure + from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError from twisted.internet import defer from mock import Mock from tests import unittest @@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase): side_effect=lambda x: self.appservice) ) - self.auth_result = (False, None, None, None) + self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None)) self.auth_handler = Mock( check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), get_session_data=Mock(return_value=None) @@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.request.args = { "access_token": "i_am_an_app_service" } + self.request_data = json.dumps({ "username": "kermit" }) @@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase): "device_id": device_id, }) self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (True, None, { + self.auth_result = (None, { "username": "kermit", "password": "monkey" }, None) @@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase): "password": "monkey" }) self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (True, None, { + self.auth_result = (None, { "username": "kermit", "password": "monkey" }, None)