From e97d1cf0014668b9d4883d4175b783088444b24b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 9 Jan 2020 17:21:30 +0000 Subject: [PATCH] Modify check_database to take a connection rather than a cursor We might not need the cursor at all. --- scripts/synapse_port_db | 25 +++++++------------------ synapse/storage/data_stores/__init__.py | 2 +- synapse/storage/engines/postgres.py | 17 +++++++++-------- synapse/storage/engines/sqlite.py | 2 +- 4 files changed, 18 insertions(+), 28 deletions(-) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index cb77314f1e..a3dafaffc9 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -447,15 +447,6 @@ class Porter(object): else: return - def setup_db(self, db_config: DatabaseConnectionConfig, engine): - db_conn = make_conn(db_config, engine) - prepare_database(db_conn, engine, config=None) - - db_conn.commit() - - return db_conn - - @defer.inlineCallbacks def build_db_store(self, db_config: DatabaseConnectionConfig): """Builds and returns a database store using the provided configuration. @@ -468,16 +459,14 @@ class Porter(object): self.progress.set_state("Preparing %s" % db_config.config["name"]) engine = create_engine(db_config.config) - conn = self.setup_db(db_config, engine) hs = MockHomeserver(self.hs_config) - store = Store(Database(hs, db_config, engine), conn, hs) - - yield store.db.runInteraction( - "%s_engine.check_database" % db_config.config["name"], - engine.check_database, - ) + with make_conn(db_config, engine) as db_conn: + engine.check_database(db_conn) + prepare_database(db_conn, engine, config=None) + store = Store(Database(hs, db_config, engine), db_conn, hs) + db_conn.commit() return store @@ -502,7 +491,7 @@ class Porter(object): @defer.inlineCallbacks def run(self): try: - self.sqlite_store = yield self.build_db_store( + self.sqlite_store = self.build_db_store( DatabaseConnectionConfig("master-sqlite", self.sqlite_config) ) @@ -518,7 +507,7 @@ class Porter(object): ) defer.returnValue(None) - self.postgres_store = yield self.build_db_store( + self.postgres_store = self.build_db_store( self.hs_config.get_single_database() ) diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 092e803799..e1d03429ca 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -47,7 +47,7 @@ class DataStores(object): with make_conn(database_config, engine) as db_conn: logger.info("Preparing database %r...", db_name) - engine.check_database(db_conn.cursor()) + engine.check_database(db_conn) prepare_database( db_conn, engine, hs.config, data_stores=database_config.data_stores, ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index b7c4eda338..ba19785fd7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -32,14 +32,15 @@ class PostgresEngine(object): self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet - def check_database(self, txn): - txn.execute("SHOW SERVER_ENCODING") - rows = txn.fetchall() - if rows and rows[0][0] != "UTF8": - raise IncorrectDatabaseSetup( - "Database has incorrect encoding: '%s' instead of 'UTF8'\n" - "See docs/postgres.rst for more information." % (rows[0][0],) - ) + def check_database(self, db_conn): + with db_conn.cursor() as txn: + txn.execute("SHOW SERVER_ENCODING") + rows = txn.fetchall() + if rows and rows[0][0] != "UTF8": + raise IncorrectDatabaseSetup( + "Database has incorrect encoding: '%s' instead of 'UTF8'\n" + "See docs/postgres.rst for more information." % (rows[0][0],) + ) def convert_param_style(self, sql): return sql.replace("?", "%s") diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index df039a072d..3b3c13360b 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -53,7 +53,7 @@ class Sqlite3Engine(object): """ return False - def check_database(self, txn): + def check_database(self, db_conn): pass def convert_param_style(self, sql):