Merge branch 'madlittlemods/msc3575-sliding-sync-0.0.1' into madlittlemods/msc3575-sliding-sync-filtering

This commit is contained in:
Eric Eastwood 2024-05-29 22:48:14 -05:00
commit f74cc3f166
30 changed files with 589 additions and 532 deletions

1
changelog.d/17083.misc Normal file
View file

@ -0,0 +1 @@
Improve DB usage when fetching related events.

1
changelog.d/17226.misc Normal file
View file

@ -0,0 +1 @@
Move towards using `MultiWriterIdGenerator` everywhere.

1
changelog.d/17238.misc Normal file
View file

@ -0,0 +1 @@
Change the `allow_unsafe_locale` config option to also apply when setting up new databases.

1
changelog.d/17239.misc Normal file
View file

@ -0,0 +1 @@
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.

1
changelog.d/17240.bugfix Normal file
View file

@ -0,0 +1 @@
Ignore attempts to send to-device messages to bad users, to avoid log spam when we try to connect to the bad server.

1
changelog.d/17241.bugfix Normal file
View file

@ -0,0 +1 @@
Fix handling of duplicate concurrent uploading of device one-time-keys.

View file

@ -242,12 +242,11 @@ host all all ::1/128 ident
### Fixing incorrect `COLLATE` or `CTYPE` ### Fixing incorrect `COLLATE` or `CTYPE`
Synapse will refuse to set up a new database if it has the wrong values of Synapse will refuse to start when using a database with incorrect values of
`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the `database` section of the config, is set to true. Using different locales can
`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from cause issues if the locale library is updated from underneath the database, or
underneath the database, or if a different version of the locale is used on any if a different version of the locale is used on any replicas.
replicas.
If you have a database with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with If you have a database with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
the correct locale parameter (as shown above). It is also possible to change the the correct locale parameter (as shown above). It is also possible to change the

View file

@ -236,6 +236,13 @@ class DeviceMessageHandler:
local_messages = {} local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
if not UserID.is_valid(user_id):
logger.warning(
"Ignoring attempt to send device message to invalid user: %r",
user_id,
)
continue
# add an opentracing log entry for each message # add an opentracing log entry for each message
for device_id, message_content in by_device.items(): for device_id, message_content in by_device.items():
log_kv( log_kv(

View file

@ -53,6 +53,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
class E2eKeysHandler: class E2eKeysHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.config = hs.config self.config = hs.config
@ -62,6 +65,7 @@ class E2eKeysHandler:
self._appservice_handler = hs.get_application_service_handler() self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._worker_lock_handler = hs.get_worker_locks_handler()
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
@ -855,6 +859,12 @@ class E2eKeysHandler:
async def _upload_one_time_keys_for_user( async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None: ) -> None:
# We take out a lock so that we don't have to worry about a client
# sending duplicate requests.
lock_key = f"{user_id}_{device_id}"
async with self._worker_lock_handler.acquire_lock(
ONE_TIME_KEY_UPLOAD, lock_key
):
logger.info( logger.info(
"Adding one_time_keys %r for device %r for user %r at %d", "Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(), one_time_keys.keys(),
@ -893,7 +903,9 @@ class E2eKeysHandler:
) )
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) await self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
)
async def upload_signing_keys_for_user( async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict self, user_id: str, keys: JsonDict

View file

@ -393,9 +393,9 @@ class RelationsHandler:
# Attempt to find another event to use as the latest event. # Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event( potential_events, _ = await self._main_store.get_relations_for_event(
room_id,
event_id, event_id,
event, event,
room_id,
RelationTypes.THREAD, RelationTypes.THREAD,
direction=Direction.FORWARDS, direction=Direction.FORWARDS,
) )

View file

@ -650,7 +650,7 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id) file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish): async with self.media_storage.store_into_file(file_info) as (f, fname):
try: try:
length, headers = await self.client.download_media( length, headers = await self.client.download_media(
server_name, server_name,
@ -693,8 +693,6 @@ class MediaRepository:
) )
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
await finish()
if b"Content-Type" in headers: if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii") media_type = headers[b"Content-Type"][0].decode("ascii")
else: else:
@ -1045,14 +1043,9 @@ class MediaRepository:
), ),
) )
with self.media_storage.store_into_file(file_info) as ( async with self.media_storage.store_into_file(file_info) as (f, fname):
f,
fname,
finish,
):
try: try:
await self.media_storage.write_to_file(t_byte_source, f) await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally: finally:
t_byte_source.close() t_byte_source.close()

View file

@ -27,10 +27,9 @@ from typing import (
IO, IO,
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable, AsyncIterator,
BinaryIO, BinaryIO,
Callable, Callable,
Generator,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
@ -97,11 +96,9 @@ class MediaStorage:
the file path written to in the primary media store the file path written to in the primary media store
""" """
with self.store_into_file(file_info) as (f, fname, finish_cb): async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository # Write to the main media repository
await self.write_to_file(source, f) await self.write_to_file(source, f)
# Write to the other storage providers
await finish_cb()
return fname return fname
@ -111,32 +108,27 @@ class MediaStorage:
await defer_to_thread(self.reactor, _write_file_synchronously, source, output) await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@trace_with_opname("MediaStorage.store_into_file") @trace_with_opname("MediaStorage.store_into_file")
@contextlib.contextmanager @contextlib.asynccontextmanager
def store_into_file( async def store_into_file(
self, file_info: FileInfo self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: ) -> AsyncIterator[Tuple[BinaryIO, str]]:
"""Context manager used to get a file like object to write into, as """Async Context manager used to get a file like object to write into, as
described by file_info. described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file Actually yields a 2-tuple (file, fname,), where file is a file
like object that can be written to, fname is the absolute path of file like object that can be written to and fname is the absolute path of file
on disk, and finish_cb is a function that returns an awaitable. on disk.
fname can be used to read the contents from after upload, e.g. to fname can be used to read the contents from after upload, e.g. to
generate thumbnails. generate thumbnails.
finish_cb must be called and waited on after the file has been successfully been
written to. Should not be called if there was an error. Checks for spam and
stores the file into the configured storage providers.
Args: Args:
file_info: Info about the file to store file_info: Info about the file to store
Example: Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb): async with media_storage.store_into_file(info) as (f, fname,):
# .. write into f ... # .. write into f ...
await finish_cb()
""" """
path = self._file_info_to_path(file_info) path = self._file_info_to_path(file_info)
@ -145,29 +137,30 @@ class MediaStorage:
dirname = os.path.dirname(fname) dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
finished_called = [False]
main_media_repo_write_trace_scope = start_active_span( main_media_repo_write_trace_scope = start_active_span(
"writing to main media repo" "writing to main media repo"
) )
main_media_repo_write_trace_scope.__enter__() main_media_repo_write_trace_scope.__enter__()
with main_media_repo_write_trace_scope:
try: try:
with open(fname, "wb") as f: with open(fname, "wb") as f:
yield f, fname
async def finish() -> None: except Exception as e:
# When someone calls finish, we assume they are done writing to the main media repo try:
main_media_repo_write_trace_scope.__exit__(None, None, None) os.remove(fname)
except Exception:
pass
raise e from None
with start_active_span("writing to other storage providers"): with start_active_span("writing to other storage providers"):
# Ensure that all writes have been flushed and close the spam_check = (
# file. await self._spam_checker_module_callbacks.check_media_file_for_spam(
f.flush()
f.close()
spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info ReadableFileWrapper(self.clock, fname), file_info
) )
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker") logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the # Note that we'll delete the stored media, due to the
@ -181,27 +174,6 @@ class MediaStorage:
with start_active_span(str(provider)): with start_active_span(str(provider)):
await provider.store_file(path, file_info) await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
except Exception as e:
try:
main_media_repo_write_trace_scope.__exit__(
type(e), None, e.__traceback__
)
os.remove(fname)
except Exception:
pass
raise e from None
if not finished_called:
exc = Exception("Finished callback not called")
main_media_repo_write_trace_scope.__exit__(
type(exc), None, exc.__traceback__
)
raise exc
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache """Attempts to fetch media described by file_info from the local cache
and configured storage providers. and configured storage providers.

View file

@ -592,7 +592,7 @@ class UrlPreviewer:
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
with self.media_storage.store_into_file(file_info) as (f, fname, finish): async with self.media_storage.store_into_file(file_info) as (f, fname):
if url.startswith("data:"): if url.startswith("data:"):
if not allow_data_urls: if not allow_data_urls:
raise SynapseError( raise SynapseError(
@ -603,8 +603,6 @@ class UrlPreviewer:
else: else:
download_result = await self._download_url(url, f) download_result = await self._download_url(url, f)
await finish()
try: try:
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()

View file

@ -2461,7 +2461,11 @@ class DatabasePool:
def make_in_list_sql_clause( def make_in_list_sql_clause(
database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any] database_engine: BaseDatabaseEngine,
column: str,
iterable: Collection[Any],
*,
negative: bool = False,
) -> Tuple[str, list]: ) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable. """Returns an SQL clause that checks the given column is in the iterable.
@ -2474,6 +2478,7 @@ def make_in_list_sql_clause(
database_engine database_engine
column: Name of the column column: Name of the column
iterable: The values to check the column against. iterable: The values to check the column against.
negative: Whether we should check for inequality, i.e. `NOT IN`
Returns: Returns:
A tuple of SQL query and the args A tuple of SQL query and the args
@ -2482,9 +2487,19 @@ def make_in_list_sql_clause(
if database_engine.supports_using_any_list: if database_engine.supports_using_any_list:
# This should hopefully be faster, but also makes postgres query # This should hopefully be faster, but also makes postgres query
# stats easier to understand. # stats easier to understand.
return "%s = ANY(?)" % (column,), [list(iterable)] if not negative:
clause = f"{column} = ANY(?)"
else: else:
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) clause = f"{column} != ALL(?)"
return clause, [list(iterable)]
else:
params = ",".join("?" for _ in iterable)
if not negative:
clause = f"{column} IN ({params})"
else:
clause = f"{column} NOT IN ({params})"
return clause, list(iterable)
# These overloads ensure that `columns` and `iterable` values have the same length. # These overloads ensure that `columns` and `iterable` values have the same length.

View file

@ -43,11 +43,9 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator, AbstractStreamIdGenerator,
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator,
) )
from synapse.types import JsonDict, JsonMapping from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder from synapse.util import json_encoder
@ -75,7 +73,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._account_data_id_gen: AbstractStreamIdGenerator self._account_data_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._account_data_id_gen = MultiWriterIdGenerator( self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
@ -90,22 +87,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sequence_name="account_data_sequence", sequence_name="account_data_sequence",
writers=hs.config.worker.writers.account_data, writers=hs.config.worker.writers.account_data,
) )
else:
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
self._account_data_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[
("account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
is_writer=self._instance_name in hs.config.worker.writers.account_data,
)
account_max = self.get_max_account_data_stream_id() account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(

View file

@ -318,7 +318,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined] self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
# Caches which might leak edits must be invalidated for the event being # Caches which might leak edits must be invalidated for the event being
# redacted. # redacted.
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache(
"get_relations_for_event",
(
room_id,
redacts,
),
)
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
@ -345,7 +351,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
) )
if relates_to: if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) self._attempt_to_invalidate_cache(
"get_relations_for_event",
(
room_id,
relates_to,
),
)
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
@ -380,9 +392,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache( self._attempt_to_invalidate_cache(
"get_unread_event_push_actions_by_room_for_user", (room_id,) "get_unread_event_push_actions_by_room_for_user", (room_id,)
) )
self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
self._attempt_to_invalidate_cache("get_relations_for_event", None)
self._attempt_to_invalidate_cache("get_applicable_edit", None) self._attempt_to_invalidate_cache("get_applicable_edit", None)
self._attempt_to_invalidate_cache("get_thread_id", None) self._attempt_to_invalidate_cache("get_thread_id", None)
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None) self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)

View file

@ -50,11 +50,9 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
make_in_list_sql_clause, make_in_list_sql_clause,
) )
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator, AbstractStreamIdGenerator,
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator,
) )
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
@ -89,13 +87,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
) )
if isinstance(database.engine, PostgresEngine):
self._can_write_to_device = ( self._can_write_to_device = (
self._instance_name in hs.config.worker.writers.to_device self._instance_name in hs.config.worker.writers.to_device
) )
self._to_device_msg_id_gen: AbstractStreamIdGenerator = ( self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(), notifier=hs.get_replication_notifier(),
@ -108,16 +104,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
sequence_name="device_inbox_sequence", sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device, writers=hs.config.worker.writers.to_device,
) )
)
else:
self._can_write_to_device = True
self._to_device_msg_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_inbox",
"stream_id",
extra_tables=[("device_federation_outbox", "stream_id")],
)
max_device_inbox_id = self._to_device_msg_id_gen.get_current_token() max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(

View file

@ -1923,7 +1923,12 @@ class PersistEventsStore:
# Any relation information for the related event must be cleared. # Any relation information for the related event must be cleared.
self.store._invalidate_cache_and_stream( self.store._invalidate_cache_and_stream(
txn, self.store.get_relations_for_event, (redacted_relates_to,) txn,
self.store.get_relations_for_event,
(
room_id,
redacted_relates_to,
),
) )
if rel_type == RelationTypes.REFERENCE: if rel_type == RelationTypes.REFERENCE:
self.store._invalidate_cache_and_stream( self.store._invalidate_cache_and_stream(

View file

@ -1181,7 +1181,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
results = list(txn) results = list(txn)
# (event_id, parent_id, rel_type) for each relation # (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = [] relations_to_insert: List[Tuple[str, str, str, str]] = []
for event_id, event_json_raw in results: for event_id, event_json_raw in results:
try: try:
event_json = db_to_json(event_json_raw) event_json = db_to_json(event_json_raw)
@ -1214,7 +1214,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if not isinstance(parent_id, str): if not isinstance(parent_id, str):
continue continue
relations_to_insert.append((event_id, parent_id, rel_type)) room_id = event_json["room_id"]
relations_to_insert.append((room_id, event_id, parent_id, rel_type))
# Insert the missing data, note that we upsert here in case the event # Insert the missing data, note that we upsert here in case the event
# has already been processed. # has already been processed.
@ -1223,18 +1224,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn, txn=txn,
table="event_relations", table="event_relations",
key_names=("event_id",), key_names=("event_id",),
key_values=[(r[0],) for r in relations_to_insert], key_values=[(r[1],) for r in relations_to_insert],
value_names=("relates_to_id", "relation_type"), value_names=("relates_to_id", "relation_type"),
value_values=[r[1:] for r in relations_to_insert], value_values=[r[2:] for r in relations_to_insert],
) )
# Iterate the parent IDs and invalidate caches. # Iterate the parent IDs and invalidate caches.
cache_tuples = {(r[1],) for r in relations_to_insert}
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined] self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined] txn,
self.get_relations_for_event, # type: ignore[attr-defined]
{
(
r[0], # room_id
r[2], # parent_id
)
for r in relations_to_insert
},
) )
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined] self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined] txn,
self.get_thread_summary, # type: ignore[attr-defined]
{(r[1],) for r in relations_to_insert},
) )
if results: if results:

