Merge branch 'madlittlemods/msc3575-sliding-sync-0.0.1' into madlittlemods/msc3575-sliding-sync-filtering

This commit is contained in:
Eric Eastwood 2024-05-29 22:48:14 -05:00
commit f74cc3f166
30 changed files with 589 additions and 532 deletions

1
changelog.d/17083.misc Normal file
View file

@ -0,0 +1 @@
Improve DB usage when fetching related events.

1
changelog.d/17226.misc Normal file
View file

@ -0,0 +1 @@
Move towards using `MultiWriterIdGenerator` everywhere.

1
changelog.d/17238.misc Normal file
View file

@ -0,0 +1 @@
Change the `allow_unsafe_locale` config option to also apply when setting up new databases.

1
changelog.d/17239.misc Normal file
View file

@ -0,0 +1 @@
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.

1
changelog.d/17240.bugfix Normal file
View file

@ -0,0 +1 @@
Ignore attempts to send to-device messages to bad users, to avoid log spam when we try to connect to the bad server.

1
changelog.d/17241.bugfix Normal file
View file

@ -0,0 +1 @@
Fix handling of duplicate concurrent uploading of device one-time-keys.

View file

@ -242,12 +242,11 @@ host all all ::1/128 ident
### Fixing incorrect `COLLATE` or `CTYPE`
Synapse will refuse to set up a new database if it has the wrong values of
`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values
of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from
underneath the database, or if a different version of the locale is used on any
replicas.
Synapse will refuse to start when using a database with incorrect values of
`COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
`database` section of the config, is set to true. Using different locales can
cause issues if the locale library is updated from underneath the database, or
if a different version of the locale is used on any replicas.
If you have a database with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
the correct locale parameter (as shown above). It is also possible to change the

View file

@ -236,6 +236,13 @@ class DeviceMessageHandler:
local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
if not UserID.is_valid(user_id):
logger.warning(
"Ignoring attempt to send device message to invalid user: %r",
user_id,
)
continue
# add an opentracing log entry for each message
for device_id, message_content in by_device.items():
log_kv(

View file

@ -53,6 +53,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
@ -62,6 +65,7 @@ class E2eKeysHandler:
self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
self._worker_lock_handler = hs.get_worker_locks_handler()
federation_registry = hs.get_federation_registry()
@ -855,6 +859,12 @@ class E2eKeysHandler:
async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None:
# We take out a lock so that we don't have to worry about a client
# sending duplicate requests.
lock_key = f"{user_id}_{device_id}"
async with self._worker_lock_handler.acquire_lock(
ONE_TIME_KEY_UPLOAD, lock_key
):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
@ -893,7 +903,9 @@ class E2eKeysHandler:
)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
await self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
)
async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict

View file

@ -393,9 +393,9 @@ class RelationsHandler:
# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
room_id,
event_id,
event,
room_id,
RelationTypes.THREAD,
direction=Direction.FORWARDS,
)

View file

@ -650,7 +650,7 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers = await self.client.download_media(
server_name,
@ -693,8 +693,6 @@ class MediaRepository:
)
raise SynapseError(502, "Failed to fetch remote media")
await finish()
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
@ -1045,14 +1043,9 @@ class MediaRepository:
),
)
with self.media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally:
t_byte_source.close()

View file

