Remove concept of a non-limited stream. (#7011)

This commit is contained in:
Erik Johnston 2020-03-20 14:40:47 +00:00 committed by GitHub
parent caec7d4fa0
commit fdb1344716
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 72 additions and 68 deletions

1
changelog.d/7011.misc Normal file
View file

@ -0,0 +1 @@
Remove concept of a non-limited stream.

View file

@ -747,7 +747,7 @@ class PresenceHandler(object):
return False
async def get_all_presence_updates(self, last_id, current_id):
async def get_all_presence_updates(self, last_id, current_id, limit):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@ -762,7 +762,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
rows = await self.store.get_all_presence_updates(last_id, current_id)
rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
return rows
def notify_new_event(self):

View file

@ -15,6 +15,7 @@
import logging
from collections import namedtuple
from typing import List
from twisted.internet import defer
@ -257,7 +258,13 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
async def get_all_typing_updates(self, last_id, current_id):
async def get_all_typing_updates(
self, last_id: int, current_id: int, limit: int
) -> List[dict]:
"""Get up to `limit` typing updates between the given tokens, earliest
updates first.
"""
if last_id == current_id:
return []
@ -275,7 +282,7 @@ class TypingHandler(object):
typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing)))
rows.sort()
return rows
return rows[:limit]
def get_current_token(self):
return self._latest_room_serial

View file

@ -166,11 +166,6 @@ class ReplicationStreamer(object):
self.pending_updates = False
with Measure(self.clock, "repl.stream.get_updates"):
# First we tell the streams that they should update their
# current tokens.
for stream in self.streams:
stream.advance_current_token()
all_streams = self.streams
if self._replication_torture_level is not None:
@ -180,7 +175,7 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
if stream.last_token == stream.upto_token:
if stream.last_token == stream.current_token():
continue
if self._replication_torture_level:
@ -192,7 +187,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
stream.upto_token,
stream.current_token(),
)
try:
updates, current_token = await stream.get_updates()

View file

@ -17,10 +17,12 @@
import itertools
import logging
from collections import namedtuple
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple
import attr
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@ -119,13 +121,12 @@ class Stream(object):
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
time it was called up until the point `advance_current_token` was called.
time it was called.
"""
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
def parse_row(cls, row):
@ -146,26 +147,15 @@ class Stream(object):
# The token from which we last asked for updates
self.last_token = self.current_token()
# The token that we will get updates up to
self.upto_token = self.current_token()
def advance_current_token(self):
"""Updates `upto_token` to "now", which updates up until which point
get_updates[_since] will fetch rows till.
"""
self.upto_token = self.current_token()
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.upto_token = self.current_token()
self.last_token = self.upto_token
self.last_token = self.current_token()
async def get_updates(self):
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before),
until the `upto_token`
since the stream was constructed if it hadn't been called before).
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
@ -178,44 +168,45 @@ class Stream(object):
return updates, current_token
async def get_updates_since(self, from_token):
async def get_updates_since(
self, from_token: int
) -> Tuple[List[Tuple[int, JsonDict]], int]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
Resolves to a pair `(updates, new_last_token)`, where `updates` is
a list of `(token, row)` entries and `new_last_token` is the new
position in stream.
"""
if from_token in ("NOW", "now"):
return [], self.upto_token
current_token = self.upto_token
if from_token in ("NOW", "now"):
return [], self.current_token()
current_token = self.current_token()
from_token = int(from_token)
if from_token == current_token:
return [], current_token
logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED:
rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
rows = await self.update_function(from_token, current_token)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
if len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
# The update function didn't hit the limit, so we must have got all
# the updates to `current_token`, and can return that as our new
# stream position.
return updates, current_token
def current_token(self):
@ -227,9 +218,8 @@ class Stream(object):
"""
raise NotImplementedError()
def update_function(self, from_token, current_token, limit=None):
"""Get updates between from_token and to_token. If Stream._LIMITED is
True then limit is provided, otherwise it's not.
def update_function(self, from_token, current_token, limit):
"""Get updates between from_token and to_token.
Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for
@ -257,7 +247,6 @@ class BackfillStream(Stream):
class PresenceStream(Stream):
NAME = "presence"
_LIMITED = False
ROW_TYPE = PresenceStreamRow
def __init__(self, hs):
@ -272,7 +261,6 @@ class PresenceStream(Stream):
class TypingStream(Stream):
NAME = "typing"
_LIMITED = False
ROW_TYPE = TypingStreamRow
def __init__(self, hs):
@ -372,7 +360,6 @@ class DeviceListsStream(Stream):
"""
NAME = "device_lists"
_LIMITED = False
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs):
@ -462,7 +449,6 @@ class UserSignatureStream(Stream):
"""
NAME = "user_signature"
_LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):

View file

@ -576,7 +576,7 @@ class DeviceWorkerStore(SQLBaseStore):
return set()
async def get_all_device_list_changes_for_remotes(
self, from_key: int, to_key: int
self, from_key: int, to_key: int, limit: int,
) -> List[Tuple[int, str]]:
"""Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is
@ -592,10 +592,16 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
return await self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
"get_all_device_list_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
)
@cached(max_entries=10000)

View file

@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
"""
sql = """
SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id
ORDER BY stream_id ASC
LIMIT ?
"""
return self.db.execute(
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
"get_all_user_signature_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
)

View file

@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
for state in presence_states
for stream_id, state in zip(stream_orderings, presence_states)
],
)
@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
)
txn.execute(sql + clause, [stream_id] + list(args))
def get_all_presence_updates(self, last_id, current_id):
def get_all_presence_updates(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_presence_updates_txn(txn):
sql = (
"SELECT stream_id, user_id, state, last_active_ts,"
" last_federation_update_ts, last_user_sync_ts, status_msg,"
" currently_active"
" FROM presence_stream"
" WHERE ? < stream_id AND stream_id <= ?"
)
txn.execute(sql, (last_id, current_id))
sql = """
SELECT stream_id, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts,
status_msg,
currently_active
FROM presence_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(