diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 1a859352b6..1f6c93b73b 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -37,7 +37,57 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) -class RegistrationWorkerStore(SQLBaseStore): +class RegistrationDeactivationStore(SQLBaseStore): + @cachedInlineCallbacks() + def get_user_deactivated_status(self, user_id): + """Retrieve the value for the `deactivated` property for the provided user. + + Args: + user_id (str): The ID of the user to retrieve the status for. + + Returns: + defer.Deferred(bool): The requested value. + """ + + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="deactivated", + desc="get_user_deactivated_status", + ) + + # Convert the integer into a boolean. + return res == 1 + + @defer.inlineCallbacks + def set_user_deactivated_status(self, user_id, deactivated): + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id (str): The ID of the user to set the status for. + deactivated (bool): The value to set for `deactivated`. + """ + + yield self.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + self._simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + + +class RegistrationWorkerStore(RegistrationDeactivationStore): def __init__(self, db_conn, hs): super(RegistrationWorkerStore, self).__init__(db_conn, hs) @@ -673,27 +723,6 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_id_servers_user_bound", ) - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): - """Retrieve the value for the `deactivated` property for the provided user. - - Args: - user_id (str): The ID of the user to retrieve the status for. - - Returns: - defer.Deferred(bool): The requested value. - """ - - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="deactivated", - desc="get_user_deactivated_status", - ) - - # Convert the integer into a boolean. - return res == 1 - def get_threepid_validation_session( self, medium, client_secret, address=None, sid=None, validated=True ): @@ -787,13 +816,14 @@ class RegistrationWorkerStore(SQLBaseStore): ) -class RegistrationStore( - RegistrationWorkerStore, background_updates.BackgroundUpdateStore +class RegistrationBackgroundUpdateStore( + RegistrationDeactivationStore, background_updates.BackgroundUpdateStore ): def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) + super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() + self.config = hs.config self.register_background_index_update( "access_tokens_device_index", @@ -809,8 +839,6 @@ class RegistrationStore( columns=["creation_ts"], ) - self._account_validity = hs.config.account_validity - # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. @@ -824,17 +852,6 @@ class RegistrationStore( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - @defer.inlineCallbacks def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 @@ -896,6 +913,54 @@ class RegistrationStore( return nb_processed + @defer.inlineCallbacks + def _bg_user_threepids_grandfather(self, progress, batch_size): + """We now track which identity servers a user binds their 3PID to, so + we need to handle the case of existing bindings where we didn't track + this. + + We do this by grandfathering in existing user threepids assuming that + they used one of the server configured trusted identity servers. + """ + id_servers = set(self.config.trusted_third_party_id_servers) + + def _bg_user_threepids_grandfather_txn(txn): + sql = """ + INSERT INTO user_threepid_id_server + (user_id, medium, address, id_server) + SELECT user_id, medium, address, ? + FROM user_threepids + """ + + txn.executemany(sql, [(id_server,) for id_server in id_servers]) + + if id_servers: + yield self.runInteraction( + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn + ) + + yield self._end_background_update("user_threepids_grandfather") + + return 1 + + +class RegistrationStore(RegistrationWorkerStore, RegistrationBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RegistrationStore, self).__init__(db_conn, hs) + + self._account_validity = hs.config.account_validity + + # Create a background job for culling expired 3PID validity tokens + def start_cull(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "cull_expired_threepid_validation_tokens", + self.cull_expired_threepid_validation_tokens, + ) + + hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): """Adds an access token for the given user. @@ -1244,36 +1309,6 @@ class RegistrationStore( desc="get_users_pending_deactivation", ) - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): - """We now track which identity servers a user binds their 3PID to, so - we need to handle the case of existing bindings where we didn't track - this. - - We do this by grandfathering in existing user threepids assuming that - they used one of the server configured trusted identity servers. - """ - id_servers = set(self.config.trusted_third_party_id_servers) - - def _bg_user_threepids_grandfather_txn(txn): - sql = """ - INSERT INTO user_threepid_id_server - (user_id, medium, address, id_server) - SELECT user_id, medium, address, ? - FROM user_threepids - """ - - txn.executemany(sql, [(id_server,) for id_server in id_servers]) - - if id_servers: - yield self.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn - ) - - yield self._end_background_update("user_threepids_grandfather") - - return 1 - def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token @@ -1464,30 +1499,3 @@ class RegistrationStore( cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), ) - - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - - @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. - """ - - yield self.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - )