Remove race condition

This commit is contained in:
Erik Johnston 2015-05-14 16:54:35 +01:00
parent ef3d8754f5
commit 1d566edb81
4 changed files with 161 additions and 100 deletions

View file

@ -26,6 +26,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import contextlib
import functools import functools
import sys import sys
import time import time
@ -299,7 +301,7 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._event_fetch_lock = threading.Lock() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
@ -342,22 +344,8 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
@defer.inlineCallbacks @contextlib.contextmanager
def runInteraction(self, desc, func, *args, **kwargs): def _new_transaction(self, conn, desc, after_callbacks):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
after_callbacks = []
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
start = time.time() * 1000 start = time.time() * 1000
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -367,8 +355,8 @@ class SQLBaseStore(object):
name = "%s-%x" % (desc, txn_id, ) name = "%s-%x" % (desc, txn_id, )
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name) transaction_logger.debug("[TXN START] {%s}", name)
try: try:
i = 0 i = 0
N = 5 N = 5
@ -378,7 +366,6 @@ class SQLBaseStore(object):
txn = LoggingTransaction( txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks txn, name, self.database_engine, after_callbacks
) )
return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e: except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid # This can happen if the database disappears mid
# transaction. # transaction.
@ -396,6 +383,7 @@ class SQLBaseStore(object):
name, e1, name, e1,
) )
continue continue
raise
except self.database_engine.module.DatabaseError as e: except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e): if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
@ -410,6 +398,17 @@ class SQLBaseStore(object):
) )
continue continue
raise raise
try:
yield txn
conn.commit()
return
except:
try:
conn.rollback()
except:
pass
raise
except Exception as e: except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e) logger.debug("[TXN FAIL] {%s} %s", name, e)
raise raise
@ -423,6 +422,27 @@ class SQLBaseStore(object):
self._txn_perf_counters.update(desc, start, end) self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc) sql_txn_timer.inc_by(duration, desc)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
after_callbacks = []
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
with self._new_transaction(conn, desc, after_callbacks) as txn:
return func(txn, *args, **kwargs)
result = yield preserve_context_over_fn( result = yield preserve_context_over_fn(
self._db_pool.runWithConnection, self._db_pool.runWithConnection,
inner_func, *args, **kwargs inner_func, *args, **kwargs
@ -432,6 +452,32 @@ class SQLBaseStore(object):
after_callback(*after_args) after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
defer.returnValue(result)
def cursor_to_dict(self, cursor): def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.

View file

@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object): class PostgresEngine(object):
single_threaded = False
def __init__(self, database_module): def __init__(self, database_module):
self.module = database_module self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE) self.module.extensions.register_type(self.module.extensions.UNICODE)

View file

@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True
def __init__(self, database_module): def __init__(self, database_module):
self.module = database_module self.module = database_module

View file

@ -504,22 +504,25 @@ class EventsStore(SQLBaseStore):
if not events: if not events:
defer.returnValue({}) defer.returnValue({})
def do_fetch(txn): def do_fetch(conn):
event_list = [] event_list = []
while True: while True:
try: try:
with self._event_fetch_lock: with self._event_fetch_lock:
event_list = self._event_fetch_list i = 0
self._event_fetch_list = [] while not self._event_fetch_list:
if not event_list:
self._event_fetch_ongoing -= 1 self._event_fetch_ongoing -= 1
return return
event_list = self._event_fetch_list
self._event_fetch_list = []
event_id_lists = zip(*event_list)[0] event_id_lists = zip(*event_list)[0]
event_ids = [ event_ids = [
item for sublist in event_id_lists for item in sublist item for sublist in event_id_lists for item in sublist
] ]
with self._new_transaction(conn, "do_fetch", []) as txn:
rows = self._fetch_event_rows(txn, event_ids) rows = self._fetch_event_rows(txn, event_ids)
row_dict = { row_dict = {
@ -528,22 +531,44 @@ class EventsStore(SQLBaseStore):
} }
for ids, d in event_list: for ids, d in event_list:
reactor.callFromThread( def fire():
d.callback, if not d.called:
d.callback(
[ [
row_dict[i] for i in ids row_dict[i]
for i in ids
if i in row_dict if i in row_dict
] ]
) )
reactor.callFromThread(fire)
except Exception as e: except Exception as e:
logger.exception("do_fetch")
for _, d in event_list: for _, d in event_list:
try: if not d.called:
reactor.callFromThread(d.errback, e) reactor.callFromThread(d.errback, e)
except:
pass
def cb(rows): with self._event_fetch_lock:
return defer.gatherResults([ self._event_fetch_ongoing -= 1
return
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify_all()
# if self._event_fetch_ongoing < 5:
self._event_fetch_ongoing += 1
self.runWithConnection(
do_fetch
)
rows = yield events_d
res = yield defer.gatherResults(
[
self._get_event_from_row( self._get_event_from_row(
None, None,
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],
@ -552,24 +577,10 @@ class EventsStore(SQLBaseStore):
rejected_reason=row["rejects"], rejected_reason=row["rejects"],
) )
for row in rows for row in rows
]) ],
consumeErrors=True
d = defer.Deferred()
d.addCallback(cb)
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, d)
) )
if self._event_fetch_ongoing < 3:
self._event_fetch_ongoing += 1
self.runInteraction(
"do_fetch",
do_fetch
)
res = yield d
defer.returnValue({ defer.returnValue({
e.event_id: e e.event_id: e
for e in res if e for e in res if e