Support any process writing to cache invalidation stream. (#7436)

This commit is contained in:
Erik Johnston 2020-05-07 13:51:08 +01:00 committed by GitHub
parent 2929ce29d6
commit d7983b63a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 225 additions and 230 deletions

1
changelog.d/7436.misc Normal file
View file

@ -0,0 +1 @@
Support any process writing to cache invalidation stream.

View file

@ -219,10 +219,6 @@ Asks the server for the current position of all streams.
Inform the server a pusher should be removed Inform the server a pusher should be removed
#### INVALIDATE_CACHE (C)
Inform the server a cache should be invalidated
### REMOTE_SERVER_UP (S, C) ### REMOTE_SERVER_UP (S, C)
Inform other processes that a remote server may have come back online. Inform other processes that a remote server may have come back online.

View file

@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [
"presence_stream", "presence_stream",
"push_rules_stream", "push_rules_stream",
"ex_outlier_stream", "ex_outlier_stream",
"cache_invalidation_stream", "cache_invalidation_stream_by_instance",
"public_room_list_stream", "public_room_list_stream",
"state_group_edges", "state_group_edges",
"stream_ordering_to_exterm", "stream_ordering_to_exterm",
@ -188,7 +188,7 @@ class MockHomeserver:
self.clock = Clock(reactor) self.clock = Clock(reactor)
self.config = config self.config = config
self.hostname = config.server_name self.hostname = config.server_name
self.version_string = "Synapse/"+get_version_string(synapse) self.version_string = "Synapse/" + get_version_string(synapse)
def get_clock(self): def get_clock(self):
return self.clock return self.clock

View file

@ -18,14 +18,10 @@ from typing import Optional
import six import six
from synapse.storage.data_stores.main.cache import ( from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
CURRENT_STATE_CACHE_NAME,
CacheInvalidationWorkerStore,
)
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs) super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker( self._cache_id_gen = MultiWriterIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id" db_conn,
) # type: Optional[SlavedIdTracker] database,
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
) # type: Optional[MultiWriterIdGenerator]
else: else:
self._cache_id_gen = None self._cache_id_gen = None
self.hs = hs self.hs = hs
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
raise Exception(
"Can't send an 'invalidate all' for current state cache"
)
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys)
def _send_invalidation_poke(self, cache_func, keys):
self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)

View file

@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "tag_account_data": if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token) self._account_data_id_gen.advance(token)
for row in rows: for row in rows:
@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
(row.user_id, row.room_id, row.data_type) (row.user_id, row.room_id, row.data_type)
) )
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedAccountDataStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)

View file

@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
) )
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "to_device": if stream_name == "to_device":
self._device_inbox_id_gen.advance(token) self._device_inbox_id_gen.advance(token)
for row in rows: for row in rows:
@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_federation_outbox_stream_cache.entity_has_changed( self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token row.entity, token
) )
return super(SlavedDeviceInboxStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)

View file

@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max "DeviceListFederationStreamChangeCache", device_list_max
) )
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token) self._device_list_id_gen.advance(token)
self._invalidate_caches_for_devices(token, rows) self._invalidate_caches_for_devices(token, rows)
@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self._device_list_id_gen.advance(token) self._device_list_id_gen.advance(token)
for row in rows: for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, rows): def _invalidate_caches_for_devices(self, token, rows):
for row in rows: for row in rows:

View file

@ -93,7 +93,7 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self): def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token() return self._backfill_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events": if stream_name == "events":
self._stream_id_gen.advance(token) self._stream_id_gen.advance(token)
for row in rows: for row in rows:
@ -111,9 +111,7 @@ class SlavedEventStore(
row.relates_to, row.relates_to,
backfilled=True, backfilled=True,
) )
return super(SlavedEventStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)
def _process_event_stream_row(self, token, row): def _process_event_stream_row(self, token, row):
data = row.data data = row.data

View file

@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self): def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token() return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "groups": if stream_name == "groups":
self._group_updates_id_gen.advance(token) self._group_updates_id_gen.advance(token)
for row in rows: for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token) self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedGroupServerStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)

View file

