Merge pull request #2727 from matrix-org/rav/refactor_ui_auth_return

Refactor UI auth implementation
This commit is contained in:
Richard van der Hoff 2017-12-05 09:40:38 +00:00 committed by GitHub
commit aa6ecf0984
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 48 deletions

View file

@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
pass pass
class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete
(This indicates we should return a 401 with 'result' as the body)
Attributes:
result (dict): the server response to the request, which should be
passed back to the client
"""
def __init__(self, result):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete",
)
self.result = result
class UnrecognizedRequestError(SynapseError): class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make""" """An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

View file

@ -17,7 +17,10 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import (
AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
SynapseError,
)
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -95,26 +98,36 @@ class AuthHandler(BaseHandler):
session with a map, which maps each auth-type (str) to the relevant session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool). identity authenticated by that auth-type (mostly str, but for captcha, bool).
If no auth flows have been completed successfully, raises an
InteractiveAuthIncompleteError. To handle this, you can use
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
decorator.
Args: Args:
flows (list): A list of login flows. Each flow is an ordered list of flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full strings representing auth-types. At least one full
flow must be completed in order for auth to be successful. flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent. 'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client. clientip (str): The IP address of the client.
Returns: Returns:
A tuple of (authed, dict, dict, session_id) where authed is true if defer.Deferred[dict, dict, str]: a deferred tuple of
the client has successfully completed an auth flow. If it is true (creds, params, session_id).
the first dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to 'creds' contains the authenticated credentials of each stage.
the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this 'params' contains the parameters for this request (which may
request (which may have been given only in a previous call). have been given only in a previous call).
session_id is the ID of this session, either passed in by the client 'session_id' is the ID of this session, either passed in by the
or assigned by the call to check_auth client or assigned by this call
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
all the stages in any of the permitted flows.
""" """
authdict = None authdict = None
@ -142,11 +155,8 @@ class AuthHandler(BaseHandler):
clientdict = session['clientdict'] clientdict = session['clientdict']
if not authdict: if not authdict:
defer.returnValue( raise InteractiveAuthIncompleteError(
( self._auth_dict_for_flows(flows, session),
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
) )
if 'creds' not in session: if 'creds' not in session:
@ -190,12 +200,14 @@ class AuthHandler(BaseHandler):
"Auth completed with creds: %r. Client dict has keys: %r", "Auth completed with creds: %r. Client dict has keys: %r",
creds, clientdict.keys() creds, clientdict.keys()
) )
defer.returnValue((True, creds, clientdict, session['id'])) defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
ret.update(errordict) ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id'])) raise InteractiveAuthIncompleteError(
ret,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip): def add_oob_auth(self, stagetype, authdict, clientip):

View file

@ -15,12 +15,13 @@
"""This module contains base REST classes for constructing client v1 servlets. """This module contains base REST classes for constructing client v1 servlets.
""" """
import logging
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
import re import re
import logging from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
filter_json['room']['timeline']["limit"] = min( filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'], filter_json['room']['timeline']['limit'],
filter_timeline_limit) filter_timeline_limit)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns a deferred (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
# ...
yield self.auth_handler.check_auth
"""
def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs)
res.addErrback(_catch_incomplete_interactive_auth)
return res
return wrapped
def _catch_incomplete_interactive_auth(f):
"""helper for interactive_auth_handler
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
Args:
f (failure.Failure):
"""
f.trap(InteractiveAuthIncompleteError)
return 401, f.value.result

View file

@ -26,7 +26,7 @@ from synapse.http.servlet import (
) )
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -100,21 +100,19 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([ result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY], [LoginType.EMAIL_IDENTITY],
[LoginType.MSISDN], [LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
user_id = None user_id = None
requester = None requester = None
@ -168,6 +166,7 @@ class DeactivateAccountRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self._deactivate_account_handler = hs.get_deactivate_account_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -186,13 +185,10 @@ class DeactivateAccountRestServlet(RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
authed, result, params, _ = yield self.auth_handler.check_auth([ result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
if LoginType.PASSWORD in result: if LoginType.PASSWORD in result:
user_id = result[LoginType.PASSWORD] user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in # if using password, they should also be logged in

View file

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api import constants, errors from synapse.api import constants, errors
from synapse.http import servlet from synapse.http import servlet
from ._base import client_v2_patterns from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,6 +60,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
try: try:
@ -77,13 +78,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
) )
authed, result, params, _ = yield self.auth_handler.check_auth([ result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD], [constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices( yield self.device_handler.delete_devices(
requester.user.to_string(), requester.user.to_string(),
@ -115,6 +113,7 @@ class DeviceRestServlet(servlet.RestServlet):
) )
defer.returnValue((200, device)) defer.returnValue((200, device))
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, device_id): def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -130,13 +129,10 @@ class DeviceRestServlet(servlet.RestServlet):
else: else:
raise raise
authed, result, params, _ = yield self.auth_handler.check_auth([ result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD], [constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
# check that the UI auth matched the access token # check that the UI auth matched the access token
user_id = result[constants.LoginType.PASSWORD] user_id = result[constants.LoginType.PASSWORD]
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():

View file

@ -27,7 +27,7 @@ from synapse.http.servlet import (
) )
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns, interactive_auth_handler
import logging import logging
import hmac import hmac
@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY], [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
]) ])
authed, auth_result, params, session_id = yield self.auth_handler.check_auth( auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
if not authed:
defer.returnValue((401, auth_result))
return
if registered_user_id is not None: if registered_user_id is not None:
logger.info( logger.info(
"Already registered user ID %r for this session", "Already registered user ID %r for this session",

View file

@ -1,5 +1,7 @@
from twisted.python import failure
from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
side_effect=lambda x: self.appservice) side_effect=lambda x: self.appservice)
) )
self.auth_result = (False, None, None, None) self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock( self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None) get_session_data=Mock(return_value=None)
@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.request.args = { self.request.args = {
"access_token": "i_am_an_app_service" "access_token": "i_am_an_app_service"
} }
self.request_data = json.dumps({ self.request_data = json.dumps({
"username": "kermit" "username": "kermit"
}) })
@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": device_id, "device_id": device_id,
}) })
self.registration_handler.check_username = Mock(return_value=True) self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, { self.auth_result = (None, {
"username": "kermit", "username": "kermit",
"password": "monkey" "password": "monkey"
}, None) }, None)
@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey" "password": "monkey"
}) })
self.registration_handler.check_username = Mock(return_value=True) self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, { self.auth_result = (None, {
"username": "kermit", "username": "kermit",
"password": "monkey" "password": "monkey"
}, None) }, None)