diff --git a/changelog.d/16827.bugfix b/changelog.d/16827.bugfix new file mode 100644 index 0000000000..e0ed9e262a --- /dev/null +++ b/changelog.d/16827.bugfix @@ -0,0 +1 @@ +Fix a race when registering via email 3pid where 2 different user ids would be created. diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 634ebed2be..01541a1a9b 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -75,6 +75,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +USER_REGISTRATION_LOCK_NAME = "user_registration" + class EmailRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/email/requestToken$") @@ -417,6 +419,7 @@ class RegisterRestServlet(RestServlet): self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() + self._worker_lock_handler = hs.get_worker_locks_handler() self.clock = hs.get_clock() self.password_auth_provider = hs.get_password_auth_provider() self._registration_enabled = self.hs.config.registration.enable_registration @@ -508,6 +511,23 @@ class RegisterRestServlet(RestServlet): "An access token should not be provided on requests to /register (except if type is m.login.application_service)", ) + # Take a global lock when doing user registration to avoid races, + # for example when doing 3pid email binding. + async with self._worker_lock_handler.acquire_lock( + USER_REGISTRATION_LOCK_NAME, "" + ): + return await self._do_user_register( + desired_username, client_addr, body, should_issue_refresh_token, request + ) + + async def _do_user_register( + self, + desired_username: Optional[str], + address: str, + body: JsonDict, + should_issue_refresh_token: bool, + request: SynapseRequest, + ) -> Tuple[int, JsonDict]: # == Normal User Registration == (everyone else) if not self._registration_enabled: raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) @@ -702,7 +722,7 @@ class RegisterRestServlet(RestServlet): guest_access_token=guest_access_token, threepid=threepid, default_display_name=display_name, - address=client_addr, + address=address, user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 859051cdda..1a2c0594fb 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -21,7 +21,8 @@ # import datetime import os -from typing import Any, Dict, List, Tuple +import re +from typing import Any, Dict, List, Optional, Tuple import pkg_resources @@ -42,6 +43,7 @@ from synapse.types import JsonDict from synapse.util import Clock from tests import unittest +from tests.server import ThreadedMemoryReactorClock from tests.unittest import override_config @@ -1248,3 +1250,91 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): f"{self.url}?token={token}", ) self.assertEqual(channel.code, 200, msg=channel.result) + + +class EmailRegisterRestServletTestCase(unittest.HomeserverTestCase): + servlets = [register.register_servlets] + + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: + hs = super().make_homeserver(reactor, clock) + + async def send_email( + email_address: str, + subject: str, + app_name: str, + html: str, + text: str, + additional_headers: Optional[Dict[str, str]] = None, + ) -> None: + self.email_attempts.append(text) + + self.email_attempts: List[str] = [] + hs.get_send_email_handler().send_email = send_email # type: ignore[method-assign] + return hs + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "registrations_require_3pid": ["email"], + "disable_msisdn_registration": True, + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_email_3pid_registration_race(self) -> None: + channel = self.make_request("POST", b"register", {"password": "password"}) + session = channel.json_body["session"] + + # request a token to be sent by email for validation + channel = self.make_request( + "POST", + b"register/email/requestToken", + { + "client_secret": "client_secret", + "email": "email@email", + "send_attempt": 1, + }, + ) + sid = channel.json_body["sid"] + + email_text = self.email_attempts[0] + match = re.search("https://test_server(.*)", email_text) + assert match is not None + validation_url = match.group(1) + + # "Click" the link in the email to validate the adress + self.make_request("GET", validation_url.encode("utf-8")) + + # launch 2 simultaneous register request, only one account + # should be created after that. + register_content = { + "auth": { + "session": session, + "threepid_creds": { + "client_secret": "client_secret", + "sid": sid, + }, + "type": "m.login.email.identity", + }, + "password": "password", + } + register1_channel = self.make_request( + "POST", b"register", register_content, await_result=False + ) + register2_channel = self.make_request( + "POST", b"register", register_content, await_result=False + ) + while ( + not register1_channel.is_finished() or not register2_channel.is_finished() + ): + self.pump() + + self.assertEqual( + register1_channel.json_body["user_id"], + register2_channel.json_body["user_id"], + )