mirror of
https://github.com/element-hq/synapse
synced 2024-07-15 10:34:05 +00:00
Merge branch 'erikj/repl_notifieri' into erikj/fix_wait_for_stream
This commit is contained in:
commit
bc8136dd81
1
changelog.d/14844.misc
Normal file
1
changelog.d/14844.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add check to avoid starting duplicate partial state syncs.
|
1
changelog.d/14875.docker
Normal file
1
changelog.d/14875.docker
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Bump default Python version in the Dockerfile from 3.9 to 3.11.
|
1
changelog.d/14877.misc
Normal file
1
changelog.d/14877.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Always notify replication when a stream advances automatically.
|
|
@ -20,7 +20,7 @@
|
||||||
# `poetry export | pip install -r /dev/stdin`, but beware: we have experienced bugs in
|
# `poetry export | pip install -r /dev/stdin`, but beware: we have experienced bugs in
|
||||||
# in `poetry export` in the past.
|
# in `poetry export` in the past.
|
||||||
|
|
||||||
ARG PYTHON_VERSION=3.9
|
ARG PYTHON_VERSION=3.11
|
||||||
|
|
||||||
###
|
###
|
||||||
### Stage 0: generate requirements.txt
|
### Stage 0: generate requirements.txt
|
||||||
|
@ -34,11 +34,11 @@ FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as requirements
|
||||||
# Here we use it to set up a cache for apt (and below for pip), to improve
|
# Here we use it to set up a cache for apt (and below for pip), to improve
|
||||||
# rebuild speeds on slow connections.
|
# rebuild speeds on slow connections.
|
||||||
RUN \
|
RUN \
|
||||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||||
apt-get update -qq && apt-get install -yqq \
|
apt-get update -qq && apt-get install -yqq \
|
||||||
build-essential git libffi-dev libssl-dev \
|
build-essential git libffi-dev libssl-dev \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# We install poetry in its own build stage to avoid its dependencies conflicting with
|
# We install poetry in its own build stage to avoid its dependencies conflicting with
|
||||||
# synapse's dependencies.
|
# synapse's dependencies.
|
||||||
|
@ -64,9 +64,9 @@ ARG TEST_ONLY_IGNORE_POETRY_LOCKFILE
|
||||||
# Otherwise, just create an empty requirements file so that the Dockerfile can
|
# Otherwise, just create an empty requirements file so that the Dockerfile can
|
||||||
# proceed.
|
# proceed.
|
||||||
RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
|
RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
|
||||||
/root/.local/bin/poetry export --extras all -o /synapse/requirements.txt ${TEST_ONLY_SKIP_DEP_HASH_VERIFICATION:+--without-hashes}; \
|
/root/.local/bin/poetry export --extras all -o /synapse/requirements.txt ${TEST_ONLY_SKIP_DEP_HASH_VERIFICATION:+--without-hashes}; \
|
||||||
else \
|
else \
|
||||||
touch /synapse/requirements.txt; \
|
touch /synapse/requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
###
|
###
|
||||||
|
@ -76,24 +76,24 @@ FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as builder
|
||||||
|
|
||||||
# install the OS build deps
|
# install the OS build deps
|
||||||
RUN \
|
RUN \
|
||||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||||
apt-get update -qq && apt-get install -yqq \
|
apt-get update -qq && apt-get install -yqq \
|
||||||
build-essential \
|
build-essential \
|
||||||
libffi-dev \
|
libffi-dev \
|
||||||
libjpeg-dev \
|
libjpeg-dev \
|
||||||
libpq-dev \
|
libpq-dev \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
libwebp-dev \
|
libwebp-dev \
|
||||||
libxml++2.6-dev \
|
libxml++2.6-dev \
|
||||||
libxslt1-dev \
|
libxslt1-dev \
|
||||||
openssl \
|
openssl \
|
||||||
zlib1g-dev \
|
zlib1g-dev \
|
||||||
git \
|
git \
|
||||||
curl \
|
curl \
|
||||||
libicu-dev \
|
libicu-dev \
|
||||||
pkg-config \
|
pkg-config \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
||||||
# Install rust and ensure its in the PATH
|
# Install rust and ensure its in the PATH
|
||||||
|
@ -134,9 +134,9 @@ ARG TEST_ONLY_IGNORE_POETRY_LOCKFILE
|
||||||
RUN --mount=type=cache,target=/synapse/target,sharing=locked \
|
RUN --mount=type=cache,target=/synapse/target,sharing=locked \
|
||||||
--mount=type=cache,target=${CARGO_HOME}/registry,sharing=locked \
|
--mount=type=cache,target=${CARGO_HOME}/registry,sharing=locked \
|
||||||
if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
|
if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
|
||||||
pip install --prefix="/install" --no-deps --no-warn-script-location /synapse[all]; \
|
pip install --prefix="/install" --no-deps --no-warn-script-location /synapse[all]; \
|
||||||
else \
|
else \
|
||||||
pip install --prefix="/install" --no-warn-script-location /synapse[all]; \
|
pip install --prefix="/install" --no-warn-script-location /synapse[all]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
###
|
###
|
||||||
|
@ -151,20 +151,20 @@ LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git
|
||||||
LABEL org.opencontainers.image.licenses='Apache-2.0'
|
LABEL org.opencontainers.image.licenses='Apache-2.0'
|
||||||
|
|
||||||
RUN \
|
RUN \
|
||||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||||
apt-get update -qq && apt-get install -yqq \
|
apt-get update -qq && apt-get install -yqq \
|
||||||
curl \
|
curl \
|
||||||
gosu \
|
gosu \
|
||||||
libjpeg62-turbo \
|
libjpeg62-turbo \
|
||||||
libpq5 \
|
libpq5 \
|
||||||
libwebp6 \
|
libwebp6 \
|
||||||
xmlsec1 \
|
xmlsec1 \
|
||||||
libjemalloc2 \
|
libjemalloc2 \
|
||||||
libicu67 \
|
libicu67 \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
openssl \
|
openssl \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY --from=builder /install /usr/local
|
COPY --from=builder /install /usr/local
|
||||||
COPY ./docker/start.py /start.py
|
COPY ./docker/start.py /start.py
|
||||||
|
@ -175,4 +175,4 @@ EXPOSE 8008/tcp 8009/tcp 8448/tcp
|
||||||
ENTRYPOINT ["/start.py"]
|
ENTRYPOINT ["/start.py"]
|
||||||
|
|
||||||
HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \
|
HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \
|
||||||
CMD curl -fSs http://localhost:8008/health || exit 1
|
CMD curl -fSs http://localhost:8008/health || exit 1
|
||||||
|
|
|
@ -51,6 +51,7 @@ from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
run_in_background,
|
run_in_background,
|
||||||
)
|
)
|
||||||
|
from synapse.notifier import ReplicationNotifier
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
|
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
|
||||||
from synapse.storage.databases.main import PushRuleStore
|
from synapse.storage.databases.main import PushRuleStore
|
||||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||||
|
@ -260,6 +261,9 @@ class MockHomeserver:
|
||||||
def should_send_federation(self) -> bool:
|
def should_send_federation(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_replication_notifier(self) -> ReplicationNotifier:
|
||||||
|
return ReplicationNotifier()
|
||||||
|
|
||||||
|
|
||||||
class Porter:
|
class Porter:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -27,6 +27,7 @@ from typing import (
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
@ -171,12 +172,23 @@ class FederationHandler:
|
||||||
|
|
||||||
self.third_party_event_rules = hs.get_third_party_event_rules()
|
self.third_party_event_rules = hs.get_third_party_event_rules()
|
||||||
|
|
||||||
|
# Tracks running partial state syncs by room ID.
|
||||||
|
# Partial state syncs currently only run on the main process, so it's okay to
|
||||||
|
# track them in-memory for now.
|
||||||
|
self._active_partial_state_syncs: Set[str] = set()
|
||||||
|
# Tracks partial state syncs we may want to restart.
|
||||||
|
# A dictionary mapping room IDs to (initial destination, other destinations)
|
||||||
|
# tuples.
|
||||||
|
self._partial_state_syncs_maybe_needing_restart: Dict[
|
||||||
|
str, Tuple[Optional[str], Collection[str]]
|
||||||
|
] = {}
|
||||||
|
|
||||||
# if this is the main process, fire off a background process to resume
|
# if this is the main process, fire off a background process to resume
|
||||||
# any partial-state-resync operations which were in flight when we
|
# any partial-state-resync operations which were in flight when we
|
||||||
# were shut down.
|
# were shut down.
|
||||||
if not hs.config.worker.worker_app:
|
if not hs.config.worker.worker_app:
|
||||||
run_as_background_process(
|
run_as_background_process(
|
||||||
"resume_sync_partial_state_room", self._resume_sync_partial_state_room
|
"resume_sync_partial_state_room", self._resume_partial_state_room_sync
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
@ -679,9 +691,7 @@ class FederationHandler:
|
||||||
if ret.partial_state:
|
if ret.partial_state:
|
||||||
# Kick off the process of asynchronously fetching the state for this
|
# Kick off the process of asynchronously fetching the state for this
|
||||||
# room.
|
# room.
|
||||||
run_as_background_process(
|
self._start_partial_state_room_sync(
|
||||||
desc="sync_partial_state_room",
|
|
||||||
func=self._sync_partial_state_room,
|
|
||||||
initial_destination=origin,
|
initial_destination=origin,
|
||||||
other_destinations=ret.servers_in_room,
|
other_destinations=ret.servers_in_room,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -1660,20 +1670,100 @@ class FederationHandler:
|
||||||
# well.
|
# well.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _resume_sync_partial_state_room(self) -> None:
|
async def _resume_partial_state_room_sync(self) -> None:
|
||||||
"""Resumes resyncing of all partial-state rooms after a restart."""
|
"""Resumes resyncing of all partial-state rooms after a restart."""
|
||||||
assert not self.config.worker.worker_app
|
assert not self.config.worker.worker_app
|
||||||
|
|
||||||
partial_state_rooms = await self.store.get_partial_state_room_resync_info()
|
partial_state_rooms = await self.store.get_partial_state_room_resync_info()
|
||||||
for room_id, resync_info in partial_state_rooms.items():
|
for room_id, resync_info in partial_state_rooms.items():
|
||||||
run_as_background_process(
|
self._start_partial_state_room_sync(
|
||||||
desc="sync_partial_state_room",
|
|
||||||
func=self._sync_partial_state_room,
|
|
||||||
initial_destination=resync_info.joined_via,
|
initial_destination=resync_info.joined_via,
|
||||||
other_destinations=resync_info.servers_in_room,
|
other_destinations=resync_info.servers_in_room,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _start_partial_state_room_sync(
|
||||||
|
self,
|
||||||
|
initial_destination: Optional[str],
|
||||||
|
other_destinations: Collection[str],
|
||||||
|
room_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Starts the background process to resync the state of a partial state room,
|
||||||
|
if it is not already running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_destination: the initial homeserver to pull the state from
|
||||||
|
other_destinations: other homeservers to try to pull the state from, if
|
||||||
|
`initial_destination` is unavailable
|
||||||
|
room_id: room to be resynced
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _sync_partial_state_room_wrapper() -> None:
|
||||||
|
if room_id in self._active_partial_state_syncs:
|
||||||
|
# Another local user has joined the room while there is already a
|
||||||
|
# partial state sync running. This implies that there is a new join
|
||||||
|
# event to un-partial state. We might find ourselves in one of a few
|
||||||
|
# scenarios:
|
||||||
|
# 1. There is an existing partial state sync. The partial state sync
|
||||||
|
# un-partial states the new join event before completing and all is
|
||||||
|
# well.
|
||||||
|
# 2. Before the latest join, the homeserver was no longer in the room
|
||||||
|
# and there is an existing partial state sync from our previous
|
||||||
|
# membership of the room. The partial state sync may have:
|
||||||
|
# a) succeeded, but not yet terminated. The room will not be
|
||||||
|
# un-partial stated again unless we restart the partial state
|
||||||
|
# sync.
|
||||||
|
# b) failed, because we were no longer in the room and remote
|
||||||
|
# homeservers were refusing our requests, but not yet
|
||||||
|
# terminated. After the latest join, remote homeservers may
|
||||||
|
# start answering our requests again, so we should restart the
|
||||||
|
# partial state sync.
|
||||||
|
# In the cases where we would want to restart the partial state sync,
|
||||||
|
# the room would have the partial state flag when the partial state sync
|
||||||
|
# terminates.
|
||||||
|
self._partial_state_syncs_maybe_needing_restart[room_id] = (
|
||||||
|
initial_destination,
|
||||||
|
other_destinations,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._active_partial_state_syncs.add(room_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._sync_partial_state_room(
|
||||||
|
initial_destination=initial_destination,
|
||||||
|
other_destinations=other_destinations,
|
||||||
|
room_id=room_id,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Read the room's partial state flag while we still hold the claim to
|
||||||
|
# being the active partial state sync (so that another partial state
|
||||||
|
# sync can't come along and mess with it under us).
|
||||||
|
# Normally, the partial state flag will be gone. If it isn't, then we
|
||||||
|
# may find ourselves in scenario 2a or 2b as described in the comment
|
||||||
|
# above, where we want to restart the partial state sync.
|
||||||
|
is_still_partial_state_room = await self.store.is_partial_state_room(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
self._active_partial_state_syncs.remove(room_id)
|
||||||
|
|
||||||
|
if room_id in self._partial_state_syncs_maybe_needing_restart:
|
||||||
|
(
|
||||||
|
restart_initial_destination,
|
||||||
|
restart_other_destinations,
|
||||||
|
) = self._partial_state_syncs_maybe_needing_restart.pop(room_id)
|
||||||
|
|
||||||
|
if is_still_partial_state_room:
|
||||||
|
self._start_partial_state_room_sync(
|
||||||
|
initial_destination=restart_initial_destination,
|
||||||
|
other_destinations=restart_other_destinations,
|
||||||
|
room_id=room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
run_as_background_process(
|
||||||
|
desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
async def _sync_partial_state_room(
|
async def _sync_partial_state_room(
|
||||||
self,
|
self,
|
||||||
initial_destination: Optional[str],
|
initial_destination: Optional[str],
|
||||||
|
|
|
@ -226,8 +226,7 @@ class Notifier:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
|
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
|
||||||
|
|
||||||
# Called when there are new things to stream over replication
|
self._replication_notifier = hs.get_replication_notifier()
|
||||||
self.replication_callbacks: List[Callable[[], None]] = []
|
|
||||||
self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
|
self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
|
||||||
|
|
||||||
self._federation_client = hs.get_federation_http_client()
|
self._federation_client = hs.get_federation_http_client()
|
||||||
|
@ -279,7 +278,7 @@ class Notifier:
|
||||||
it needs to do any asynchronous work, a background thread should be started and
|
it needs to do any asynchronous work, a background thread should be started and
|
||||||
wrapped with run_as_background_process.
|
wrapped with run_as_background_process.
|
||||||
"""
|
"""
|
||||||
self.replication_callbacks.append(cb)
|
self._replication_notifier.add_replication_callback(cb)
|
||||||
|
|
||||||
def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
|
def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
|
||||||
"""Add a callback that will be called when a user joins a room.
|
"""Add a callback that will be called when a user joins a room.
|
||||||
|
@ -741,8 +740,7 @@ class Notifier:
|
||||||
|
|
||||||
def notify_replication(self) -> None:
|
def notify_replication(self) -> None:
|
||||||
"""Notify the any replication listeners that there's a new event"""
|
"""Notify the any replication listeners that there's a new event"""
|
||||||
for cb in self.replication_callbacks:
|
self._replication_notifier.notify_replication()
|
||||||
cb()
|
|
||||||
|
|
||||||
def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
|
def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
|
||||||
for cb in self._new_join_in_room_callbacks:
|
for cb in self._new_join_in_room_callbacks:
|
||||||
|
@ -759,3 +757,26 @@ class Notifier:
|
||||||
# Tell the federation client about the fact the server is back up, so
|
# Tell the federation client about the fact the server is back up, so
|
||||||
# that any in flight requests can be immediately retried.
|
# that any in flight requests can be immediately retried.
|
||||||
self._federation_client.wake_destination(server)
|
self._federation_client.wake_destination(server)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class ReplicationNotifier:
|
||||||
|
"""Tracks callbacks for things that need to know about stream changes.
|
||||||
|
|
||||||
|
This is separate from the notifier to avoid circular dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_replication_callbacks: List[Callable[[], None]] = attr.Factory(list)
|
||||||
|
|
||||||
|
def add_replication_callback(self, cb: Callable[[], None]) -> None:
|
||||||
|
"""Add a callback that will be called when some new data is available.
|
||||||
|
Callback is not given any arguments. It should *not* return a Deferred - if
|
||||||
|
it needs to do any asynchronous work, a background thread should be started and
|
||||||
|
wrapped with run_as_background_process.
|
||||||
|
"""
|
||||||
|
self._replication_callbacks.append(cb)
|
||||||
|
|
||||||
|
def notify_replication(self) -> None:
|
||||||
|
"""Notify the any replication listeners that there's a new event"""
|
||||||
|
for cb in self._replication_callbacks:
|
||||||
|
cb()
|
||||||
|
|
|
@ -107,7 +107,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
|
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.notifier import Notifier
|
from synapse.notifier import Notifier, ReplicationNotifier
|
||||||
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||||
from synapse.push.pusherpool import PusherPool
|
from synapse.push.pusherpool import PusherPool
|
||||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||||
|
@ -389,6 +389,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def get_notifier(self) -> Notifier:
|
def get_notifier(self) -> Notifier:
|
||||||
return Notifier(self)
|
return Notifier(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_replication_notifier(self) -> ReplicationNotifier:
|
||||||
|
return ReplicationNotifier()
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_auth(self) -> Auth:
|
def get_auth(self) -> Auth:
|
||||||
return Auth(self)
|
return Auth(self)
|
||||||
|
|
|
@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
self._account_data_id_gen = MultiWriterIdGenerator(
|
self._account_data_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="account_data",
|
stream_name="account_data",
|
||||||
instance_name=self._instance_name,
|
instance_name=self._instance_name,
|
||||||
tables=[
|
tables=[
|
||||||
|
@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
||||||
# SQLite).
|
# SQLite).
|
||||||
self._account_data_id_gen = StreamIdGenerator(
|
self._account_data_id_gen = StreamIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
"room_account_data",
|
"room_account_data",
|
||||||
"stream_id",
|
"stream_id",
|
||||||
extra_tables=[("room_tags_revisions", "stream_id")],
|
extra_tables=[("room_tags_revisions", "stream_id")],
|
||||||
|
|
|
@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
self._cache_id_gen = MultiWriterIdGenerator(
|
self._cache_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
database,
|
database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="caches",
|
stream_name="caches",
|
||||||
instance_name=hs.get_instance_name(),
|
instance_name=hs.get_instance_name(),
|
||||||
tables=[
|
tables=[
|
||||||
|
|
|
@ -91,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
MultiWriterIdGenerator(
|
MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="to_device",
|
stream_name="to_device",
|
||||||
instance_name=self._instance_name,
|
instance_name=self._instance_name,
|
||||||
tables=[("device_inbox", "instance_name", "stream_id")],
|
tables=[("device_inbox", "instance_name", "stream_id")],
|
||||||
|
@ -101,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
self._can_write_to_device = True
|
self._can_write_to_device = True
|
||||||
self._device_inbox_id_gen = StreamIdGenerator(
|
self._device_inbox_id_gen = StreamIdGenerator(
|
||||||
db_conn, "device_inbox", "stream_id"
|
db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
|
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
|
||||||
|
|
|
@ -92,6 +92,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
# class below that is used on the main process.
|
# class below that is used on the main process.
|
||||||
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
|
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
"device_lists_stream",
|
"device_lists_stream",
|
||||||
"stream_id",
|
"stream_id",
|
||||||
extra_tables=[
|
extra_tables=[
|
||||||
|
|
|
@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._cross_signing_id_gen = StreamIdGenerator(
|
self._cross_signing_id_gen = StreamIdGenerator(
|
||||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
|
"e2e_cross_signing_keys",
|
||||||
|
"stream_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_e2e_device_keys(
|
async def set_e2e_device_keys(
|
||||||
|
|
|
@ -191,6 +191,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
self._stream_id_gen = MultiWriterIdGenerator(
|
self._stream_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="events",
|
stream_name="events",
|
||||||
instance_name=hs.get_instance_name(),
|
instance_name=hs.get_instance_name(),
|
||||||
tables=[("events", "instance_name", "stream_ordering")],
|
tables=[("events", "instance_name", "stream_ordering")],
|
||||||
|
@ -200,6 +201,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="backfill",
|
stream_name="backfill",
|
||||||
instance_name=hs.get_instance_name(),
|
instance_name=hs.get_instance_name(),
|
||||||
tables=[("events", "instance_name", "stream_ordering")],
|
tables=[("events", "instance_name", "stream_ordering")],
|
||||||
|
@ -217,12 +219,14 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
# SQLite).
|
# SQLite).
|
||||||
self._stream_id_gen = StreamIdGenerator(
|
self._stream_id_gen = StreamIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
"events",
|
"events",
|
||||||
"stream_ordering",
|
"stream_ordering",
|
||||||
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
|
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
|
||||||
)
|
)
|
||||||
self._backfill_id_gen = StreamIdGenerator(
|
self._backfill_id_gen = StreamIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
"events",
|
"events",
|
||||||
"stream_ordering",
|
"stream_ordering",
|
||||||
step=-1,
|
step=-1,
|
||||||
|
@ -300,6 +304,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
|
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="un_partial_stated_event_stream",
|
stream_name="un_partial_stated_event_stream",
|
||||||
instance_name=hs.get_instance_name(),
|
instance_name=hs.get_instance_name(),
|
||||||
tables=[
|
tables=[
|
||||||
|
@ -311,7 +316,10 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
|
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
|
||||||
db_conn, "un_partial_stated_event_stream", "stream_id"
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
|
"un_partial_stated_event_stream",
|
||||||
|
"stream_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_un_partial_stated_events_token(self) -> int:
|
def get_un_partial_stated_events_token(self) -> int:
|
||||||
|
|
|
@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
||||||
self._presence_id_gen = MultiWriterIdGenerator(
|
self._presence_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="presence_stream",
|
stream_name="presence_stream",
|
||||||
instance_name=self._instance_name,
|
instance_name=self._instance_name,
|
||||||
tables=[("presence_stream", "instance_name", "stream_id")],
|
tables=[("presence_stream", "instance_name", "stream_id")],
|
||||||
|
@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._presence_id_gen = StreamIdGenerator(
|
self._presence_id_gen = StreamIdGenerator(
|
||||||
db_conn, "presence_stream", "stream_id"
|
db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
|
@ -118,6 +118,7 @@ class PushRulesWorkerStore(
|
||||||
# class below that is used on the main process.
|
# class below that is used on the main process.
|
||||||
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
|
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
"push_rules_stream",
|
"push_rules_stream",
|
||||||
"stream_id",
|
"stream_id",
|
||||||
is_writer=hs.config.worker.worker_app is None,
|
is_writer=hs.config.worker.worker_app is None,
|
||||||
|
|
|
@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore):
|
||||||
# class below that is used on the main process.
|
# class below that is used on the main process.
|
||||||
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
|
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
"pushers",
|
"pushers",
|
||||||
"id",
|
"id",
|
||||||
extra_tables=[("deleted_pushers", "stream_id")],
|
extra_tables=[("deleted_pushers", "stream_id")],
|
||||||
|
|
|
@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
self._receipts_id_gen = MultiWriterIdGenerator(
|
self._receipts_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="receipts",
|
stream_name="receipts",
|
||||||
instance_name=self._instance_name,
|
instance_name=self._instance_name,
|
||||||
tables=[("receipts_linearized", "instance_name", "stream_id")],
|
tables=[("receipts_linearized", "instance_name", "stream_id")],
|
||||||
|
@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
# SQLite).
|
# SQLite).
|
||||||
self._receipts_id_gen = StreamIdGenerator(
|
self._receipts_id_gen = StreamIdGenerator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
"receipts_linearized",
|
"receipts_linearized",
|
||||||
"stream_id",
|
"stream_id",
|
||||||
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
|
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
|
||||||
|
|
|
@ -126,6 +126,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
db=database,
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
stream_name="un_partial_stated_room_stream",
|
stream_name="un_partial_stated_room_stream",
|
||||||
instance_name=self._instance_name,
|
instance_name=self._instance_name,
|
||||||
tables=[
|
tables=[
|
||||||
|
@ -137,7 +138,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
|
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
|
||||||
db_conn, "un_partial_stated_room_stream", "stream_id"
|
db_conn,
|
||||||
|
hs.get_replication_notifier(),
|
||||||
|
"un_partial_stated_room_stream",
|
||||||
|
"stream_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def store_room(
|
async def store_room(
|
||||||
|
|
|
@ -20,6 +20,7 @@ from collections import OrderedDict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
Dict,
|
Dict,
|
||||||
|
@ -49,6 +50,9 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.sequence import PostgresSequenceGenerator
|
from synapse.storage.util.sequence import PostgresSequenceGenerator
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.notifier import ReplicationNotifier
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_conn: LoggingDatabaseConnection,
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
notifier: "ReplicationNotifier",
|
||||||
table: str,
|
table: str,
|
||||||
column: str,
|
column: str,
|
||||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
extra_tables: Iterable[Tuple[str, str]] = (),
|
||||||
|
@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||||
# The key and values are the same, but we never look at the values.
|
# The key and values are the same, but we never look at the values.
|
||||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
||||||
|
|
||||||
|
self._notifier = notifier
|
||||||
|
|
||||||
def advance(self, instance_name: str, new_id: int) -> None:
|
def advance(self, instance_name: str, new_id: int) -> None:
|
||||||
# Advance should never be called on a writer instance, only over replication
|
# Advance should never be called on a writer instance, only over replication
|
||||||
if self._is_writer:
|
if self._is_writer:
|
||||||
|
@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._unfinished_ids.pop(next_id)
|
self._unfinished_ids.pop(next_id)
|
||||||
|
|
||||||
|
self._notifier.notify_replication()
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||||
|
@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||||
for next_id in next_ids:
|
for next_id in next_ids:
|
||||||
self._unfinished_ids.pop(next_id)
|
self._unfinished_ids.pop(next_id)
|
||||||
|
|
||||||
|
self._notifier.notify_replication()
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
def get_current_token(self) -> int:
|
def get_current_token(self) -> int:
|
||||||
|
@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
self,
|
self,
|
||||||
db_conn: LoggingDatabaseConnection,
|
db_conn: LoggingDatabaseConnection,
|
||||||
db: DatabasePool,
|
db: DatabasePool,
|
||||||
|
notifier: "ReplicationNotifier",
|
||||||
stream_name: str,
|
stream_name: str,
|
||||||
instance_name: str,
|
instance_name: str,
|
||||||
tables: List[Tuple[str, str, str]],
|
tables: List[Tuple[str, str, str]],
|
||||||
|
@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
positive: bool = True,
|
positive: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
|
self._notifier = notifier
|
||||||
self._stream_name = stream_name
|
self._stream_name = stream_name
|
||||||
self._instance_name = instance_name
|
self._instance_name = instance_name
|
||||||
self._positive = positive
|
self._positive = positive
|
||||||
|
@ -535,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
|
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
|
||||||
# controls the return type. If `None` or omitted, the context manager yields
|
# controls the return type. If `None` or omitted, the context manager yields
|
||||||
# a single integer stream_id; otherwise it yields a list of stream_ids.
|
# a single integer stream_id; otherwise it yields a list of stream_ids.
|
||||||
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
|
return cast(
|
||||||
|
AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
|
||||||
|
)
|
||||||
|
|
||||||
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
|
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
|
||||||
# If we have a list of instances that are allowed to write to this
|
# If we have a list of instances that are allowed to write to this
|
||||||
|
@ -544,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
raise Exception("Tried to allocate stream ID on non-writer")
|
raise Exception("Tried to allocate stream ID on non-writer")
|
||||||
|
|
||||||
# Cast safety: see get_next.
|
# Cast safety: see get_next.
|
||||||
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
|
return cast(
|
||||||
|
AsyncContextManager[List[int]],
|
||||||
|
_MultiWriterCtxManager(self, self._notifier, n),
|
||||||
|
)
|
||||||
|
|
||||||
def get_next_txn(self, txn: LoggingTransaction) -> int:
|
def get_next_txn(self, txn: LoggingTransaction) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -563,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
|
|
||||||
txn.call_after(self._mark_id_as_finished, next_id)
|
txn.call_after(self._mark_id_as_finished, next_id)
|
||||||
txn.call_on_exception(self._mark_id_as_finished, next_id)
|
txn.call_on_exception(self._mark_id_as_finished, next_id)
|
||||||
|
txn.call_after(self._notifier.notify_replication)
|
||||||
|
|
||||||
# Update the `stream_positions` table with newly updated stream
|
# Update the `stream_positions` table with newly updated stream
|
||||||
# ID (unless self._writers is not set in which case we don't
|
# ID (unless self._writers is not set in which case we don't
|
||||||
|
@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
|
||||||
"""Async context manager returned by MultiWriterIdGenerator"""
|
"""Async context manager returned by MultiWriterIdGenerator"""
|
||||||
|
|
||||||
id_gen: MultiWriterIdGenerator
|
id_gen: MultiWriterIdGenerator
|
||||||
|
notifier: "ReplicationNotifier"
|
||||||
multiple_ids: Optional[int] = None
|
multiple_ids: Optional[int] = None
|
||||||
stream_ids: List[int] = attr.Factory(list)
|
stream_ids: List[int] = attr.Factory(list)
|
||||||
|
|
||||||
|
@ -814,6 +834,8 @@ class _MultiWriterCtxManager:
|
||||||
for i in self.stream_ids:
|
for i in self.stream_ids:
|
||||||
self.id_gen._mark_id_as_finished(i)
|
self.id_gen._mark_id_as_finished(i)
|
||||||
|
|
||||||
|
self.notifier.notify_replication()
|
||||||
|
|
||||||
if exc_type is not None:
|
if exc_type is not None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -12,10 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import cast
|
from typing import Collection, Optional, cast
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
f"Stale partial-stated room flag left over for {room_id} after a"
|
f"Stale partial-stated room flag left over for {room_id} after a"
|
||||||
f" failed do_invite_join!",
|
f" failed do_invite_join!",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_duplicate_partial_state_room_syncs(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests that concurrent partial state syncs are not started for the same room.
|
||||||
|
"""
|
||||||
|
is_partial_state = True
|
||||||
|
end_sync: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
|
async def is_partial_state_room(room_id: str) -> bool:
|
||||||
|
return is_partial_state
|
||||||
|
|
||||||
|
async def sync_partial_state_room(
|
||||||
|
initial_destination: Optional[str],
|
||||||
|
other_destinations: Collection[str],
|
||||||
|
room_id: str,
|
||||||
|
) -> None:
|
||||||
|
nonlocal end_sync
|
||||||
|
try:
|
||||||
|
await end_sync
|
||||||
|
finally:
|
||||||
|
end_sync = Deferred()
|
||||||
|
|
||||||
|
mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
|
||||||
|
mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
|
||||||
|
|
||||||
|
fed_handler = self.hs.get_federation_handler()
|
||||||
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
|
||||||
|
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
|
||||||
|
# Start the partial state sync.
|
||||||
|
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
|
# Try to start another partial state sync.
|
||||||
|
# Nothing should happen.
|
||||||
|
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
|
# End the partial state sync
|
||||||
|
is_partial_state = False
|
||||||
|
end_sync.callback(None)
|
||||||
|
|
||||||
|
# The partial state sync should not be restarted.
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
|
# The next attempt to start the partial state sync should work.
|
||||||
|
is_partial_state = True
|
||||||
|
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
||||||
|
|
||||||
|
def test_partial_state_room_sync_restart(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests that partial state syncs are restarted when a second partial state sync
|
||||||
|
was deduplicated and the first partial state sync fails.
|
||||||
|
"""
|
||||||
|
is_partial_state = True
|
||||||
|
end_sync: "Deferred[None]" = Deferred()
|
||||||
|
|
||||||
|
async def is_partial_state_room(room_id: str) -> bool:
|
||||||
|
return is_partial_state
|
||||||
|
|
||||||
|
async def sync_partial_state_room(
|
||||||
|
initial_destination: Optional[str],
|
||||||
|
other_destinations: Collection[str],
|
||||||
|
room_id: str,
|
||||||
|
) -> None:
|
||||||
|
nonlocal end_sync
|
||||||
|
try:
|
||||||
|
await end_sync
|
||||||
|
finally:
|
||||||
|
end_sync = Deferred()
|
||||||
|
|
||||||
|
mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
|
||||||
|
mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
|
||||||
|
|
||||||
|
fed_handler = self.hs.get_federation_handler()
|
||||||
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
|
||||||
|
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
|
||||||
|
# Start the partial state sync.
|
||||||
|
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
|
# Fail the partial state sync.
|
||||||
|
# The partial state sync should not be restarted.
|
||||||
|
end_sync.errback(Exception("Failed to request /state_ids"))
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
|
# Start the partial state sync again.
|
||||||
|
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
||||||
|
|
||||||
|
# Deduplicate another partial state sync.
|
||||||
|
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
||||||
|
|
||||||
|
# Fail the partial state sync.
|
||||||
|
# It should restart with the latest parameters.
|
||||||
|
end_sync.errback(Exception("Failed to request /state_ids"))
|
||||||
|
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
|
||||||
|
mock_sync_partial_state_room.assert_called_with(
|
||||||
|
initial_destination="hs3",
|
||||||
|
other_destinations=["hs2"],
|
||||||
|
room_id="room_id",
|
||||||
|
)
|
||||||
|
|
|
@ -404,6 +404,9 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||||
self.module_api.send_local_online_presence_to([remote_user_id])
|
self.module_api.send_local_online_presence_to([remote_user_id])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We don't always send out federation immediately, so we advance the clock.
|
||||||
|
self.reactor.advance(1000)
|
||||||
|
|
||||||
# Check that a presence update was sent as part of a federation transaction
|
# Check that a presence update was sent as part of a federation transaction
|
||||||
found_update = False
|
found_update = False
|
||||||
calls = (
|
calls = (
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.replication.tcp.commands import PositionCommand, RdataCommand
|
from synapse.replication.tcp.commands import PositionCommand
|
||||||
|
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
|
|
||||||
|
@ -111,20 +111,14 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
next_token = self.get_success(ctx.__aenter__())
|
next_token = self.get_success(ctx.__aenter__())
|
||||||
self.get_success(ctx.__aexit__(None, None, None))
|
self.get_success(ctx.__aexit__(None, None, None))
|
||||||
|
|
||||||
cmd_handler.send_command(
|
|
||||||
RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
|
|
||||||
)
|
|
||||||
self.replicate()
|
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
data_handler.wait_for_stream_position("worker1", "caches", next_token)
|
data_handler.wait_for_stream_position("worker1", "caches", next_token)
|
||||||
)
|
)
|
||||||
|
|
||||||
# `wait_for_stream_position` should only return once master receives an
|
# `wait_for_stream_position` should only return once master receives a
|
||||||
# RDATA from the worker
|
# notification that `next_token` has persisted.
|
||||||
ctx = cache_id_gen.get_next()
|
ctx_worker1 = cache_id_gen.get_next()
|
||||||
next_token = self.get_success(ctx.__aenter__())
|
next_token = self.get_success(ctx_worker1.__aenter__())
|
||||||
self.get_success(ctx.__aexit__(None, None, None))
|
|
||||||
|
|
||||||
d = defer.ensureDeferred(
|
d = defer.ensureDeferred(
|
||||||
data_handler.wait_for_stream_position("worker1", "caches", next_token)
|
data_handler.wait_for_stream_position("worker1", "caches", next_token)
|
||||||
|
@ -142,10 +136,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(d.called)
|
self.assertFalse(d.called)
|
||||||
|
|
||||||
# ... but receiving the RDATA should
|
# ... but worker1 finishing (and so sending an update) should.
|
||||||
cmd_handler.send_command(
|
self.get_success(ctx_worker1.__aexit__(None, None, None))
|
||||||
RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
|
|
||||||
)
|
|
||||||
self.replicate()
|
|
||||||
|
|
||||||
self.assertTrue(d.called)
|
self.assertTrue(d.called)
|
||||||
|
|
|
@ -52,6 +52,7 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
|
||||||
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
|
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
|
||||||
return StreamIdGenerator(
|
return StreamIdGenerator(
|
||||||
db_conn=conn,
|
db_conn=conn,
|
||||||
|
notifier=self.hs.get_replication_notifier(),
|
||||||
table="foobar",
|
table="foobar",
|
||||||
column="stream_id",
|
column="stream_id",
|
||||||
)
|
)
|
||||||
|
@ -196,6 +197,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
return MultiWriterIdGenerator(
|
return MultiWriterIdGenerator(
|
||||||
conn,
|
conn,
|
||||||
self.db_pool,
|
self.db_pool,
|
||||||
|
notifier=self.hs.get_replication_notifier(),
|
||||||
stream_name="test_stream",
|
stream_name="test_stream",
|
||||||
instance_name=instance_name,
|
instance_name=instance_name,
|
||||||
tables=[("foobar", "instance_name", "stream_id")],
|
tables=[("foobar", "instance_name", "stream_id")],
|
||||||
|
@ -630,6 +632,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
return MultiWriterIdGenerator(
|
return MultiWriterIdGenerator(
|
||||||
conn,
|
conn,
|
||||||
self.db_pool,
|
self.db_pool,
|
||||||
|
notifier=self.hs.get_replication_notifier(),
|
||||||
stream_name="test_stream",
|
stream_name="test_stream",
|
||||||
instance_name=instance_name,
|
instance_name=instance_name,
|
||||||
tables=[("foobar", "instance_name", "stream_id")],
|
tables=[("foobar", "instance_name", "stream_id")],
|
||||||
|
@ -766,6 +769,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
return MultiWriterIdGenerator(
|
return MultiWriterIdGenerator(
|
||||||
conn,
|
conn,
|
||||||
self.db_pool,
|
self.db_pool,
|
||||||
|
notifier=self.hs.get_replication_notifier(),
|
||||||
stream_name="test_stream",
|
stream_name="test_stream",
|
||||||
instance_name=instance_name,
|
instance_name=instance_name,
|
||||||
tables=[
|
tables=[
|
||||||
|
|
Loading…
Reference in a new issue