mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-20 19:10:45 +03:00
Fix a race when registering via email 3pid
This commit is contained in:
parent
cf5adc80e1
commit
027b4af5ac
3 changed files with 113 additions and 2 deletions
1
changelog.d/16827.bugfix
Normal file
1
changelog.d/16827.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a race when registering via email 3pid where 2 different user ids would be created.
|
|
@ -75,6 +75,8 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
USER_REGISTRATION_LOCK_NAME = "user_registration"
|
||||
|
||||
|
||||
class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/register/email/requestToken$")
|
||||
|
@ -417,6 +419,7 @@ class RegisterRestServlet(RestServlet):
|
|||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self.ratelimiter = hs.get_registration_ratelimiter()
|
||||
self.password_policy_handler = hs.get_password_policy_handler()
|
||||
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.password_auth_provider = hs.get_password_auth_provider()
|
||||
self._registration_enabled = self.hs.config.registration.enable_registration
|
||||
|
@ -508,6 +511,23 @@ class RegisterRestServlet(RestServlet):
|
|||
"An access token should not be provided on requests to /register (except if type is m.login.application_service)",
|
||||
)
|
||||
|
||||
# Take a global lock when doing user registration to avoid races,
|
||||
# for example when doing 3pid email binding.
|
||||
async with self._worker_lock_handler.acquire_lock(
|
||||
USER_REGISTRATION_LOCK_NAME, ""
|
||||
):
|
||||
return await self._do_user_register(
|
||||
desired_username, client_addr, body, should_issue_refresh_token, request
|
||||
)
|
||||
|
||||
async def _do_user_register(
|
||||
self,
|
||||
desired_username: Optional[str],
|
||||
address: str,
|
||||
body: JsonDict,
|
||||
should_issue_refresh_token: bool,
|
||||
request: SynapseRequest,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
# == Normal User Registration == (everyone else)
|
||||
if not self._registration_enabled:
|
||||
raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
|
||||
|
@ -702,7 +722,7 @@ class RegisterRestServlet(RestServlet):
|
|||
guest_access_token=guest_access_token,
|
||||
threepid=threepid,
|
||||
default_display_name=display_name,
|
||||
address=client_addr,
|
||||
address=address,
|
||||
user_agent_ips=entries,
|
||||
)
|
||||
# Necessary due to auth checks prior to the threepid being
|
||||
|
|
|
@ -21,7 +21,8 @@
|
|||
#
|
||||
import datetime
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pkg_resources
|
||||
|
||||
|
@ -42,6 +43,7 @@ from synapse.types import JsonDict
|
|||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import ThreadedMemoryReactorClock
|
||||
from tests.unittest import override_config
|
||||
|
||||
|
||||
|
@ -1248,3 +1250,91 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
|
|||
f"{self.url}?token={token}",
|
||||
)
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
|
||||
|
||||
class EmailRegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [register.register_servlets]
|
||||
|
||||
def make_homeserver(
|
||||
self, reactor: ThreadedMemoryReactorClock, clock: Clock
|
||||
) -> HomeServer:
|
||||
hs = super().make_homeserver(reactor, clock)
|
||||
|
||||
async def send_email(
|
||||
email_address: str,
|
||||
subject: str,
|
||||
app_name: str,
|
||||
html: str,
|
||||
text: str,
|
||||
additional_headers: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
self.email_attempts.append(text)
|
||||
|
||||
self.email_attempts: List[str] = []
|
||||
hs.get_send_email_handler().send_email = send_email # type: ignore[method-assign]
|
||||
return hs
|
||||
|
||||
@unittest.override_config(
|
||||
{
|
||||
"public_baseurl": "https://test_server",
|
||||
"registrations_require_3pid": ["email"],
|
||||
"disable_msisdn_registration": True,
|
||||
"email": {
|
||||
"smtp_host": "mail_server",
|
||||
"smtp_port": 2525,
|
||||
"notif_from": "sender@host",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_email_3pid_registration_race(self) -> None:
|
||||
channel = self.make_request("POST", b"register", {"password": "password"})
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# request a token to be sent by email for validation
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
b"register/email/requestToken",
|
||||
{
|
||||
"client_secret": "client_secret",
|
||||
"email": "email@email",
|
||||
"send_attempt": 1,
|
||||
},
|
||||
)
|
||||
sid = channel.json_body["sid"]
|
||||
|
||||
email_text = self.email_attempts[0]
|
||||
match = re.search("https://test_server(.*)", email_text)
|
||||
assert match is not None
|
||||
validation_url = match.group(1)
|
||||
|
||||
# "Click" the link in the email to validate the adress
|
||||
self.make_request("GET", validation_url.encode("utf-8"))
|
||||
|
||||
# launch 2 simultaneous register request, only one account
|
||||
# should be created after that.
|
||||
register_content = {
|
||||
"auth": {
|
||||
"session": session,
|
||||
"threepid_creds": {
|
||||
"client_secret": "client_secret",
|
||||
"sid": sid,
|
||||
},
|
||||
"type": "m.login.email.identity",
|
||||
},
|
||||
"password": "password",
|
||||
}
|
||||
register1_channel = self.make_request(
|
||||
"POST", b"register", register_content, await_result=False
|
||||
)
|
||||
register2_channel = self.make_request(
|
||||
"POST", b"register", register_content, await_result=False
|
||||
)
|
||||
while (
|
||||
not register1_channel.is_finished() or not register2_channel.is_finished()
|
||||
):
|
||||
self.pump()
|
||||
|
||||
self.assertEqual(
|
||||
register1_channel.json_body["user_id"],
|
||||
register2_channel.json_body["user_id"],
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue