Fix a race when registering via email 3pid

This commit is contained in:
Mathieu Velten 2024-01-16 21:28:50 +01:00
parent cf5adc80e1
commit 027b4af5ac
3 changed files with 113 additions and 2 deletions

1
changelog.d/16827.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a race when registering via email 3pid where 2 different user ids would be created.

View file

@ -75,6 +75,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
USER_REGISTRATION_LOCK_NAME = "user_registration"
class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$")
@ -417,6 +419,7 @@ class RegisterRestServlet(RestServlet):
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self._worker_lock_handler = hs.get_worker_locks_handler()
self.clock = hs.get_clock()
self.password_auth_provider = hs.get_password_auth_provider()
self._registration_enabled = self.hs.config.registration.enable_registration
@ -508,6 +511,23 @@ class RegisterRestServlet(RestServlet):
"An access token should not be provided on requests to /register (except if type is m.login.application_service)",
)
# Take a global lock when doing user registration to avoid races,
# for example when doing 3pid email binding.
async with self._worker_lock_handler.acquire_lock(
USER_REGISTRATION_LOCK_NAME, ""
):
return await self._do_user_register(
desired_username, client_addr, body, should_issue_refresh_token, request
)
async def _do_user_register(
self,
desired_username: Optional[str],
address: str,
body: JsonDict,
should_issue_refresh_token: bool,
request: SynapseRequest,
) -> Tuple[int, JsonDict]:
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
@ -702,7 +722,7 @@ class RegisterRestServlet(RestServlet):
guest_access_token=guest_access_token,
threepid=threepid,
default_display_name=display_name,
address=client_addr,
address=address,
user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being

View file

@ -21,7 +21,8 @@
#
import datetime
import os
from typing import Any, Dict, List, Tuple
import re
from typing import Any, Dict, List, Optional, Tuple
import pkg_resources
@ -42,6 +43,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.server import ThreadedMemoryReactorClock
from tests.unittest import override_config
@ -1248,3 +1250,91 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
f"{self.url}?token={token}",
)
self.assertEqual(channel.code, 200, msg=channel.result)
class EmailRegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
hs = super().make_homeserver(reactor, clock)
async def send_email(
email_address: str,
subject: str,
app_name: str,
html: str,
text: str,
additional_headers: Optional[Dict[str, str]] = None,
) -> None:
self.email_attempts.append(text)
self.email_attempts: List[str] = []
hs.get_send_email_handler().send_email = send_email # type: ignore[method-assign]
return hs
@unittest.override_config(
{
"public_baseurl": "https://test_server",
"registrations_require_3pid": ["email"],
"disable_msisdn_registration": True,
"email": {
"smtp_host": "mail_server",
"smtp_port": 2525,
"notif_from": "sender@host",
},
}
)
def test_email_3pid_registration_race(self) -> None:
channel = self.make_request("POST", b"register", {"password": "password"})
session = channel.json_body["session"]
# request a token to be sent by email for validation
channel = self.make_request(
"POST",
b"register/email/requestToken",
{
"client_secret": "client_secret",
"email": "email@email",
"send_attempt": 1,
},
)
sid = channel.json_body["sid"]
email_text = self.email_attempts[0]
match = re.search("https://test_server(.*)", email_text)
assert match is not None
validation_url = match.group(1)
# "Click" the link in the email to validate the adress
self.make_request("GET", validation_url.encode("utf-8"))
# launch 2 simultaneous register request, only one account
# should be created after that.
register_content = {
"auth": {
"session": session,
"threepid_creds": {
"client_secret": "client_secret",
"sid": sid,
},
"type": "m.login.email.identity",
},
"password": "password",
}
register1_channel = self.make_request(
"POST", b"register", register_content, await_result=False
)
register2_channel = self.make_request(
"POST", b"register", register_content, await_result=False
)
while (
not register1_channel.is_finished() or not register2_channel.is_finished()
):
self.pump()
self.assertEqual(
register1_channel.json_body["user_id"],
register2_channel.json_body["user_id"],
)