Merge branch 'erikj/repl_notifieri' into erikj/fix_wait_for_stream

This commit is contained in:
Erik Johnston 2023-01-20 16:06:52 +00:00
commit bc8136dd81
24 changed files with 357 additions and 80 deletions

1
changelog.d/14844.misc Normal file
View file

@ -0,0 +1 @@
Add check to avoid starting duplicate partial state syncs.

1
changelog.d/14875.docker Normal file
View file

@ -0,0 +1 @@
Bump default Python version in the Dockerfile from 3.9 to 3.11.

1
changelog.d/14877.misc Normal file
View file

@ -0,0 +1 @@
Always notify replication when a stream advances automatically.

View file

@ -20,7 +20,7 @@
# `poetry export | pip install -r /dev/stdin`, but beware: we have experienced bugs in
# in `poetry export` in the past.
ARG PYTHON_VERSION=3.9
ARG PYTHON_VERSION=3.11
###
### Stage 0: generate requirements.txt
@ -34,11 +34,11 @@ FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as requirements
# Here we use it to set up a cache for apt (and below for pip), to improve
# rebuild speeds on slow connections.
RUN \
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -qq && apt-get install -yqq \
build-essential git libffi-dev libssl-dev \
&& rm -rf /var/lib/apt/lists/*
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -qq && apt-get install -yqq \
build-essential git libffi-dev libssl-dev \
&& rm -rf /var/lib/apt/lists/*
# We install poetry in its own build stage to avoid its dependencies conflicting with
# synapse's dependencies.
@ -64,9 +64,9 @@ ARG TEST_ONLY_IGNORE_POETRY_LOCKFILE
# Otherwise, just create an empty requirements file so that the Dockerfile can
# proceed.
RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
/root/.local/bin/poetry export --extras all -o /synapse/requirements.txt ${TEST_ONLY_SKIP_DEP_HASH_VERIFICATION:+--without-hashes}; \
/root/.local/bin/poetry export --extras all -o /synapse/requirements.txt ${TEST_ONLY_SKIP_DEP_HASH_VERIFICATION:+--without-hashes}; \
else \
touch /synapse/requirements.txt; \
touch /synapse/requirements.txt; \
fi
###
@ -76,24 +76,24 @@ FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as builder
# install the OS build deps
RUN \
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -qq && apt-get install -yqq \
build-essential \
libffi-dev \
libjpeg-dev \
libpq-dev \
libssl-dev \
libwebp-dev \
libxml++2.6-dev \
libxslt1-dev \
openssl \
zlib1g-dev \
git \
curl \
libicu-dev \
pkg-config \
&& rm -rf /var/lib/apt/lists/*
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -qq && apt-get install -yqq \
build-essential \
libffi-dev \
libjpeg-dev \
libpq-dev \
libssl-dev \
libwebp-dev \
libxml++2.6-dev \
libxslt1-dev \
openssl \
zlib1g-dev \
git \
curl \
libicu-dev \
pkg-config \
&& rm -rf /var/lib/apt/lists/*
# Install rust and ensure its in the PATH
@ -134,9 +134,9 @@ ARG TEST_ONLY_IGNORE_POETRY_LOCKFILE
RUN --mount=type=cache,target=/synapse/target,sharing=locked \
--mount=type=cache,target=${CARGO_HOME}/registry,sharing=locked \
if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
pip install --prefix="/install" --no-deps --no-warn-script-location /synapse[all]; \
pip install --prefix="/install" --no-deps --no-warn-script-location /synapse[all]; \
else \
pip install --prefix="/install" --no-warn-script-location /synapse[all]; \
pip install --prefix="/install" --no-warn-script-location /synapse[all]; \
fi
###
@ -151,20 +151,20 @@ LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git
LABEL org.opencontainers.image.licenses='Apache-2.0'
RUN \
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
--mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -qq && apt-get install -yqq \
curl \
gosu \
libjpeg62-turbo \
libpq5 \
libwebp6 \
xmlsec1 \
libjemalloc2 \
libicu67 \
libssl-dev \
openssl \
&& rm -rf /var/lib/apt/lists/*
curl \
gosu \
libjpeg62-turbo \
libpq5 \
libwebp6 \
xmlsec1 \
libjemalloc2 \
libicu67 \
libssl-dev \
openssl \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local
COPY ./docker/start.py /start.py
@ -175,4 +175,4 @@ EXPOSE 8008/tcp 8009/tcp 8448/tcp
ENTRYPOINT ["/start.py"]
HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \
CMD curl -fSs http://localhost:8008/health || exit 1
CMD curl -fSs http://localhost:8008/health || exit 1

View file

@ -51,6 +51,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@ -260,6 +261,9 @@ class MockHomeserver:
def should_send_federation(self) -> bool:
return False
def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()
class Porter:
def __init__(

View file

@ -27,6 +27,7 @@ from typing import (
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
@ -171,12 +172,23 @@ class FederationHandler:
self.third_party_event_rules = hs.get_third_party_event_rules()
# Tracks running partial state syncs by room ID.
# Partial state syncs currently only run on the main process, so it's okay to
# track them in-memory for now.
self._active_partial_state_syncs: Set[str] = set()
# Tracks partial state syncs we may want to restart.
# A dictionary mapping room IDs to (initial destination, other destinations)
# tuples.
self._partial_state_syncs_maybe_needing_restart: Dict[
str, Tuple[Optional[str], Collection[str]]
] = {}
# if this is the main process, fire off a background process to resume
# any partial-state-resync operations which were in flight when we
# were shut down.
if not hs.config.worker.worker_app:
run_as_background_process(
"resume_sync_partial_state_room", self._resume_sync_partial_state_room
"resume_sync_partial_state_room", self._resume_partial_state_room_sync
)
@trace
@ -679,9 +691,7 @@ class FederationHandler:
if ret.partial_state:
# Kick off the process of asynchronously fetching the state for this
# room.
run_as_background_process(
desc="sync_partial_state_room",
func=self._sync_partial_state_room,
self._start_partial_state_room_sync(
initial_destination=origin,
other_destinations=ret.servers_in_room,
room_id=room_id,
@ -1660,20 +1670,100 @@ class FederationHandler:
# well.
return None
async def _resume_sync_partial_state_room(self) -> None:
async def _resume_partial_state_room_sync(self) -> None:
"""Resumes resyncing of all partial-state rooms after a restart."""
assert not self.config.worker.worker_app
partial_state_rooms = await self.store.get_partial_state_room_resync_info()
for room_id, resync_info in partial_state_rooms.items():
run_as_background_process(
desc="sync_partial_state_room",
func=self._sync_partial_state_room,
self._start_partial_state_room_sync(
initial_destination=resync_info.joined_via,
other_destinations=resync_info.servers_in_room,
room_id=room_id,
)
def _start_partial_state_room_sync(
self,
initial_destination: Optional[str],
other_destinations: Collection[str],
room_id: str,
) -> None:
"""Starts the background process to resync the state of a partial state room,
if it is not already running.
Args:
initial_destination: the initial homeserver to pull the state from
other_destinations: other homeservers to try to pull the state from, if
`initial_destination` is unavailable
room_id: room to be resynced
"""
async def _sync_partial_state_room_wrapper() -> None:
if room_id in self._active_partial_state_syncs:
# Another local user has joined the room while there is already a
# partial state sync running. This implies that there is a new join
# event to un-partial state. We might find ourselves in one of a few
# scenarios:
# 1. There is an existing partial state sync. The partial state sync
# un-partial states the new join event before completing and all is
# well.
# 2. Before the latest join, the homeserver was no longer in the room
# and there is an existing partial state sync from our previous
# membership of the room. The partial state sync may have:
# a) succeeded, but not yet terminated. The room will not be
# un-partial stated again unless we restart the partial state
# sync.
# b) failed, because we were no longer in the room and remote
# homeservers were refusing our requests, but not yet
# terminated. After the latest join, remote homeservers may
# start answering our requests again, so we should restart the
# partial state sync.
# In the cases where we would want to restart the partial state sync,
# the room would have the partial state flag when the partial state sync
# terminates.
self._partial_state_syncs_maybe_needing_restart[room_id] = (
initial_destination,
other_destinations,
)
return
self._active_partial_state_syncs.add(room_id)
try:
await self._sync_partial_state_room(
initial_destination=initial_destination,
other_destinations=other_destinations,
room_id=room_id,
)
finally:
# Read the room's partial state flag while we still hold the claim to
# being the active partial state sync (so that another partial state
# sync can't come along and mess with it under us).
# Normally, the partial state flag will be gone. If it isn't, then we
# may find ourselves in scenario 2a or 2b as described in the comment
# above, where we want to restart the partial state sync.
is_still_partial_state_room = await self.store.is_partial_state_room(
room_id
)
self._active_partial_state_syncs.remove(room_id)
if room_id in self._partial_state_syncs_maybe_needing_restart:
(
restart_initial_destination,
restart_other_destinations,
) = self._partial_state_syncs_maybe_needing_restart.pop(room_id)
if is_still_partial_state_room:
self._start_partial_state_room_sync(
initial_destination=restart_initial_destination,
other_destinations=restart_other_destinations,
room_id=room_id,
)
run_as_background_process(
desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
)
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],

View file

@ -226,8 +226,7 @@ class Notifier:
self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
# Called when there are new things to stream over replication
self.replication_callbacks: List[Callable[[], None]] = []
self._replication_notifier = hs.get_replication_notifier()
self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
self._federation_client = hs.get_federation_http_client()
@ -279,7 +278,7 @@ class Notifier:
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
self.replication_callbacks.append(cb)
self._replication_notifier.add_replication_callback(cb)
def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
"""Add a callback that will be called when a user joins a room.
@ -741,8 +740,7 @@ class Notifier:
def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
self._replication_notifier.notify_replication()
def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
for cb in self._new_join_in_room_callbacks:
@ -759,3 +757,26 @@ class Notifier:
# Tell the federation client about the fact the server is back up, so
# that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server)
@attr.s(auto_attribs=True)
class ReplicationNotifier:
"""Tracks callbacks for things that need to know about stream changes.
This is separate from the notifier to avoid circular dependencies.
"""
_replication_callbacks: List[Callable[[], None]] = attr.Factory(list)
def add_replication_callback(self, cb: Callable[[], None]) -> None:
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
self._replication_callbacks.append(cb)
def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self._replication_callbacks:
cb()

View file

@ -107,7 +107,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
from synapse.notifier import Notifier
from synapse.notifier import Notifier, ReplicationNotifier
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
@ -389,6 +389,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_notifier(self) -> Notifier:
return Notifier(self)
@cache_in_self
def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()
@cache_in_self
def get_auth(self) -> Auth:
return Auth(self)

View file

@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="account_data",
instance_name=self._instance_name,
tables=[
@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# SQLite).
self._account_data_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],

View file

@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
notifier=hs.get_replication_notifier(),
stream_name="caches",
instance_name=hs.get_instance_name(),
tables=[

View file

@ -91,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
@ -101,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
else:
self._can_write_to_device = True
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id"
db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()

View file

@ -92,6 +92,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# class below that is used on the main process.
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_lists_stream",
"stream_id",
extra_tables=[

View file

@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
db_conn,
hs.get_replication_notifier(),
"e2e_cross_signing_keys",
"stream_id",
)
async def set_e2e_device_keys(

View file

@ -191,6 +191,7 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="events",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
@ -200,6 +201,7 @@ class EventsWorkerStore(SQLBaseStore):
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="backfill",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
@ -217,12 +219,14 @@ class EventsWorkerStore(SQLBaseStore):
# SQLite).
self._stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
step=-1,
@ -300,6 +304,7 @@ class EventsWorkerStore(SQLBaseStore):
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_event_stream",
instance_name=hs.get_instance_name(),
tables=[
@ -311,7 +316,10 @@ class EventsWorkerStore(SQLBaseStore):
)
else:
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
db_conn, "un_partial_stated_event_stream", "stream_id"
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_event_stream",
"stream_id",
)
def get_un_partial_stated_events_token(self) -> int:

View file

@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="presence_stream",
instance_name=self._instance_name,
tables=[("presence_stream", "instance_name", "stream_id")],
@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
)
else:
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
)
self.hs = hs

View file

@ -118,6 +118,7 @@ class PushRulesWorkerStore(
# class below that is used on the main process.
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
is_writer=hs.config.worker.worker_app is None,

View file

@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore):
# class below that is used on the main process.
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],

View file

@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="receipts",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# SQLite).
self._receipts_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,

View file

@ -126,6 +126,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream",
instance_name=self._instance_name,
tables=[
@ -137,7 +138,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
else:
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
db_conn, "un_partial_stated_room_stream", "stream_id"
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_room_stream",
"stream_id",
)
async def store_room(

View file

@ -20,6 +20,7 @@ from collections import OrderedDict
from contextlib import contextmanager
from types import TracebackType
from typing import (
TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Dict,
@ -49,6 +50,9 @@ from synapse.storage.database import (
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
if TYPE_CHECKING:
from synapse.notifier import ReplicationNotifier
logger = logging.getLogger(__name__)
@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def __init__(
self,
db_conn: LoggingDatabaseConnection,
notifier: "ReplicationNotifier",
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
self._notifier = notifier
def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
with self._lock:
self._unfinished_ids.pop(next_id)
self._notifier.notify_replication()
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
for next_id in next_ids:
self._unfinished_ids.pop(next_id)
self._notifier.notify_replication()
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self,
db_conn: LoggingDatabaseConnection,
db: DatabasePool,
notifier: "ReplicationNotifier",
stream_name: str,
instance_name: str,
tables: List[Tuple[str, str, str]],
@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive: bool = True,
) -> None:
self._db = db
self._notifier = notifier
self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
@ -535,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids.
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
return cast(
AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
)
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
# If we have a list of instances that are allowed to write to this
@ -544,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
raise Exception("Tried to allocate stream ID on non-writer")
# Cast safety: see get_next.
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
return cast(
AsyncContextManager[List[int]],
_MultiWriterCtxManager(self, self._notifier, n),
)
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
@ -563,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""
id_gen: MultiWriterIdGenerator
notifier: "ReplicationNotifier"
multiple_ids: Optional[int] = None
stream_ids: List[int] = attr.Factory(list)
@ -814,6 +834,8 @@ class _MultiWriterCtxManager:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
self.notifier.notify_replication()
if exc_type is not None:
return False

View file

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import cast
from typing import Collection, Optional, cast
from unittest import TestCase
from unittest.mock import Mock, patch
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
f"Stale partial-stated room flag left over for {room_id} after a"
f" failed do_invite_join!",
)
def test_duplicate_partial_state_room_syncs(self) -> None:
"""
Tests that concurrent partial state syncs are not started for the same room.
"""
is_partial_state = True
end_sync: "Deferred[None]" = Deferred()
async def is_partial_state_room(room_id: str) -> bool:
return is_partial_state
async def sync_partial_state_room(
initial_destination: Optional[str],
other_destinations: Collection[str],
room_id: str,
) -> None:
nonlocal end_sync
try:
await end_sync
finally:
end_sync = Deferred()
mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
fed_handler = self.hs.get_federation_handler()
store = self.hs.get_datastores().main
with patch.object(
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Try to start another partial state sync.
# Nothing should happen.
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# End the partial state sync
is_partial_state = False
end_sync.callback(None)
# The partial state sync should not be restarted.
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# The next attempt to start the partial state sync should work.
is_partial_state = True
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
def test_partial_state_room_sync_restart(self) -> None:
"""
Tests that partial state syncs are restarted when a second partial state sync
was deduplicated and the first partial state sync fails.
"""
is_partial_state = True
end_sync: "Deferred[None]" = Deferred()
async def is_partial_state_room(room_id: str) -> bool:
return is_partial_state
async def sync_partial_state_room(
initial_destination: Optional[str],
other_destinations: Collection[str],
room_id: str,
) -> None:
nonlocal end_sync
try:
await end_sync
finally:
end_sync = Deferred()
mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
fed_handler = self.hs.get_federation_handler()
store = self.hs.get_datastores().main
with patch.object(
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Fail the partial state sync.
# The partial state sync should not be restarted.
end_sync.errback(Exception("Failed to request /state_ids"))
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Start the partial state sync again.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Deduplicate another partial state sync.
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Fail the partial state sync.
# It should restart with the latest parameters.
end_sync.errback(Exception("Failed to request /state_ids"))
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3",
other_destinations=["hs2"],
room_id="room_id",
)

View file

@ -404,6 +404,9 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.send_local_online_presence_to([remote_user_id])
)
# We don't always send out federation immediately, so we advance the clock.
self.reactor.advance(1000)
# Check that a presence update was sent as part of a federation transaction
found_update = False
calls = (

View file

@ -14,7 +14,7 @@
from twisted.internet import defer
from synapse.replication.tcp.commands import PositionCommand, RdataCommand
from synapse.replication.tcp.commands import PositionCommand
from tests.replication._base import BaseMultiWorkerStreamTestCase
@ -111,20 +111,14 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
next_token = self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
cmd_handler.send_command(
RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
)
self.replicate()
self.get_success(
data_handler.wait_for_stream_position("worker1", "caches", next_token)
)
# `wait_for_stream_position` should only return once master receives an
# RDATA from the worker
ctx = cache_id_gen.get_next()
next_token = self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
# `wait_for_stream_position` should only return once master receives a
# notification that `next_token` has persisted.
ctx_worker1 = cache_id_gen.get_next()
next_token = self.get_success(ctx_worker1.__aenter__())
d = defer.ensureDeferred(
data_handler.wait_for_stream_position("worker1", "caches", next_token)
@ -142,10 +136,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertFalse(d.called)
# ... but receiving the RDATA should
cmd_handler.send_command(
RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
)
self.replicate()
# ... but worker1 finishing (and so sending an update) should.
self.get_success(ctx_worker1.__aexit__(None, None, None))
self.assertTrue(d.called)

View file

@ -52,6 +52,7 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
return StreamIdGenerator(
db_conn=conn,
notifier=self.hs.get_replication_notifier(),
table="foobar",
column="stream_id",
)
@ -196,6 +197,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
@ -630,6 +632,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
@ -766,6 +769,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[