@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_current_token() return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "presence": if stream_name == "presence":
self._presence_id_gen.advance(token) self._presence_id_gen.advance(token)
for row in rows: for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token) self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,)) self._get_presence_for_user.invalidate((row.user_id,))
return super(SlavedPresenceStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)

View file

@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self): def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "push_rules": if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token) self._push_rules_stream_id_gen.advance(token)
for row in rows: for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token) self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedPushRuleStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)

View file

@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def get_pushers_stream_token(self): def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "pushers": if stream_name == "pushers":
self._pushers_id_gen.advance(token) self._pushers_id_gen.advance(token)
return super(SlavedPusherStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)

View file

@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type)) self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "receipts": if stream_name == "receipts":
self._receipts_id_gen.advance(token) self._receipts_id_gen.advance(token)
for row in rows: for row in rows:
@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
) )
self._receipts_stream_cache.entity_has_changed(row.room_id, token) self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super(SlavedReceiptsStore, self).process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows)
stream_name, token, rows
)

View file

@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "public_rooms": if stream_name == "public_rooms":
self._public_room_id_gen.advance(token) self._public_room_id_gen.advance(token)
return super(RoomStore, self).process_replication_rows(stream_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -100,10 +100,10 @@ class ReplicationDataHandler:
token: stream token for this batch of rows token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
""" """
self.store.process_replication_rows(stream_name, token, rows) self.store.process_replication_rows(stream_name, instance_name, token, rows)
async def on_position(self, stream_name: str, token: int): async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, token, []) self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str): def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command.""" """Called when get a new REMOTE_SERVER_UP command."""

View file

@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id)) return " ".join((self.app_id, self.push_key, self.user_id))
class InvalidateCacheCommand(Command):
"""Sent by the client to invalidate an upstream cache.
THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
NOT DISASTROUS IF WE DROP ON THE FLOOR.
Mainly used to invalidate destination retry timing caches.
Format::
INVALIDATE_CACHE <cache_func> <keys_json>
Where <keys_json> is a json list.
"""
NAME = "INVALIDATE_CACHE"
def __init__(self, cache_func, keys):
self.cache_func = cache_func
self.keys = keys
@classmethod
def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1)
return cls(cache_func, json.loads(keys_json))
def to_line(self):
return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
class UserIpCommand(Command): class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client. """Sent periodically when a worker sees activity from a client.
@ -439,7 +408,6 @@ _COMMANDS = (
UserSyncCommand, UserSyncCommand,
FederationAckCommand, FederationAckCommand,
RemovePusherCommand, RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand, UserIpCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
ClearUserSyncsCommand, ClearUserSyncsCommand,
@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
ClearUserSyncsCommand.NAME, ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME, FederationAckCommand.NAME,
RemovePusherCommand.NAME, RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,
UserIpCommand.NAME, UserIpCommand.NAME,
ErrorCommand.NAME, ErrorCommand.NAME,
RemoteServerUpCommand.NAME, RemoteServerUpCommand.NAME,

View file

@ -15,18 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import ( from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
from prometheus_client import Counter from prometheus_client import Counter
@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand, ClearUserSyncsCommand,
Command, Command,
FederationAckCommand, FederationAckCommand,
InvalidateCacheCommand,
PositionCommand, PositionCommand,
RdataCommand, RdataCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
@ -171,7 +159,7 @@ class ReplicationCommandHandler:
return return
for stream_name, stream in self._streams.items(): for stream_name, stream in self._streams.items():
current_token = stream.current_token() current_token = stream.current_token(self._instance_name)
self.send_command( self.send_command(
PositionCommand(stream_name, self._instance_name, current_token) PositionCommand(stream_name, self._instance_name, current_token)
) )
@ -210,18 +198,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data() self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
):
invalidate_cache_counter.inc()
if self._is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self._store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc() user_ip_cache_counter.inc()
@ -295,7 +271,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row. Stream.parse_row.
""" """
logger.debug("Received rdata %s -> %s", stream_name, token) logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata( await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows stream_name, instance_name, token, rows
) )
@ -326,7 +302,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, []) self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to. # Find where we previously streamed up to.
current_token = stream.current_token() current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to # If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates # date and there's nothing to do. Otherwise, fetch all updates
@ -363,7 +339,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler. # We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(stream_name, cmd.token) await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name) self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@ -491,12 +469,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id) cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd) self.send_command(cmd)
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip( def send_user_ip(
self, self,
user_id: str, user_id: str,

View file

@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream from synapse.replication.tcp.streams import (
STREAMS_MAP,
CachesStream,
FederationStream,
Stream,
)
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
stream_updates_counter = Counter( stream_updates_counter = Counter(
@ -71,11 +76,16 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level self._replication_torture_level = hs.config.replication_torture_level
# Work out list of streams that this instance is the source of. # Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream] self.streams = [] # type: List[Stream]
# All workers can write to the cache invalidation stream.
self.streams.append(CachesStream(hs))
if hs.config.worker_app is None: if hs.config.worker_app is None:
for stream in STREAMS_MAP.values(): for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation: if stream == FederationStream and hs.config.send_federation:
@ -83,6 +93,10 @@ class ReplicationStreamer(object):
# has been disabled on the master. # has been disabled on the master.
continue continue
if stream == CachesStream:
# We've already added it above.
continue
self.streams.append(stream(hs)) self.streams.append(stream(hs))
self.streams_by_name = {stream.NAME: stream for stream in self.streams} self.streams_by_name = {stream.NAME: stream for stream in self.streams}
@ -145,7 +159,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams) random.shuffle(all_streams)
for stream in all_streams: for stream in all_streams:
if stream.last_token == stream.current_token(): if stream.last_token == stream.current_token(
self._instance_name
):
continue continue
if self._replication_torture_level: if self._replication_torture_level:
@ -157,7 +173,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s", "Getting stream: %s: %s -> %s",
stream.NAME, stream.NAME,
stream.last_token, stream.last_token,
stream.current_token(), stream.current_token(self._instance_name),
) )
try: try:
updates, current_token, limited = await stream.get_updates() updates, current_token, limited = await stream.get_updates()

View file

@ -95,20 +95,25 @@ class Stream(object):
def __init__( def __init__(
self, self,
local_instance_name: str, local_instance_name: str,
current_token_function: Callable[[], Token], current_token_function: Callable[[str], Token],
update_function: UpdateFunction, update_function: UpdateFunction,
): ):
"""Instantiate a Stream """Instantiate a Stream
current_token_function and update_function are callbacks which should be `current_token_function` and `update_function` are callbacks which
implemented by subclasses. should be implemented by subclasses.
current_token_function is called to get the current token of the underlying `current_token_function` takes an instance name, which is a writer to
stream. It is only meaningful on the process that is the source of the the stream, and returns the position in the stream of the writer (as
replication stream (ie, usually the master). viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).
update_function is called to get updates for this stream between a pair of `update_function` is called to get updates for this stream between a
stream tokens. See the UpdateFunction type definition for more info. pair of stream tokens. See the `UpdateFunction` type definition for more
info.
Args: Args:
local_instance_name: The instance name of the current process local_instance_name: The instance name of the current process
@ -120,13 +125,13 @@ class Stream(object):
self.update_function = update_function self.update_function = update_function
# The token from which we last asked for updates # The token from which we last asked for updates
self.last_token = self.current_token() self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self): def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded, """Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers. e.g. when there are no currently connected workers.
""" """
self.last_token = self.current_token() self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult: async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or """Gets all updates since the last time this function was called (or
@ -138,7 +143,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates position in stream, and `limited` is whether there are more updates
to fetch. to fetch.
""" """
current_token = self.current_token() current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since( updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token self.local_instance_name, self.last_token, current_token
) )
@ -170,6 +175,16 @@ class Stream(object):
return updates, upto_token, limited return updates, upto_token, limited
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
"""Takes a current token callback function for a single writer stream
that doesn't take an instance name parameter and wraps it in a function that
does accept an instance name parameter but ignores it.
"""
return lambda instance_name: current_token()
def db_query_to_update_function( def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction: ) -> UpdateFunction:
@ -235,7 +250,7 @@ class BackfillStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_current_backfill_token, current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows), db_query_to_update_function(store.get_all_new_backfill_event_rows),
) )
@ -271,7 +286,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
super().__init__( super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function hs.get_instance_name(),
current_token_without_instance(store.get_current_presence_token),
update_function,
) )
@ -296,7 +313,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
super().__init__( super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function hs.get_instance_name(),
current_token_without_instance(typing_handler.get_current_token),
update_function,
) )
@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_max_receipt_stream_id, current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts), db_query_to_update_function(store.get_all_updated_receipts),
) )
@ -339,7 +358,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function hs.get_instance_name(), self._current_token, self._update_function
) )
def _current_token(self) -> int: def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token() push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token return push_rules_token
@ -373,7 +392,7 @@ class PushersStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_pushers_stream_token, current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows), db_query_to_update_function(store.get_all_updated_pushers_rows),
) )
@ -402,13 +421,27 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow ROW_TYPE = CachesStreamRow
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() self.store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_cache_stream_token, self.store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches), self._update_function,
) )
async def _update_function(
self, instance_name: str, from_token: int, upto_token: int, limit: int
):
rows = await self.store.get_all_updated_caches(
instance_name, from_token, upto_token, limit
)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
class PublicRoomsStream(Stream): class PublicRoomsStream(Stream):
"""The public rooms list changed """The public rooms list changed
@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_current_public_room_stream_id, current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms), db_query_to_update_function(store.get_all_new_public_rooms),
) )
@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_device_stream_token, current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes), db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
) )
@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_to_device_stream_token, current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages), db_query_to_update_function(store.get_all_new_device_messages),
) )
@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_max_account_data_stream_id, current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags), db_query_to_update_function(store.get_all_updated_tags),
) )
@ -510,7 +543,7 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore() self.store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
self.store.get_max_account_data_stream_id, current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function), db_query_to_update_function(self._update_function),
) )
@ -541,7 +574,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_group_stream_token, current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes), db_query_to_update_function(store.get_all_groups_changes),
) )
@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_device_stream_token, current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function( db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes store.get_all_user_signature_changes_for_remotes
), ),