View file

@ -75,12 +75,10 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator, AbstractStreamIdGenerator,
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator,
) )
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
@ -195,9 +193,7 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen: AbstractStreamIdGenerator self._stream_id_gen: AbstractStreamIdGenerator
self._backfill_id_gen: AbstractStreamIdGenerator self._backfill_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
self._stream_id_gen = MultiWriterIdGenerator( self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
@ -219,27 +215,6 @@ class EventsWorkerStore(SQLBaseStore):
positive=False, positive=False,
writers=hs.config.worker.writers.events, writers=hs.config.worker.writers.events,
) )
else:
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
self._stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
@ -309,27 +284,17 @@ class EventsWorkerStore(SQLBaseStore):
self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator( self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(), notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_event_stream", stream_name="un_partial_stated_event_stream",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
tables=[ tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")],
("un_partial_stated_event_stream", "instance_name", "stream_id")
],
sequence_name="un_partial_stated_event_stream_sequence", sequence_name="un_partial_stated_event_stream_sequence",
# TODO(faster_joins, multiple writers) Support multiple writers. # TODO(faster_joins, multiple writers) Support multiple writers.
writers=["master"], writers=["master"],
) )
else:
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_event_stream",
"stream_id",
)
def get_un_partial_stated_events_token(self, instance_name: str) -> int: def get_un_partial_stated_events_token(self, instance_name: str) -> int:
return ( return (

View file

@ -40,13 +40,11 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel from synapse.storage.engines._base import IsolationLevel
from synapse.storage.types import Connection from synapse.storage.types import Connection
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator, AbstractStreamIdGenerator,
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator,
) )
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -91,7 +89,6 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._instance_name in hs.config.worker.writers.presence self._instance_name in hs.config.worker.writers.presence
) )
if isinstance(database.engine, PostgresEngine):
self._presence_id_gen = MultiWriterIdGenerator( self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
@ -102,10 +99,6 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
sequence_name="presence_stream_sequence", sequence_name="presence_stream_sequence",
writers=hs.config.worker.writers.presence, writers=hs.config.worker.writers.presence,
) )
else:
self._presence_id_gen = StreamIdGenerator(
db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
)
self.hs = hs self.hs = hs
self._presence_on_startup = self._get_active_presence(db_conn) self._presence_on_startup = self._get_active_presence(db_conn)