@ -27,10 +27,9 @@ from typing import (
IO,
TYPE_CHECKING,
Any,
Awaitable,
AsyncIterator,
BinaryIO,
Callable,
Generator,
Optional,
Sequence,
Tuple,
@ -97,11 +96,9 @@ class MediaStorage:
the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository
await self.write_to_file(source, f)
# Write to the other storage providers
await finish_cb()
return fname
@ -111,32 +108,27 @@ class MediaStorage:
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@trace_with_opname("MediaStorage.store_into_file")
@contextlib.contextmanager
def store_into_file(
@contextlib.asynccontextmanager
async def store_into_file(
self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
) -> AsyncIterator[Tuple[BinaryIO, str]]:
"""Async Context manager used to get a file like object to write into, as
described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns an awaitable.
Actually yields a 2-tuple (file, fname,), where file is a file
like object that can be written to and fname is the absolute path of file
on disk.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
finish_cb must be called and waited on after the file has been successfully been
written to. Should not be called if there was an error. Checks for spam and
stores the file into the configured storage providers.
Args:
file_info: Info about the file to store
Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb):
async with media_storage.store_into_file(info) as (f, fname,):
# .. write into f ...
await finish_cb()
"""
path = self._file_info_to_path(file_info)
@ -145,29 +137,30 @@ class MediaStorage:
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)
finished_called = [False]
main_media_repo_write_trace_scope = start_active_span(
"writing to main media repo"
)
main_media_repo_write_trace_scope.__enter__()
with main_media_repo_write_trace_scope:
try:
with open(fname, "wb") as f:
yield f, fname
async def finish() -> None:
# When someone calls finish, we assume they are done writing to the main media repo
main_media_repo_write_trace_scope.__exit__(None, None, None)
except Exception as e:
try:
os.remove(fname)
except Exception:
pass
raise e from None
with start_active_span("writing to other storage providers"):
# Ensure that all writes have been flushed and close the
# file.
f.flush()
f.close()
spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
spam_check = (
await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
@ -181,27 +174,6 @@ class MediaStorage:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
except Exception as e:
try:
main_media_repo_write_trace_scope.__exit__(
type(e), None, e.__traceback__
)
os.remove(fname)
except Exception:
pass
raise e from None
if not finished_called:
exc = Exception("Finished callback not called")
main_media_repo_write_trace_scope.__exit__(
type(exc), None, exc.__traceback__
)
raise exc
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.

View file

@ -592,7 +592,7 @@ class UrlPreviewer:
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
async with self.media_storage.store_into_file(file_info) as (f, fname):
if url.startswith("data:"):
if not allow_data_urls:
raise SynapseError(
@ -603,8 +603,6 @@ class UrlPreviewer:
else:
download_result = await self._download_url(url, f)
await finish()
try:
time_now_ms = self.clock.time_msec()

View file

@ -2461,7 +2461,11 @@ class DatabasePool:
def make_in_list_sql_clause(
database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
database_engine: BaseDatabaseEngine,
column: str,
iterable: Collection[Any],
*,
negative: bool = False,
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
@ -2474,6 +2478,7 @@ def make_in_list_sql_clause(
database_engine
column: Name of the column
iterable: The values to check the column against.
negative: Whether we should check for inequality, i.e. `NOT IN`
Returns:
A tuple of SQL query and the args
@ -2482,9 +2487,19 @@ def make_in_list_sql_clause(
if database_engine.supports_using_any_list:
# This should hopefully be faster, but also makes postgres query
# stats easier to understand.
return "%s = ANY(?)" % (column,), [list(iterable)]
if not negative:
clause = f"{column} = ANY(?)"
else:
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
clause = f"{column} != ALL(?)"
return clause, [list(iterable)]
else:
params = ",".join("?" for _ in iterable)
if not negative:
clause = f"{column} IN ({params})"
else:
clause = f"{column} NOT IN ({params})"
return clause, list(iterable)
# These overloads ensure that `columns` and `iterable` values have the same length.

View file

@ -43,11 +43,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
@ -75,7 +73,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._account_data_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
@ -90,22 +87,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sequence_name="account_data_sequence",
writers=hs.config.worker.writers.account_data,
)
else:
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
self._account_data_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[
("account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
is_writer=self._instance_name in hs.config.worker.writers.account_data,
)
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(

View file

@ -318,7 +318,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
# Caches which might leak edits must be invalidated for the event being
# redacted.
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache(
"get_relations_for_event",
(
room_id,
redacts,
),
)
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
@ -345,7 +351,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
self._attempt_to_invalidate_cache(
"get_relations_for_event",
(
room_id,
relates_to,
),
)
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
@ -380,9 +392,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache(
"get_unread_event_push_actions_by_room_for_user", (room_id,)
)
self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
self._attempt_to_invalidate_cache("get_relations_for_event", None)
self._attempt_to_invalidate_cache("get_applicable_edit", None)
self._attempt_to_invalidate_cache("get_thread_id", None)
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)

View file

@ -50,11 +50,9 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
@ -89,13 +87,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
expiry_ms=30 * 60 * 1000,
)
if isinstance(database.engine, PostgresEngine):
self._can_write_to_device = (
self._instance_name in hs.config.worker.writers.to_device
)
self._to_device_msg_id_gen: AbstractStreamIdGenerator = (
MultiWriterIdGenerator(
self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
@ -108,16 +104,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
)
else:
self._can_write_to_device = True
self._to_device_msg_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_inbox",
"stream_id",
extra_tables=[("device_federation_outbox", "stream_id")],
)
max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(

View file

@ -1923,7 +1923,12 @@ class PersistEventsStore:
# Any relation information for the related event must be cleared.
self.store._invalidate_cache_and_stream(
txn, self.store.get_relations_for_event, (redacted_relates_to,)
txn,
self.store.get_relations_for_event,
(
room_id,
redacted_relates_to,
),
)
if rel_type == RelationTypes.REFERENCE:
self.store._invalidate_cache_and_stream(

View file

@ -1181,7 +1181,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
results = list(txn)
# (event_id, parent_id, rel_type) for each relation
relations_to_insert: List[Tuple[str, str, str]] = []
relations_to_insert: List[Tuple[str, str, str, str]] = []
for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
@ -1214,7 +1214,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if not isinstance(parent_id, str):
continue
relations_to_insert.append((event_id, parent_id, rel_type))
room_id = event_json["room_id"]
relations_to_insert.append((room_id, event_id, parent_id, rel_type))
# Insert the missing data, note that we upsert here in case the event
# has already been processed.
@ -1223,18 +1224,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="event_relations",
key_names=("event_id",),
key_values=[(r[0],) for r in relations_to_insert],
key_values=[(r[1],) for r in relations_to_insert],
value_names=("relates_to_id", "relation_type"),
value_values=[r[1:] for r in relations_to_insert],
value_values=[r[2:] for r in relations_to_insert],
)
# Iterate the parent IDs and invalidate caches.
cache_tuples = {(r[1],) for r in relations_to_insert}
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
txn,
self.get_relations_for_event, # type: ignore[attr-defined]
{
(
r[0], # room_id
r[2], # parent_id
)
for r in relations_to_insert
},
)
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
txn,
self.get_thread_summary, # type: ignore[attr-defined]
{(r[1],) for r in relations_to_insert},
)
if results:

View file

@ -75,12 +75,10 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
@ -195,9 +193,7 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen: AbstractStreamIdGenerator
self._backfill_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
@ -219,27 +215,6 @@ class EventsWorkerStore(SQLBaseStore):
positive=False,
writers=hs.config.worker.writers.events,
)
else:
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
self._stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
@ -309,27 +284,17 @@ class EventsWorkerStore(SQLBaseStore):
self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_event_stream",
instance_name=hs.get_instance_name(),
tables=[
("un_partial_stated_event_stream", "instance_name", "stream_id")
],
tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")],
sequence_name="un_partial_stated_event_stream_sequence",
# TODO(faster_joins, multiple writers) Support multiple writers.
writers=["master"],
)
else:
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_event_stream",
"stream_id",
)
def get_un_partial_stated_events_token(self, instance_name: str) -> int:
return (

View file

@ -40,13 +40,11 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -91,7 +89,6 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._instance_name in hs.config.worker.writers.presence
)
if isinstance(database.engine, PostgresEngine):
self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
@ -102,10 +99,6 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
sequence_name="presence_stream_sequence",
writers=hs.config.worker.writers.presence,
)
else:
self._presence_id_gen = StreamIdGenerator(
db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
)
self.hs = hs
self._presence_on_startup = self._get_active_presence(db_conn)

