Refactor SsoHandler.get_mxid_from_sso (#8900)

* Factor out _call_attribute_mapper and _register_mapped_user

This is mostly an attempt to simplify `get_mxid_from_sso`.

* Move mapping_lock down into SsoHandler.
This commit is contained in:
Richard van der Hoff 2020-12-10 12:43:58 +00:00 committed by GitHub
parent 1821f7cc26
commit c64002e1c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 28 deletions

1
changelog.d/8900.feature Normal file
View file

@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.

View file

@ -34,7 +34,6 @@ from synapse.types import (
map_username_to_mxid_localpart, map_username_to_mxid_localpart,
mxid_localpart_allowed_characters, mxid_localpart_allowed_characters,
) )
from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING: if TYPE_CHECKING:
@ -81,9 +80,6 @@ class SamlHandler(BaseHandler):
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
def handle_redirect_request( def handle_redirect_request(
@ -299,15 +295,14 @@ class SamlHandler(BaseHandler):
return None return None
with (await self._mapping_lock.queue(self._auth_provider_id)): return await self._sso_handler.get_mxid_from_sso(
return await self._sso_handler.get_mxid_from_sso( self._auth_provider_id,
self._auth_provider_id, remote_user_id,
remote_user_id, user_agent,
user_agent, ip_address,
ip_address, saml_response_to_remapped_user_attributes,
saml_response_to_remapped_user_attributes, grandfather_existing_users,
grandfather_existing_users, )
)
def _remote_id_from_saml_response( def _remote_id_from_saml_response(
self, self,

View file

@ -22,6 +22,7 @@ from twisted.web.http import Request
from synapse.api.errors import RedirectException from synapse.api.errors import RedirectException
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters from synapse.types import UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -54,6 +55,9 @@ class SsoHandler:
self._error_template = hs.config.sso_error_template self._error_template = hs.config.sso_error_template
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
def render_error( def render_error(
self, request, error: str, error_description: Optional[str] = None self, request, error: str, error_description: Optional[str] = None
) -> None: ) -> None:
@ -172,24 +176,38 @@ class SsoHandler:
to an additional page. (e.g. to prompt for more information) to an additional page. (e.g. to prompt for more information)
""" """
# first of all, check if we already have a mapping for this user # grab a lock while we try to find a mapping for this user. This seems...
previously_registered_user_id = await self.get_sso_user_by_remote_user_id( # optimistic, especially for implementations that end up redirecting to
auth_provider_id, remote_user_id, # interstitial pages.
) with await self._mapping_lock.queue(auth_provider_id):
if previously_registered_user_id: # first of all, check if we already have a mapping for this user
return previously_registered_user_id previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
# Check for grandfathering of users. )
if grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id: if previously_registered_user_id:
# Future logins should also match this user ID.
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id return previously_registered_user_id
# Otherwise, generate a new user. # Check for grandfathering of users.
if grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
# Future logins should also match this user ID.
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
# Otherwise, generate a new user.
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
user_id = await self._register_mapped_user(
attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
)
return user_id
async def _call_attribute_mapper(
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
) -> UserAttributes:
"""Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES): for i in range(self._MAP_USERNAME_RETRIES):
try: try:
attributes = await sso_to_matrix_id_mapper(i) attributes = await sso_to_matrix_id_mapper(i)
@ -227,7 +245,16 @@ class SsoHandler:
raise MappingException( raise MappingException(
"Unable to generate a Matrix ID from the SSO response" "Unable to generate a Matrix ID from the SSO response"
) )
return attributes
async def _register_mapped_user(
self,
attributes: UserAttributes,
auth_provider_id: str,
remote_user_id: str,
user_agent: str,
ip_address: str,
) -> str:
# Since the localpart is provided via a potentially untrusted module, # Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering. # ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(attributes.localpart): if contains_invalid_mxid_characters(attributes.localpart):