View file

@ -20,7 +20,7 @@ from typing import List, Tuple, Type
import attr import attr
from ._base import Stream, StreamUpdateResult, Token from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream """Handling of the 'events' replication stream
@ -119,7 +119,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore() self._store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
self._store.get_current_events_token, current_token_without_instance(self._store.get_current_events_token),
self._update_function, self._update_function,
) )

View file

@ -15,7 +15,11 @@
# limitations under the License. # limitations under the License.
from collections import namedtuple from collections import namedtuple
from synapse.replication.tcp.streams._base import Stream, make_http_update_function from synapse.replication.tcp.streams._base import (
Stream,
current_token_without_instance,
make_http_update_function,
)
class FederationStream(Stream): class FederationStream(Stream):
@ -41,7 +45,9 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and # will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.) # get_replication_rows.)
federation_sender = hs.get_federation_sender() federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token current_token = current_token_without_instance(
federation_sender.get_current_token
)
update_function = federation_sender.get_replication_rows update_function = federation_sender.get_replication_rows
elif hs.should_send_federation(): elif hs.should_send_federation():
@ -58,7 +64,7 @@ class FederationStream(Stream):
super().__init__(hs.get_instance_name(), current_token, update_function) super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod @staticmethod
def _stub_current_token(): def _stub_current_token(instance_name: str) -> int:
# dummy current-token method for use on workers # dummy current-token method for use on workers
return 0 return 0

View file

@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta):
self.db = database self.db = database
self.rand = random.SystemRandom() self.rand = random.SystemRandom()
def process_replication_rows(self, stream_name, instance_name, token, rows):
pass
def _invalidate_state_caches(self, room_id, members_changed): def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does """Invalidates caches that are based on the current state, but does
not stream invalidations down replication. not stream invalidations down replication.