View file

@ -44,12 +44,10 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import (
JsonDict,
@ -80,7 +78,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
# class below that is used on the main process.
self._receipts_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
self._instance_name in hs.config.worker.writers.receipts
)
@ -95,20 +92,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
sequence_name="receipts_sequence",
writers=hs.config.worker.writers.receipts,
)
else:
self._can_write_to_receipts = True
# Multiple writers are not supported for SQLite.
#
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
self._receipts_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
)
super().__init__(database, db_conn, hs)

View file

@ -169,9 +169,9 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
room_id: str,
event_id: str,
event: EventBase,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,

View file

@ -58,13 +58,11 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.util import json_encoder
@ -155,27 +153,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
if isinstance(database.engine, PostgresEngine):
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream",
instance_name=self._instance_name,
tables=[
("un_partial_stated_room_stream", "instance_name", "stream_id")
],
tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")],
sequence_name="un_partial_stated_room_stream_sequence",
# TODO(faster_joins, multiple writers) Support multiple writers.
writers=["master"],
)
else:
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_room_stream",
"stream_id",
)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int

View file

@ -142,6 +142,10 @@ class PostgresEngine(
apply stricter checks on new databases versus existing database.
"""
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
if allow_unsafe_locale:
return
collation, ctype = self.get_db_locale(txn)
errors = []
@ -155,7 +159,9 @@ class PostgresEngine(
if errors:
raise IncorrectDatabaseSetup(
"Database is incorrectly configured:\n\n%s\n\n"
"See docs/postgres.md for more information." % ("\n".join(errors))
"See docs/postgres.md for more information. You can override this check by"
"setting 'allow_unsafe_locale' to true in the database config.",
"\n".join(errors),
)
def convert_param_style(self, sql: str) -> str:

View file

@ -53,9 +53,11 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
from synapse.storage.util.sequence import build_sequence_generator
if TYPE_CHECKING:
from synapse.notifier import ReplicationNotifier
@ -432,7 +434,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
self._sequence_gen = build_sequence_generator(
db_conn=db_conn,
database_engine=db.engine,
get_first_callback=lambda _: self._persisted_upto_position,
sequence_name=sequence_name,
# We only need to set the below if we want it to call
# `check_consistency`, but we do that ourselves below so we can
# leave them blank.
table=None,
id_column=None,
stream_name=None,
positive=positive,
)
# We check that the table and sequence haven't diverged.
for table, _, id_column in tables:
@ -444,9 +461,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive=positive,
)
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1
)
@ -480,13 +494,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# important if we add back a writer after a long time; we want to
# consider that a "new" writer, rather than using the old stale
# entry here.
sql = """
clause, args = make_in_list_sql_clause(
self._db.engine, "instance_name", self._writers, negative=True
)
sql = f"""
DELETE FROM stream_positions
WHERE
stream_name = ?
AND instance_name != ALL(?)
AND {clause}
"""
cur.execute(sql, (self._stream_name, self._writers))
cur.execute(sql, [self._stream_name] + args)
sql = """
SELECT instance_name, stream_id FROM stream_positions
@ -508,12 +526,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
greatest_func = (
"GREATEST" if isinstance(self._db.engine, PostgresEngine) else "MAX"
)
max_stream_id = 1
for table, _, id_column in tables:
sql = """
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
SELECT %(greatest_func)s(COALESCE(%(agg)s(%(id)s), 1), 1)
FROM %(table)s
""" % {
"greatest_func": greatest_func,
"id": id_column,
"table": table,
"agg": "MAX" if self._positive else "-MIN",
@ -913,6 +935,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# We upsert the value, ensuring on conflict that we always increase the
# value (or decrease if stream goes backwards).
if isinstance(self._db.engine, PostgresEngine):
agg = "GREATEST" if self._positive else "LEAST"
else:
agg = "MAX" if self._positive else "MIN"
sql = """
INSERT INTO stream_positions (stream_name, instance_name, stream_id)
VALUES (?, ?, ?)
@ -920,10 +947,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
DO UPDATE SET
stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
""" % {
"agg": "GREATEST" if self._positive else "LEAST",
"agg": agg,
}
pos = (self.get_current_token_for_writer(self._instance_name),)
pos = self.get_current_token_for_writer(self._instance_name)
txn.execute(sql, (self._stream_name, self._instance_name, pos))

View file

@ -93,13 +93,13 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404.
file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
media_storage = hs.get_media_repository().media_storage
ctx = media_storage.store_into_file(file_info)
(f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG)
self.get_success(finish())
self.get_success(ctx.__aexit__(None, None, None))
self.get_success(
self.store.store_cached_remote_media(

View file

@ -1204,3 +1204,79 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
class SlidingSyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding MSC3575 Sliding Sync `/sync` endpoint.
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync"
def test_sync_list(self) -> None:
"""
Test that room IDs show up in the Sliding Sync lists
"""
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(alice_user_id, "correcthorse")
room_id = self.helper.create_room_as(
alice_user_id, tok=alice_access_token, is_public=True
)
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {
"foo-list": {
"ranges": [[0, 99]],
"sort": ["by_notification_level", "by_recency", "by_name"],
"required_state": [
["m.room.join_rules", ""],
["m.room.history_visibility", ""],
["m.space.child", "*"],
],
"timeline_limit": 1,
}
}
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Make sure it has the foo-list we requested
self.assertListEqual(
list(channel.json_body["lists"].keys()),
["foo-list"],
channel.json_body["lists"].keys(),
)
# Make sure the list includes the room we are joined to
self.assertListEqual(
list(channel.json_body["lists"]["foo-list"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [room_id],
}
],
channel.json_body["lists"]["foo-list"],
)

View file

@ -44,13 +44,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404.
file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
media_storage = hs.get_media_repository().media_storage
ctx = media_storage.store_into_file(file_info)
(f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG)
self.get_success(finish())
self.get_success(ctx.__aexit__(None, None, None))
self.get_success(
self.store.store_cached_remote_media(

View file

@ -31,6 +31,11 @@ from synapse.storage.database import (
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import (
LocalSequenceGenerator,
PostgresSequenceGenerator,
SequenceGenerator,
)
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@ -175,18 +180,22 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
self.get_success(test_gen_next())
class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
class MultiWriterIdGeneratorBase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
if USE_POSTGRES_FOR_TESTS:
self.seq_gen: SequenceGenerator = PostgresSequenceGenerator("foobar_seq")
else:
self.seq_gen = LocalSequenceGenerator(lambda _: 0)
def _setup_db(self, txn: LoggingTransaction) -> None:
if USE_POSTGRES_FOR_TESTS:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
@ -221,44 +230,27 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
next_val = self.seq_gen.get_next_id_txn(txn)
txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
(instance_name,),
)
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match.
"""
def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
(
stream_id,
next_val,
instance_name,
),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, stream_id, stream_id),
(instance_name, next_val, next_val),
)
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
def test_empty(self) -> None:
"""Test an ID generator against an empty database gives sensible
current positions.
@ -347,6 +339,176 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out
of order updates are handled correctly.
"""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Persist two rows at once
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
self.assertEqual(s1, 8)
self.assertEqual(s2, 9)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# We finish persisting the second row before restart
self.get_success(ctx2.__aexit__(None, None, None))
# We simulate a restart of another worker by just creating a new ID gen.
id_gen_worker = self._create_id_generator("worker")
# Restarted worker should not see the second persisted row
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
# Now if we persist the first row then both instances should jump ahead
# correctly.
self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match.
"""
def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
(
stream_id,
instance_name,
),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, stream_id, stream_id),
)
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
# The following tests are a bit cheeky in that we notify about new
# positions via `advance` without *actually* advancing the postgres
# sequence.
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
# Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
# We advance "first" straight to 6. Min is now 5 but there is no gap so
# we expect it to be 6
id_gen.advance("first", 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# No gap, so we expect 7.
id_gen.advance("second", 7)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# We haven't seen 8 yet, so we expect 7 still.
id_gen.advance("second", 9)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# Now that we've seen 7, 8 and 9 we can got straight to 9.
id_gen.advance("first", 8)
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
# Jump forward with gaps. The minimum is 11, even though we haven't seen
# 10 we know that everything before 11 must be persisted.
id_gen.advance("first", 11)
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called.
"""
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
def test_multi_instance(self) -> None:
"""Test that reads and writes from multiple processes are handled
correctly.
@ -453,145 +615,6 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
)
def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
# The following tests are a bit cheeky in that we notify about new
# positions via `advance` without *actually* advancing the postgres
# sequence.
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("worker", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
# Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
# We advance "first" straight to 6. Min is now 5 but there is no gap so
# we expect it to be 6
id_gen.advance("first", 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# No gap, so we expect 7.
id_gen.advance("second", 7)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# We haven't seen 8 yet, so we expect 7 still.
id_gen.advance("second", 9)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# Now that we've seen 7, 8 and 9 we can got straight to 9.
id_gen.advance("first", 8)
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
# Jump forward with gaps. The minimum is 11, even though we haven't seen
# 10 we know that everything before 11 must be persisted.
id_gen.advance("first", 11)
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called.
"""
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out
of order updates are handled correctly.
"""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Persist two rows at once
ctx1 = id_gen.get_next()
ctx2 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
self.assertEqual(s1, 8)
self.assertEqual(s2, 9)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# We finish persisting the second row before restart
self.get_success(ctx2.__aexit__(None, None, None))
# We simulate a restart of another worker by just creating a new ID gen.
id_gen_worker = self._create_id_generator("worker")
# Restarted worker should not see the second persisted row
self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
# Now if we persist the first row then both instances should jump ahead
# correctly.
self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
def test_writer_config_change(self) -> None:
"""Test that changing the writer config correctly works."""