prepare_database() on db_conn, not plain name, so we can pass in the connection from outside

This commit is contained in:
Paul "LeoNerd" Evans 2014-09-10 16:23:58 +01:00
parent 2faffc52ee
commit 55397f6347
2 changed files with 36 additions and 33 deletions

View file

@ -39,6 +39,7 @@ import logging
import os import os
import re import re
import sys import sys
import sqlite3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -208,7 +209,14 @@ def setup():
redirect_root_to_web_client=True, redirect_root_to_web_client=True,
) )
prepare_database(hs.get_db_name()) db_name = hs.get_db_name()
logging.info("Preparing database: %s...", db_name)
with sqlite3.connect(db_name) as db_conn:
prepare_database(db_conn)
logging.info("Database prepared in %s.", db_name)
hs.get_db_pool() hs.get_db_pool()

View file

@ -43,7 +43,6 @@ from .keys import KeyStore
import json import json
import logging import logging
import os import os
import sqlite3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -370,44 +369,40 @@ def read_schema(schema):
return schema_file.read() return schema_file.read()
def prepare_database(db_name): def prepare_database(db_conn):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
don't have to worry about overwriting existing content. don't have to worry about overwriting existing content.
""" """
logging.info("Preparing database: %s...", db_name) c = db_conn.cursor()
c.execute("PRAGMA user_version")
row = c.fetchone()
with sqlite3.connect(db_name) as db_conn: if row and row[0]:
c = db_conn.cursor() user_version = row[0]
c.execute("PRAGMA user_version")
row = c.fetchone()
if row and row[0]: if user_version > SCHEMA_VERSION:
user_version = row[0] raise ValueError("Cannot use this database as it is too " +
"new for the server to understand"
if user_version > SCHEMA_VERSION: )
raise ValueError("Cannot use this database as it is too " + elif user_version < SCHEMA_VERSION:
"new for the server to understand" logging.info("Upgrading database from version %d",
) user_version
elif user_version < SCHEMA_VERSION: )
logging.info("Upgrading database from version %d",
user_version
)
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script)
db_conn.commit()
else:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script) c.executescript(sql_script)
db_conn.commit() db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close() else:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()
logging.info("Database prepared in %s.", db_name)