Merge pull request #2727 from matrix-org/rav/refactor_ui_auth_return

Refactor UI auth implementation
This commit is contained in:
Richard van der Hoff 2017-12-05 09:40:38 +00:00 committed by GitHub
commit aa6ecf0984
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 48 deletions

View file

@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
pass
class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete
(This indicates we should return a 401 with 'result' as the body)
Attributes:
result (dict): the server response to the request, which should be
passed back to the client
"""
def __init__(self, result):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete",
)
self.result = result
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):

View file

@ -17,7 +17,10 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.api.errors import (
AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
SynapseError,
)
from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor
@ -95,26 +98,36 @@ class AuthHandler(BaseHandler):
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
If no auth flows have been completed successfully, raises an
InteractiveAuthIncompleteError. To handle this, you can use
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
decorator.
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns:
A tuple of (authed, dict, dict, session_id) where authed is true if
the client has successfully completed an auth flow. If it is true
the first dict contains the authenticated credentials of each stage.
defer.Deferred[dict, dict, str]: a deferred tuple of
(creds, params, session_id).
If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.
'creds' contains the authenticated credentials of each stage.
In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).
'params' contains the parameters for this request (which may
have been given only in a previous call).
session_id is the ID of this session, either passed in by the client
or assigned by the call to check_auth
'session_id' is the ID of this session, either passed in by the
client or assigned by this call
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
all the stages in any of the permitted flows.
"""
authdict = None
@ -142,11 +155,8 @@ class AuthHandler(BaseHandler):
clientdict = session['clientdict']
if not authdict:
defer.returnValue(
(
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session),
)
if 'creds' not in session:
@ -190,12 +200,14 @@ class AuthHandler(BaseHandler):
"Auth completed with creds: %r. Client dict has keys: %r",
creds, clientdict.keys()
)
defer.returnValue((True, creds, clientdict, session['id']))
defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id']))
raise InteractiveAuthIncompleteError(
ret,
)
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):

View file

@ -15,12 +15,13 @@
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
import logging
import re
import logging
from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
logger = logging.getLogger(__name__)
@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'],
filter_timeline_limit)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns a deferred (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
# ...
yield self.auth_handler.check_auth
"""
def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs)
res.addErrback(_catch_incomplete_interactive_auth)
return res
return wrapped
def _catch_incomplete_interactive_auth(f):
"""helper for interactive_auth_handler
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
Args:
f (failure.Failure):
"""
f.trap(InteractiveAuthIncompleteError)
return 401, f.value.result

View file

@ -26,7 +26,7 @@ from synapse.http.servlet import (
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@ -100,21 +100,19 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([
result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY],
[LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
user_id = None
requester = None
@ -168,6 +166,7 @@ class DeactivateAccountRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
@ -186,13 +185,10 @@ class DeactivateAccountRestServlet(RestServlet):
)
defer.returnValue((200, {}))
authed, result, params, _ = yield self.auth_handler.check_auth([
result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
if LoginType.PASSWORD in result:
user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in

View file

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api import constants, errors
from synapse.http import servlet
from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@ -60,6 +60,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
try:
@ -77,13 +78,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
authed, result, params, _ = yield self.auth_handler.check_auth([
result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices(
requester.user.to_string(),
@ -115,6 +113,7 @@ class DeviceRestServlet(servlet.RestServlet):
)
defer.returnValue((200, device))
@interactive_auth_handler
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
@ -130,13 +129,10 @@ class DeviceRestServlet(servlet.RestServlet):
else:
raise
authed, result, params, _ = yield self.auth_handler.check_auth([
result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
# check that the UI auth matched the access token
user_id = result[constants.LoginType.PASSWORD]
if user_id != requester.user.to_string():

View file

@ -27,7 +27,7 @@ from synapse.http.servlet import (
)
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns
from ._base import client_v2_patterns, interactive_auth_handler
import logging
import hmac
@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, auth_result))
return
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",

View file

@ -1,5 +1,7 @@
from twisted.python import failure
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
from twisted.internet import defer
from mock import Mock
from tests import unittest
@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
side_effect=lambda x: self.appservice)
)
self.auth_result = (False, None, None, None)
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)
@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)