View file

@ -44,12 +44,10 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel from synapse.storage.engines._base import IsolationLevel
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator, AbstractStreamIdGenerator,
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator,
) )
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
@ -80,7 +78,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
# class below that is used on the main process. # class below that is used on the main process.
self._receipts_id_gen: AbstractStreamIdGenerator self._receipts_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = ( self._can_write_to_receipts = (
self._instance_name in hs.config.worker.writers.receipts self._instance_name in hs.config.worker.writers.receipts
) )
@ -95,20 +92,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
sequence_name="receipts_sequence", sequence_name="receipts_sequence",
writers=hs.config.worker.writers.receipts, writers=hs.config.worker.writers.receipts,
) )
else:
self._can_write_to_receipts = True
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
self._receipts_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
)
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)

View file

@ -169,9 +169,9 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(uncached_args=("event",), tree=True) @cached(uncached_args=("event",), tree=True)
async def get_relations_for_event( async def get_relations_for_event(
self, self,
room_id: str,
event_id: str, event_id: str,
event: EventBase, event: EventBase,
room_id: str,
relation_type: Optional[str] = None, relation_type: Optional[str] = None,
event_type: Optional[str] = None, event_type: Optional[str] = None,
limit: int = 5, limit: int = 5,

View file

@ -58,13 +58,11 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator, AbstractStreamIdGenerator,
IdGenerator, IdGenerator,
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator,
) )
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.util import json_encoder from synapse.util import json_encoder
@ -155,27 +153,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn, db_conn=db_conn,
db=database, db=database,
notifier=hs.get_replication_notifier(), notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream", stream_name="un_partial_stated_room_stream",
instance_name=self._instance_name, instance_name=self._instance_name,
tables=[ tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")],
("un_partial_stated_room_stream", "instance_name", "stream_id")
],
sequence_name="un_partial_stated_room_stream_sequence", sequence_name="un_partial_stated_room_stream_sequence",
# TODO(faster_joins, multiple writers) Support multiple writers. # TODO(faster_joins, multiple writers) Support multiple writers.
writers=["master"], writers=["master"],
) )
else:
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_room_stream",
"stream_id",
)
def process_replication_position( def process_replication_position(
self, stream_name: str, instance_name: str, token: int self, stream_name: str, instance_name: str, token: int

View file

@ -142,6 +142,10 @@ class PostgresEngine(
apply stricter checks on new databases versus existing database. apply stricter checks on new databases versus existing database.
""" """
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
if allow_unsafe_locale:
return
collation, ctype = self.get_db_locale(txn) collation, ctype = self.get_db_locale(txn)
errors = [] errors = []
@ -155,7 +159,9 @@ class PostgresEngine(
if errors: if errors:
raise IncorrectDatabaseSetup( raise IncorrectDatabaseSetup(
"Database is incorrectly configured:\n\n%s\n\n" "Database is incorrectly configured:\n\n%s\n\n"
"See docs/postgres.md for more information." % ("\n".join(errors)) "See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
"\n".join(errors),
) )
def convert_param_style(self, sql: str) -> str: def convert_param_style(self, sql: str) -> str:

View file

@ -53,9 +53,11 @@ from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
make_in_list_sql_clause,
) )
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator from synapse.storage.util.sequence import build_sequence_generator
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.notifier import ReplicationNotifier from synapse.notifier import ReplicationNotifier
@ -432,7 +434,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# no active writes in progress. # no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id self._max_position_of_local_instance = self._max_seen_allocated_stream_id
self._sequence_gen = PostgresSequenceGenerator(sequence_name) # This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
self._sequence_gen = build_sequence_generator(
db_conn=db_conn,
database_engine=db.engine,
get_first_callback=lambda _: self._persisted_upto_position,
sequence_name=sequence_name,
# We only need to set the below if we want it to call
# `check_consistency`, but we do that ourselves below so we can
# leave them blank.
table=None,
id_column=None,
stream_name=None,
positive=positive,
)
# We check that the table and sequence haven't diverged. # We check that the table and sequence haven't diverged.
for table, _, id_column in tables: for table, _, id_column in tables:
@ -444,9 +461,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive=positive, positive=positive,
) )
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
self._max_seen_allocated_stream_id = max( self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1 self._current_positions.values(), default=1
) )
@ -480,13 +494,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# important if we add back a writer after a long time; we want to # important if we add back a writer after a long time; we want to
# consider that a "new" writer, rather than using the old stale # consider that a "new" writer, rather than using the old stale
# entry here. # entry here.
sql = """ clause, args = make_in_list_sql_clause(
self._db.engine, "instance_name", self._writers, negative=True
)
sql = f"""
DELETE FROM stream_positions DELETE FROM stream_positions
WHERE WHERE
stream_name = ? stream_name = ?
AND instance_name != ALL(?) AND {clause}
""" """
cur.execute(sql, (self._stream_name, self._writers)) cur.execute(sql, [self._stream_name] + args)
sql = """ sql = """
SELECT instance_name, stream_id FROM stream_positions SELECT instance_name, stream_id FROM stream_positions
@ -508,12 +526,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# We add a GREATEST here to ensure that the result is always # We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where # positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled). # the server has never backfilled).
greatest_func = (
"GREATEST" if isinstance(self._db.engine, PostgresEngine) else "MAX"
)
max_stream_id = 1 max_stream_id = 1
for table, _, id_column in tables: for table, _, id_column in tables:
sql = """ sql = """
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) SELECT %(greatest_func)s(COALESCE(%(agg)s(%(id)s), 1), 1)
FROM %(table)s FROM %(table)s
""" % { """ % {
"greatest_func": greatest_func,
"id": id_column, "id": id_column,
"table": table, "table": table,
"agg": "MAX" if self._positive else "-MIN", "agg": "MAX" if self._positive else "-MIN",
@ -913,6 +935,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# We upsert the value, ensuring on conflict that we always increase the # We upsert the value, ensuring on conflict that we always increase the
# value (or decrease if stream goes backwards). # value (or decrease if stream goes backwards).
if isinstance(self._db.engine, PostgresEngine):
agg = "GREATEST" if self._positive else "LEAST"
else:
agg = "MAX" if self._positive else "MIN"
sql = """ sql = """
INSERT INTO stream_positions (stream_name, instance_name, stream_id) INSERT INTO stream_positions (stream_name, instance_name, stream_id)
VALUES (?, ?, ?) VALUES (?, ?, ?)
@ -920,10 +947,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
DO UPDATE SET DO UPDATE SET
stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
""" % { """ % {
"agg": "GREATEST" if self._positive else "LEAST", "agg": agg,
} }
pos = (self.get_current_token_for_writer(self._instance_name),) pos = self.get_current_token_for_writer(self._instance_name)
txn.execute(sql, (self._stream_name, self._instance_name, pos)) txn.execute(sql, (self._stream_name, self._instance_name, pos))

