diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 2eb795192f..22a6abd7d2 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -48,6 +48,7 @@ from synapse.storage.data_stores.main.media_repository import ( ) from synapse.storage.data_stores.main.registration import ( RegistrationBackgroundUpdateStore, + find_max_generated_user_id_localpart, ) from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore @@ -622,8 +623,10 @@ class Porter(object): ) ) - # Step 5. Do final post-processing + # Step 5. Set up sequences + self.progress.set_state("Setting up sequence generators") await self._setup_state_group_id_seq() + await self._setup_user_id_seq() self.progress.done() except Exception as e: @@ -793,6 +796,13 @@ class Porter(object): return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) + def _setup_user_id_seq(self): + def r(txn): + next_id = find_max_generated_user_id_localpart(txn) + 1 + txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) + + return self.postgres_store.db.runInteraction("setup_user_id_seq", r) + ############################################## # The following is simply UI stuff diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 78c3772ac1..501f0fe795 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -28,7 +28,6 @@ from synapse.replication.http.register import ( ) from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester -from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -50,14 +49,7 @@ class RegistrationHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() self.identity_handler = self.hs.get_handlers().identity_handler self.ratelimiter = hs.get_registration_ratelimiter() - - self._next_generated_user_id = None - self.macaroon_gen = hs.get_macaroon_generator() - - self._generate_user_id_linearizer = Linearizer( - name="_generate_user_id_linearizer" - ) self._server_notices_mxid = hs.config.server_notices_mxid if hs.config.worker_app: @@ -219,7 +211,7 @@ class RegistrationHandler(BaseHandler): if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = await self._generate_user_id() + localpart = await self.store.generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) @@ -510,18 +502,6 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - async def _generate_user_id(self): - if self._next_generated_user_id is None: - with await self._generate_user_id_linearizer.queue(()): - if self._next_generated_user_id is None: - self._next_generated_user_id = ( - await self.store.find_next_generated_user_id_localpart() - ) - - id = self._next_generated_user_id - self._next_generated_user_id += 1 - return str(id) - def check_registration_ratelimit(self, address): """A simple helper method to check whether the registration rate limit has been hit for a given IP address diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 587d4b91c1..27d2c5028c 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore): self.config = hs.config self.clock = hs.get_clock() + self._user_id_seq = build_sequence_generator( + database.engine, find_max_generated_user_id_localpart, "user_id_seq", + ) + @cached() def get_user_by_id(self, user_id): return self.db.simple_select_one( @@ -481,39 +487,17 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self.db.runInteraction("count_real_users", _count_users) return ret - @defer.inlineCallbacks - def find_next_generated_user_id_localpart(self): + async def generate_user_id(self) -> str: + """Generate a suitable localpart for a guest user + + Returns: a (hopefully) free localpart """ - Gets the localpart of the next generated user ID. - - Generated user IDs are integers, so we find the largest integer user ID - already taken and return that plus one. - """ - - def _find_next_generated_user_id(txn): - # We bound between '@0' and '@a' to avoid pulling the entire table - # out. - txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") - - regex = re.compile(r"^@(\d+):") - - max_found = 0 - - for (user_id,) in txn: - match = regex.search(user_id) - if match: - max_found = max(int(match.group(1)), max_found) - - return max_found + 1 - - return ( - ( - yield self.db.runInteraction( - "find_next_generated_user_id", _find_next_generated_user_id - ) - ) + next_id = await self.db.runInteraction( + "generate_user_id", self._user_id_seq.get_next_id_txn ) + return str(next_id) + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: """Returns user id from threepid @@ -1573,3 +1557,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): keyvalues={"user_id": user_id}, values={"expiration_ts_ms": expiration_ts, "email_sent": False}, ) + + +def find_max_generated_user_id_localpart(cur: Cursor) -> int: + """ + Gets the localpart of the max current generated user ID. + + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that. + """ + + # We bound between '@0' and '@a' to avoid pulling the entire table + # out. + cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") + + regex = re.compile(r"^@(\d+):") + + max_found = 0 + + for (user_id,) in cur: + match = regex.search(user_id) + if match: + max_found = max(int(match.group(1)), max_found) + return max_found diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py new file mode 100644 index 0000000000..2011f6bceb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py @@ -0,0 +1,34 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adds a postgres SEQUENCE for generating guest user IDs. +""" + +from synapse.storage.data_stores.main.registration import ( + find_max_generated_user_id_localpart, +) +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + next_id = find_max_generated_user_id_localpart(cur) + 1 + cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,)) + + +def run_upgrade(*args, **kwargs): + pass