View file

@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
ChainedIdGenerator, ChainedIdGenerator,
IdGenerator, IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationStore from .cache import CacheInvalidationWorkerStore
from .client_ips import ClientIpStore from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore from .devices import DeviceStore
@ -112,8 +113,8 @@ class DataStore(
MonthlyActiveUsersStore, MonthlyActiveUsersStore,
StatsStore, StatsStore,
RelationsStore, RelationsStore,
CacheInvalidationStore,
UIAuthStore, UIAuthStore,
CacheInvalidationWorkerStore,
): ):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
self.hs = hs self.hs = hs
@ -170,8 +171,14 @@ class DataStore(
) )
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator( self._cache_id_gen = MultiWriterIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id" db_conn,
database,
instance_name="master",
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
) )
else: else:
self._cache_id_gen = None self._cache_id_gen = None

View file

@ -16,11 +16,10 @@
import itertools import itertools
import logging import logging
from typing import Any, Iterable, Optional, Tuple from typing import Any, Iterable, Optional
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore): class CacheInvalidationWorkerStore(SQLBaseStore):
def get_all_updated_caches(self, last_id, current_id, limit): def __init__(self, database: Database, db_conn, hs):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
):
"""Fetches cache invalidation rows between the two given IDs written
by the given instance. Returns at most `limit` rows.
"""
if last_id == current_id: if last_id == current_id:
return defer.succeed([]) return []
def get_all_updated_caches_txn(txn): def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to # We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache # send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine. # invalidations are idempotent, so duplicates are fine.
sql = ( sql = """
"SELECT stream_id, cache_func, keys, invalidation_ts" SELECT stream_id, cache_func, keys, invalidation_ts
" FROM cache_invalidation_stream" FROM cache_invalidation_stream_by_instance
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" WHERE stream_id > ? AND instance_name = ?
) ORDER BY stream_id ASC
txn.execute(sql, (last_id, limit)) LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall() return txn.fetchall()
return self.db.runInteraction( return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
class CacheInvalidationStore(CacheInvalidationWorkerStore): for row in rows:
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): if row.cache_func == CURRENT_STATE_CACHE_NAME:
"""Invalidates the cache and adds it to the cache stream so slaves if row.keys is None:
will know to invalidate their caches. raise Exception(
"Can't send an 'invalidate all' for current state cache"
)
This should only be used to invalidate caches where slaves won't room_id = row.keys[0]
otherwise know from other replication streams that the cache should members_changed = set(row.keys[1:])
be invalidated. self._invalidate_state_caches(room_id, members_changed)
""" else:
cache_func = getattr(self, cache_name, None) self._attempt_to_invalidate_cache(row.cache_func, row.keys)
if not cache_func:
return
cache_func.invalidate(keys) super().process_replication_rows(stream_name, instance_name, token, rows)
await self.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
keys,
)
def _invalidate_cache_and_stream(self, txn, cache_func, keys): def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
# the transaction. However, we want to only get an ID when we want # the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have # to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes. # __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next() stream_id = self._cache_id_gen.get_next_txn(txn)
stream_id = ctx.__enter__()
txn.call_on_exception(ctx.__exit__, None, None, None)
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data) txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None: if keys is not None:
@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
self.db.simple_insert_txn( self.db.simple_insert_txn(
txn, txn,
table="cache_invalidation_stream", table="cache_invalidation_stream_by_instance",
values={ values={
"stream_id": stream_id, "stream_id": stream_id,
"instance_name": self._instance_name,
"cache_func": cache_name, "cache_func": cache_name,
"keys": keys, "keys": keys,
"invalidation_ts": self.clock.time_msec(), "invalidation_ts": self.clock.time_msec(),
}, },
) )
def get_cache_stream_token(self): def get_cache_stream_token(self, instance_name):
if self._cache_id_gen: if self._cache_id_gen:
return self._cache_id_gen.get_current_token() return self._cache_id_gen.get_current_token(instance_name)
else: else:
return 0 return 0

View file

@ -0,0 +1,30 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We keep the old table here to enable us to roll back. It doesn't matter
-- that we have dropped all the data here.
TRUNCATE cache_invalidation_stream;
CREATE TABLE cache_invalidation_stream_by_instance (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
cache_func TEXT NOT NULL,
keys TEXT[],
invalidation_ts BIGINT
);
CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
CREATE SEQUENCE cache_invalidation_stream_seq;

View file

@ -29,6 +29,8 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
# XXX: If you're about to bump this to 59 (or higher) please create an update
# that drops the unused `cache_invalidation_stream` table, as per #7436!
SCHEMA_VERSION = 58 SCHEMA_VERSION = 58
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))