View file

@ -93,13 +93,13 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404. # from a regular 404.
file_id = "abcdefg12345" file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f, media_storage = hs.get_media_repository().media_storage
fname,
finish, ctx = media_storage.store_into_file(file_info)
): (f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG) f.write(SMALL_PNG)
self.get_success(finish()) self.get_success(ctx.__aexit__(None, None, None))
self.get_success( self.get_success(
self.store.store_cached_remote_media( self.store.store_cached_remote_media(

View file

@ -1204,3 +1204,79 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
class SlidingSyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding MSC3575 Sliding Sync `/sync` endpoint.
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync"
def test_sync_list(self) -> None:
"""
Test that room IDs show up in the Sliding Sync lists
"""
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(alice_user_id, "correcthorse")
room_id = self.helper.create_room_as(
alice_user_id, tok=alice_access_token, is_public=True
)
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {
"foo-list": {
"ranges": [[0, 99]],
"sort": ["by_notification_level", "by_recency", "by_name"],
"required_state": [
["m.room.join_rules", ""],
["m.room.history_visibility", ""],
["m.space.child", "*"],
],
"timeline_limit": 1,
}
}
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Make sure it has the foo-list we requested
self.assertListEqual(
list(channel.json_body["lists"].keys()),
["foo-list"],
channel.json_body["lists"].keys(),
)
# Make sure the list includes the room we are joined to
self.assertListEqual(
list(channel.json_body["lists"]["foo-list"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [room_id],
}
],
channel.json_body["lists"]["foo-list"],
)

View file

@ -44,13 +44,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404. # from a regular 404.
file_id = "abcdefg12345" file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f, media_storage = hs.get_media_repository().media_storage
fname,
finish, ctx = media_storage.store_into_file(file_info)
): (f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG) f.write(SMALL_PNG)
self.get_success(finish()) self.get_success(ctx.__aexit__(None, None, None))
self.get_success( self.get_success(
self.store.store_cached_remote_media( self.store.store_cached_remote_media(

View file

@ -31,6 +31,11 @@ from synapse.storage.database import (
from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import (
LocalSequenceGenerator,
PostgresSequenceGenerator,
SequenceGenerator,
)
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -175,18 +180,22 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
self.get_success(test_gen_next()) self.get_success(test_gen_next())
class MultiWriterIdGeneratorTestCase(HomeserverTestCase): class MultiWriterIdGeneratorBase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
if USE_POSTGRES_FOR_TESTS:
self.seq_gen: SequenceGenerator = PostgresSequenceGenerator("foobar_seq")
else:
self.seq_gen = LocalSequenceGenerator(lambda _: 0)
def _setup_db(self, txn: LoggingTransaction) -> None: def _setup_db(self, txn: LoggingTransaction) -> None:
if USE_POSTGRES_FOR_TESTS:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute( txn.execute(
""" """
CREATE TABLE foobar ( CREATE TABLE foobar (
@ -221,44 +230,27 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def _insert(txn: LoggingTransaction) -> None: def _insert(txn: LoggingTransaction) -> None:
for _ in range(number): for _ in range(number):
next_val = self.seq_gen.get_next_id_txn(txn)
txn.execute( txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
(instance_name,),
)
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match.
"""
def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
( (
stream_id, next_val,
instance_name, instance_name,
), ),
) )
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
txn.execute( txn.execute(
""" """
INSERT INTO stream_positions VALUES ('test_stream', ?, ?) INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""", """,
(instance_name, stream_id, stream_id), (instance_name, next_val, next_val),
) )
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
def test_empty(self) -> None: def test_empty(self) -> None:
"""Test an ID generator against an empty database gives sensible """Test an ID generator against an empty database gives sensible
current positions. current positions.
@ -347,6 +339,176 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out
of order updates are handled correctly.
"""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Persist two rows at once
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
self.assertEqual(s1, 8)
self.assertEqual(s2, 9)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# We finish persisting the second row before restart
self.get_success(ctx2.__aexit__(None, None, None))
# We simulate a restart of another worker by just creating a new ID gen.
id_gen_worker = self._create_id_generator("worker")
# Restarted worker should not see the second persisted row
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
# Now if we persist the first row then both instances should jump ahead
# correctly.
self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match.
"""
def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
(
stream_id,
instance_name,
),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, stream_id, stream_id),
)
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
# The following tests are a bit cheeky in that we notify about new
# positions via `advance` without *actually* advancing the postgres
# sequence.
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
# Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
# We advance "first" straight to 6. Min is now 5 but there is no gap so
# we expect it to be 6
id_gen.advance("first", 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# No gap, so we expect 7.
id_gen.advance("second", 7)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# We haven't seen 8 yet, so we expect 7 still.
id_gen.advance("second", 9)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# Now that we've seen 7, 8 and 9 we can got straight to 9.
id_gen.advance("first", 8)
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
# Jump forward with gaps. The minimum is 11, even though we haven't seen
# 10 we know that everything before 11 must be persisted.
id_gen.advance("first", 11)
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called.
"""
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
def test_multi_instance(self) -> None: def test_multi_instance(self) -> None:
"""Test that reads and writes from multiple processes are handled """Test that reads and writes from multiple processes are handled
correctly. correctly.
@ -453,145 +615,6 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8} third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
) )
def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
# The following tests are a bit cheeky in that we notify about new
# positions via `advance` without *actually* advancing the postgres
# sequence.
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
# Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
# We advance "first" straight to 6. Min is now 5 but there is no gap so
# we expect it to be 6
id_gen.advance("first", 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# No gap, so we expect 7.
id_gen.advance("second", 7)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# We haven't seen 8 yet, so we expect 7 still.
id_gen.advance("second", 9)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# Now that we've seen 7, 8 and 9 we can got straight to 9.
id_gen.advance("first", 8)
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
# Jump forward with gaps. The minimum is 11, even though we haven't seen
# 10 we know that everything before 11 must be persisted.
id_gen.advance("first", 11)
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called.
"""
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out
of order updates are handled correctly.
"""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Persist two rows at once
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
self.assertEqual(s1, 8)
self.assertEqual(s2, 9)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# We finish persisting the second row before restart
self.get_success(ctx2.__aexit__(None, None, None))
# We simulate a restart of another worker by just creating a new ID gen.
id_gen_worker = self._create_id_generator("worker")
# Restarted worker should not see the second persisted row
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
# Now if we persist the first row then both instances should jump ahead
# correctly.
self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
def test_writer_config_change(self) -> None: def test_writer_config_change(self) -> None:
"""Test that changing the writer config correctly works.""" """Test that changing the writer config correctly works."""