Merge pull request #684 from matrix-org/markjh/backfill_id_gen

Use a stream id generator for backfilled ids
This commit is contained in:
Mark Haines 2016-04-01 15:13:14 +01:00
commit f2b916534b
11 changed files with 71 additions and 61 deletions

View file

@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore,
self.hs = hs self.hs = hs
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
cur = db_conn.cursor()
try:
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
self.min_stream_token = min(self.min_stream_token, -1)
finally:
cur.close()
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering" db_conn, "events", "stream_ordering"
) )
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1
)
self._receipts_id_gen = StreamIdGenerator( self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
events_max = self._stream_id_gen.get_max_token() events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",
entity_column="room_id", entity_column="room_id",
@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max,
) )
account_max = self._account_data_id_gen.get_max_token() account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max, "AccountDataAndTagsChangeCache", account_max,
) )
@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", db_conn, "presence_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(), max_value=self._presence_id_gen.get_current_token(),
) )
self.presence_stream_cache = StreamChangeCache( self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val, "PresenceStreamChangeCache", min_presence_val,
@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "push_rules_stream", db_conn, "push_rules_stream",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0], max_value=self._push_rules_stream_id_gen.get_current_token()[0],
) )
self.push_rules_stream_cache = StreamChangeCache( self.push_rules_stream_cache = StreamChangeCache(

View file

@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id "add_room_account_data", add_account_data_txn, next_id
) )
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id "add_user_account_data", add_account_data_txn, next_id
) )
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id): def _update_max_stream_id(self, txn, next_id):

View file

@ -24,7 +24,6 @@ from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
import logging import logging
@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore):
return return
if backfilled: if backfilled:
start = self.min_stream_token - 1 stream_ordering_manager = self._backfill_id_gen.get_next_mult(
self.min_stream_token -= len(events_and_contexts) + 1 len(events_and_contexts)
stream_orderings = range(start, self.min_stream_token, -1) )
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else: else:
stream_ordering_manager = self._stream_id_gen.get_next_mult( stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts) len(events_and_contexts)
@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_max_token() max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((stream_ordering, max_persisted_id)) defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore):
def get_current_backfill_token(self): def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached""" """The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
# TODO: Fix race with the persit_event txn by using one of the
# stream id managers
return -self.min_stream_token
def get_all_new_events(self, last_backfill_id, last_forward_id, def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit): current_backfill_id, current_forward_id, limit):

View file

@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore):
self._update_presence_txn, stream_orderings, presence_states, self._update_presence_txn, stream_orderings, presence_states,
) )
defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) defer.returnValue((
stream_orderings[-1], self._presence_id_gen.get_current_token()
))
def _update_presence_txn(self, txn, stream_orderings, presence_states): def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states):
@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore):
defer.returnValue([UserPresenceState(**row) for row in rows]) defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_max_token() return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(

View file

@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore):
"""Get the position of the push rules stream. """Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to.""" room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token() return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id): def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):

View file

@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
def get_pushers_stream_token(self): def get_pushers_stream_token(self):
return self._pushers_id_gen.get_max_token() return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers(self, last_id, current_id, limit):
def get_all_updated_pushers_txn(txn): def get_all_updated_pushers_txn(txn):

View file

@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
) )
@cached(num_args=2) @cached(num_args=2)
@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token() return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id): user_id, event_id, data, stream_id):
@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data room_id, receipt_type, user_id, event_ids, data
) )
max_persisted_id = self._stream_id_gen.get_max_token() max_persisted_id = self._stream_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))

View file

@ -458,4 +458,4 @@ class StateStore(SQLBaseStore):
) )
def get_state_stream_token(self): def get_state_stream_token(self):
return self._state_groups_id_gen.get_max_token() return self._state_groups_id_gen.get_current_token()

View file

@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'): def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token() token = yield self._stream_id_gen.get_current_token()
if direction != 'b': if direction != 'b':
defer.returnValue("s%d" % (token,)) defer.returnValue("s%d" % (token,))
else: else:

View file

@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns: Returns:
A deferred int. A deferred int.
""" """
return self._account_data_id_gen.get_max_token() return self._account_data_id_gen.get_current_token()
@cached() @cached()
def get_tags_for_user(self, user_id): def get_tags_for_user(self, user_id):
@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(self, txn, user_id, room_id, next_id):

View file

@ -21,7 +21,7 @@ import threading
class IdGenerator(object): class IdGenerator(object):
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
self._lock = threading.Lock() self._lock = threading.Lock()
self._next_id = _load_max_id(db_conn, table, column) self._next_id = _load_current_id(db_conn, table, column)
def get_next(self): def get_next(self):
with self._lock: with self._lock:
@ -29,12 +29,16 @@ class IdGenerator(object):
return self._next_id return self._next_id
def _load_max_id(db_conn, table, column): def _load_current_id(db_conn, table, column, step=1):
cur = db_conn.cursor() cur = db_conn.cursor()
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone() val, = cur.fetchone()
cur.close() cur.close()
return int(val) if val else 1 current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator(object): class StreamIdGenerator(object):
@ -45,17 +49,32 @@ class StreamIdGenerator(object):
all ids less than or equal to it have completed. This handles the fact that all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order. persistence of events can complete out of order.
Args:
db_conn(connection): A database connection to use to fetch the
initial value of the generator from.
table(str): A database table to read the initial value of the id
generator from.
column(str): The column of the database table to read the initial
value from the id generator from.
extra_tables(list): List of pairs of database tables and columns to
use to source the initial value of the generator from. The value
with the largest magnitude is used.
step(int): which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards.
Usage: Usage:
with stream_id_gen.get_next() as stream_id: with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column, extra_tables=[]): def __init__(self, db_conn, table, column, extra_tables=[], step=1):
assert step != 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column) self._step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables: for table, column in extra_tables:
self._current_max = max( self._current = (max if step > 0 else min)(
self._current_max, self._current,
_load_max_id(db_conn, table, column) _load_current_id(db_conn, table, column, step)
) )
self._unfinished_ids = deque() self._unfinished_ids = deque()
@ -66,8 +85,8 @@ class StreamIdGenerator(object):
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
self._current_max += 1 self._current += self._step
next_id = self._current_max next_id = self._current
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -88,8 +107,12 @@ class StreamIdGenerator(object):
# ... persist events ... # ... persist events ...
""" """
with self._lock: with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1) next_ids = range(
self._current_max += n self._current + self._step,
self._current + self._step * (n + 1),
self._step
)
self._current += n
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@ -105,15 +128,15 @@ class StreamIdGenerator(object):
return manager() return manager()
def get_max_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
return self._unfinished_ids[0] - 1 return self._unfinished_ids[0] - self._step
return self._current_max return self._current
class ChainedIdGenerator(object): class ChainedIdGenerator(object):
@ -125,7 +148,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column): def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator self.chained_generator = chained_generator
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column) self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque() self._unfinished_ids = deque()
def get_next(self): def get_next(self):
@ -137,7 +160,7 @@ class ChainedIdGenerator(object):
with self._lock: with self._lock:
self._current_max += 1 self._current_max += 1
next_id = self._current_max next_id = self._current_max
chained_id = self.chained_generator.get_max_token() chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id)) self._unfinished_ids.append((next_id, chained_id))
@ -151,7 +174,7 @@ class ChainedIdGenerator(object):
return manager() return manager()
def get_max_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
@ -160,4 +183,4 @@ class ChainedIdGenerator(object):
stream_id, chained_id = self._unfinished_ids[0] stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id) return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token()) return (self._current_max, self.chained_generator.get_current_token())