Preparatory refactoring of the SamlHandlerTestCase (#8938)

* move simple_async_mock to test_utils

... so that it can be re-used

* Remove references to `SamlHandler._map_saml_response_to_user` from tests

This method is going away, so we can no longer use it as a test point. Instead,
factor out a higher-level method which takes a SAML object, and verify correct
behaviour by mocking out `AuthHandler.complete_sso_login`.

* changelog
This commit is contained in:
Richard van der Hoff 2020-12-15 20:56:10 +00:00 committed by GitHub
parent b3a4b53587
commit 01333681bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 54 deletions

1
changelog.d/8938.feature Normal file
View file

@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.

View file

@ -163,6 +163,29 @@ class SamlHandler(BaseHandler):
return return
logger.debug("SAML2 response: %s", saml2_auth.origxml) logger.debug("SAML2 response: %s", saml2_auth.origxml)
await self._handle_authn_response(request, saml2_auth, relay_state)
async def _handle_authn_response(
self,
request: SynapseRequest,
saml2_auth: saml2.response.AuthnResponse,
relay_state: str,
) -> None:
"""Handle an AuthnResponse, having parsed it from the request params
Assumes that the signature on the response object has been checked. Maps
the user onto an MXID, registering them if necessary, and returns a response
to the browser.
Args:
request: the incoming request from the browser. We'll respond to it with an
HTML page or a redirect
saml2_auth: the parsed AuthnResponse object
relay_state: the RelayState query param, which encodes the URI to rediret
back to
"""
for assertion in saml2_auth.assertions: for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather # kibana limits the length of a log field, whereas this is all rather
# useful, so split it up. # useful, so split it up.

View file

@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException from synapse.handlers.sso import MappingException
from synapse.types import UserID from synapse.types import UserID
from tests.test_utils import FakeResponse from tests.test_utils import FakeResponse, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests. # These are a few constants that are used as config parameters in the tests.
@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider):
} }
def simple_async_mock(return_value=None, raises=None) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
if raises:
raise raises
return return_value
return Mock(side_effect=cb)
async def get_json(url): async def get_json(url):
# Mock get_json calls to handle jwks & oidc discovery endpoints # Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN: if url == WELL_KNOWN:

View file

@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from mock import Mock
import attr import attr
from synapse.api.errors import RedirectException from synapse.api.errors import RedirectException
from synapse.handlers.sso import MappingException
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests. # Check if we have the dependencies to run the tests.
@ -44,6 +48,8 @@ BASE_URL = "https://synapse/"
@attr.s @attr.s
class FakeAuthnResponse: class FakeAuthnResponse:
ava = attr.ib(type=dict) ava = attr.ib(type=dict)
assertions = attr.ib(type=list, factory=list)
in_response_to = attr.ib(type=Optional[str], default=None)
class TestMappingProvider: class TestMappingProvider:
@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase):
def test_map_saml_response_to_user(self): def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
# The redirect_url doesn't matter with the default user mapping provider. request = _mock_request()
redirect_url = "" self.get_success(
mxid = self.get_success( self.handler._handle_authn_response(request, saml_response, "redirect_uri")
self.handler._map_saml_response_to_user( )
saml_response, redirect_url, "user-agent", "10.10.10.10"
) # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri"
) )
self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self): def test_map_saml_response_to_existing_user(self):
@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
store.register_user(user_id="@test_user:test", password_hash=None) store.register_user(user_id="@test_user:test", password_hash=None)
) )
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# Map a user via SSO. # Map a user via SSO.
saml_response = FakeAuthnResponse( saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"} {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
) )
redirect_url = "" request = _mock_request()
mxid = self.get_success( self.get_success(
self.handler._map_saml_response_to_user( self.handler._handle_authn_response(request, saml_response, "")
saml_response, redirect_url, "user-agent", "10.10.10.10" )
)
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, ""
) )
self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid. # Subsequent calls should map to the same mxid.
mxid = self.get_success( auth_handler.complete_sso_login.reset_mock()
self.handler._map_saml_response_to_user( self.get_success(
saml_response, redirect_url, "user-agent", "10.10.10.10" self.handler._handle_authn_response(request, saml_response, "")
) )
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, ""
) )
self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self): def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None)
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
redirect_url = "" request = _mock_request()
e = self.get_failure( self.get_success(
self.handler._map_saml_response_to_user( self.handler._handle_authn_response(request, saml_response, ""),
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
) )
self.assertEqual(str(e.value), "localpart is invalid: föö") sso_handler.render_error.assert_called_once_with(
request, "mapping_error", "localpart is invalid: föö"
)
auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self): def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use.""" """The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None)
# register a user to occupy the first-choice MXID
store = self.hs.get_datastore() store = self.hs.get_datastore()
self.get_success( self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None) store.register_user(user_id="@test_user:test", password_hash=None)
) )
# send the fake SAML response
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
redirect_url = "" request = _mock_request()
mxid = self.get_success( self.get_success(
self.handler._map_saml_response_to_user( self.handler._handle_authn_response(request, saml_response, ""),
saml_response, redirect_url, "user-agent", "10.10.10.10"
)
) )
# test_user is already taken, so test_user1 gets registered instead. # test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test") auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", request, ""
)
auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular SAML username. # Register all of the potential mxids for a particular SAML username.
self.get_success( self.get_success(
@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
# Now attempt to map to a username, this will fail since all potential usernames are taken. # Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
e = self.get_failure( self.get_success(
self.handler._map_saml_response_to_user( self.handler._handle_authn_response(request, saml_response, ""),
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
MappingException,
) )
self.assertEqual( sso_handler.render_error.assert_called_once_with(
str(e.value), "Unable to generate a Matrix ID from the SSO response" request,
"mapping_error",
"Unable to generate a Matrix ID from the SSO response",
) )
auth_handler.complete_sso_login.assert_not_called()
@override_config( @override_config(
{ {
@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
} }
) )
def test_map_saml_response_redirect(self): def test_map_saml_response_redirect(self):
"""Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
redirect_url = "" request = _mock_request()
e = self.get_failure( e = self.get_failure(
self.handler._map_saml_response_to_user( self.handler._handle_authn_response(request, saml_response, ""),
saml_response, redirect_url, "user-agent", "10.10.10.10"
),
RedirectException, RedirectException,
) )
self.assertEqual(e.value.location, b"https://custom-saml-redirect/") self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
return Mock(spec=["getClientIP", "get_user_agent"])

View file

@ -22,6 +22,8 @@ import warnings
from asyncio import Future from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar from typing import Any, Awaitable, Callable, TypeVar
from mock import Mock
import attr import attr
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]:
return cleanup return cleanup
def simple_async_mock(return_value=None, raises=None) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
if raises:
raise raises
return return_value
return Mock(side_effect=cb)
@attr.s @attr.s
class FakeResponse: class FakeResponse:
"""A fake twisted.web.IResponse object """A fake twisted.web.IResponse object